diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 76ce80fb4cc..e1d5de27c65 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -224,6 +224,16 @@ public void suppressNonSpeechTokens(boolean enable) { /** No speech threshold. */ public float no_speech_thold; + /** + * RNG seed for reproducible sampling (when temperature > 0). + * Each decoder uses {@code seed + decoder_index} so concurrent decoders get + * unique seeds. Maps to the {@code seed} field added at + * {@code include/whisper.h:553}; without this field declared here the + * subsequent {@code greedy} / {@code beam_search} struct offsets shift by + * 4 bytes and JNA reads garbage from the C-side defaults. + */ + public int seed; + /** Greedy decoding parameters. */ public GreedyParams greedy; @@ -331,6 +341,21 @@ public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) { public long i_start_rule; public float grammar_penalty; + // Voice Activity Detection (VAD) params -- added by upstream after v1.8.4. + // Without these three fields declared here the C struct's tail is missing + // from the JNA layout, which is fine for read-only callers but corrupts + // the trailing memory whenever a Java caller passes WhisperFullParams + // back into the C ABI (e.g. whisper_full). + + /** Enable VAD pre-filtering inside whisper_full. (default = false) */ + public CBool vad; + + /** Path to the Silero VAD model (only used when {@link #vad} is true). */ + public String vad_model_path; + + /** VAD tuning knobs, mirrors {@code whisper_vad_params}. */ + public WhisperVadParams vad_params; + @Override protected List getFieldOrder() { return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", @@ -343,13 +368,15 @@ protected List getFieldOrder() { "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", - "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", - "beam_search", "new_segment_callback", "new_segment_callback_user_data", - "progress_callback", "progress_callback_user_data", - "encoder_begin_callback", "encoder_begin_callback_user_data", - "abort_callback", "abort_callback_user_data", - "logits_filter_callback", "logits_filter_callback_user_data", - "grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty"); + "entropy_thold", "logprob_thold", "no_speech_thold", "seed", + "greedy", "beam_search", "new_segment_callback", + "new_segment_callback_user_data", "progress_callback", + "progress_callback_user_data", "encoder_begin_callback", + "encoder_begin_callback_user_data", "abort_callback", + "abort_callback_user_data", "logits_filter_callback", + "logits_filter_callback_user_data", "grammar_rules", + "n_grammar_rules", "i_start_rule", "grammar_penalty", + "vad", "vad_model_path", "vad_params"); } public static class ByValue extends WhisperFullParams implements Structure.ByValue { diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperVadParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperVadParams.java new file mode 100644 index 00000000000..d7487edf811 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperVadParams.java @@ -0,0 +1,61 @@ +package io.github.ggerganov.whispercpp.params; + +import com.sun.jna.Pointer; +import com.sun.jna.Structure; + +import java.util.Arrays; +import java.util.List; + +/** + * Voice Activity Detection (VAD) parameters. + * Mirrors {@code struct whisper_vad_params} in include/whisper.h. + */ +public class WhisperVadParams extends Structure { + + public WhisperVadParams() { + super(); + } + + public WhisperVadParams(Pointer p) { + super(p); + } + + /** Probability threshold to consider as speech. */ + public float threshold; + + /** Min duration for a valid speech segment. */ + public int min_speech_duration_ms; + + /** Min silence duration to consider speech as ended. */ + public int min_silence_duration_ms; + + /** Max duration of a speech segment before forcing a new segment. */ + public float max_speech_duration_s; + + /** Padding added before and after speech segments. */ + public int speech_pad_ms; + + /** Overlap in seconds when copying audio samples from speech segment. */ + public float samples_overlap; + + @Override + protected List getFieldOrder() { + return Arrays.asList( + "threshold", + "min_speech_duration_ms", + "min_silence_duration_ms", + "max_speech_duration_s", + "speech_pad_ms", + "samples_overlap"); + } + + public static class ByValue extends WhisperVadParams implements Structure.ByValue { + public ByValue() { + super(); + } + + public ByValue(Pointer p) { + super(p); + } + } +} diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index c6280a6926a..41e7b330d58 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -202,6 +202,8 @@ whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors) Note that transcription occasionally might be low accuracy when it works in parallel. +If n_processors is greater than 1, you cannot set any callbacks including new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, and log_callback set by Whisper.log_set. + ### Segments ### Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`: diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index ba71d4ba594..5f1917ee805 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -112,6 +112,10 @@ ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * return; } VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (NIL_P(log_callback)) { + return; + } + VALUE udata = rb_iv_get(mWhisper, "user_data"); rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); } @@ -129,10 +133,16 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d rb_iv_set(self, "log_callback", log_callback); rb_iv_set(self, "user_data", user_data); - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); + if (!NIL_P(log_callback)) { + VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); + rb_define_finalizer(log_callback, finalize_log_callback); + } - whisper_log_set(ruby_whisper_log_callback, NULL); + if (NIL_P(log_callback)) { + whisper_log_set(NULL, NULL); + } else { + whisper_log_set(ruby_whisper_log_callback, NULL); + } return Qnil; } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 8dfd103c17a..6b0b4df7214 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -2,6 +2,7 @@ #define RUBY_WHISPER_H #include +#include #include #include "whisper.h" diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index c39d43bd76c..6e38ead6321 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -22,7 +22,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); ID transcribe_option_names[1]; @@ -436,7 +436,7 @@ full_body(VALUE rb_args) GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context); + prepare_transcription(rwp, args->context, 1); int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); return INT2NUM(result); @@ -487,7 +487,7 @@ full_parallel_body(VALUE rb_args) GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context); + prepare_transcription(rwp, args->context, args->n_processors); int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); return INT2NUM(result); diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 61eb1733676..3e5dca9c1e1 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -29,6 +29,7 @@ extern VALUE cParams; extern VALUE cVADParams; +extern VALUE mWhisper; extern ID id_call; @@ -186,6 +187,35 @@ static bool abort_callback(void * user_data) { return false; } +static void +check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) +{ + if (n_processors == 1) { + return; + } + + if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); + } + + VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (!NIL_P(log_callback)) { + rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription"); + } +} + static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { rwp->new_segment_callback_container->context = context; @@ -219,9 +249,13 @@ static void set_vad_params(ruby_whisper_params *rwp) rwp->params.vad_params = rwvp->params; } +/* + TODO: Set abort callback to trap SIGINT and SIGTERM +*/ void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) { + check_thread_safety(rwp, context, n_processors); register_callbacks(rwp, context); set_vad_params(rwp); } @@ -240,6 +274,20 @@ rb_whisper_params_mark(void *p) void ruby_whisper_params_free(ruby_whisper_params *rwp) { + if (rwp->params.language) { + ruby_xfree((void *)rwp->params.language); + } + if (rwp->params.initial_prompt) { + ruby_xfree((void *)rwp->params.initial_prompt); + } + if (rwp->params.vad_model_path) { + ruby_xfree((void *)rwp->params.vad_model_path); + } + + xfree(rwp->new_segment_callback_container); + xfree(rwp->progress_callback_container); + xfree(rwp->encoder_begin_callback_container); + xfree(rwp->abort_callback_container); } void @@ -248,7 +296,7 @@ rb_whisper_params_free(void *p) ruby_whisper_params *rwp = (ruby_whisper_params *)p; // How to free user_data and callback only when not referred to by others? ruby_whisper_params_free(rwp); - free(rwp); + xfree(rwp); } static size_t @@ -276,6 +324,15 @@ ruby_whisper_params_allocate(VALUE klass) ruby_whisper_params *rwp; VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + if (rwp->params.language != NULL) { + rwp->params.language = ruby_strdup(rwp->params.language); + } + if (rwp->params.initial_prompt != NULL) { + rwp->params.initial_prompt = ruby_strdup(rwp->params.initial_prompt); + } + if (rwp->params.vad_model_path != NULL) { + rwp->params.vad_model_path = ruby_strdup(rwp->params.vad_model_path); + } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); @@ -296,10 +353,12 @@ ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + ruby_xfree((void *)rwp->params.language); + rwp->params.language = NULL; if (value == Qfalse || value == Qnil) { - rwp->params.language = "auto"; + rwp->params.language = ruby_strdup("auto"); } else { - rwp->params.language = StringValueCStr(value); + rwp->params.language = ruby_strdup(StringValueCStr(value)); } return value; } @@ -608,7 +667,13 @@ ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); - rwp->params.initial_prompt = StringValueCStr(value); + ruby_xfree((void *)rwp->params.initial_prompt); + rwp->params.initial_prompt = NULL; + if (NIL_P(value)) { + rwp->params.initial_prompt = NULL; + } else { + rwp->params.initial_prompt = ruby_strdup(StringValueCStr(value)); + } return value; } /* @@ -1103,12 +1168,14 @@ ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + ruby_xfree((void *)rwp->params.vad_model_path); + rwp->params.vad_model_path = NULL; if (NIL_P(value)) { rwp->params.vad_model_path = NULL; return value; } VALUE path = ruby_whisper_normalize_model_path(value); - rwp->params.vad_model_path = StringValueCStr(path); + rwp->params.vad_model_path = ruby_strdup(StringValueCStr(path)); return value; } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index c00fbcd1def..3d00566009a 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -16,7 +16,7 @@ extern ID id_to_path; extern ID transcribe_option_names[1]; extern void -prepare_transcription(ruby_whisper_params * rwp, VALUE * self); +prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); /* * transcribe a single file @@ -73,7 +73,7 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { // rwp->params.encoder_begin_callback_user_data = &is_aborted; // } - prepare_transcription(rwp, &self); + prepare_transcription(rwp, &self, n_processors); if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) { fprintf(stderr, "failed to process audio\n"); diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 9ade451c6b2..3c59661975b 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -37,7 +37,7 @@ module Whisper def self.lang_id: (string name) -> Integer def self.lang_str: (Integer id) -> String def self.lang_str_full: (Integer id) -> String - def self.log_set: (log_callback, Object? user_data) -> log_callback + def self.log_set: (log_callback?, Object? user_data) -> log_callback def self.system_info_str: () -> String class Context @@ -52,6 +52,9 @@ module Whisper # puts text # end # + # If n_processors is greater than 1, you cannot set any callbacks including + # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, + # and log_callback set by Whisper.log_set def transcribe: (path, Params, ?n_processors: Integer) -> self | (path, Params, ?n_processors: Integer) { (String) -> void } -> self @@ -129,6 +132,9 @@ module Whisper # It seems this approach can offer some speedup in some cases. # However, the transcription accuracy can be worse at the beginning and end of each chunk. # + # If n_processors is greater than 1, you cannot set any callbacks including + # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, + # and log_callback set by Whisper.log_set def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self | (Params, _Samples, ?Integer n_samples) -> self | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb index 094dba6f48e..ff5c28e9043 100644 --- a/bindings/ruby/test/test_params.rb +++ b/bindings/ruby/test/test_params.rb @@ -46,6 +46,8 @@ def setup def test_language @params.language = "en" assert_equal @params.language, "en" + GC.compact + assert_equal @params.language, "en" @params.language = "auto" assert_equal @params.language, "auto" end diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 29071210072..f7e25239d5d 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -43,9 +43,20 @@ def test_transcribe_n_processors @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new - @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| - assert_match(/what you can do for your country/i, text) - } + without_log_callback do + @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| + assert_match(/what you can do for your country/i, text) + } + end + end + + private + + def without_log_callback + Whisper.log_set nil, nil + yield + ensure + Whisper.log_set ->(level, buffer, user_data) {}, nil end sub_test_case "After transcription" do @@ -229,7 +240,9 @@ def test_full_with_memroy_view_gc def test_full_parallel nprocessors = 2 - @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join @@ -240,7 +253,9 @@ def test_full_parallel def test_full_parallel_with_memory_view nprocessors = 2 samples = JFKReader.new(AUDIO) - @whisper.full_parallel(@params, samples, nil, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, samples, nil, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join @@ -259,7 +274,9 @@ def test_full_parallel_without_length_and_n_processors def test_full_parallel_without_length nprocessors = 2 - @whisper.full_parallel(@params, @samples, nil, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, @samples, nil, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 88b94e7eb8a..2d952222f29 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -3,7 +3,7 @@ require_relative "extsources" Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] - s.version = '1.3.6' + s.version = '1.3.7' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] diff --git a/examples/bench.wasm/emscripten.cpp b/examples/bench.wasm/emscripten.cpp index 083397db057..7e9f277f66e 100644 --- a/examples/bench.wasm/emscripten.cpp +++ b/examples/bench.wasm/emscripten.cpp @@ -45,7 +45,7 @@ void bench_main(size_t index) { fprintf(stderr, "\n"); fprintf(stderr, "If you wish, you can submit these results here:\n"); fprintf(stderr, "\n"); - fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n"); + fprintf(stderr, " https://github.com/ggml-org/whisper.cpp/issues/89\n"); fprintf(stderr, "\n"); fprintf(stderr, "Please include the following information:\n"); fprintf(stderr, "\n"); diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 049473d4f32..84915c56a8a 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -157,7 +157,7 @@ static int whisper_bench_full(const whisper_params & params) { fprintf(stderr, "\n"); fprintf(stderr, "If you wish, you can submit these results here:\n"); fprintf(stderr, "\n"); - fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n"); + fprintf(stderr, " https://github.com/ggml-org/whisper.cpp/issues/89\n"); fprintf(stderr, "\n"); fprintf(stderr, "Please include the following information:\n"); fprintf(stderr, "\n"); diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index 6f02a2504c5..3f2eded86f7 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -74,6 +74,7 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_BF16: case GGML_FTYPE_MOSTLY_MXFP4: case GGML_FTYPE_MOSTLY_NVFP4: + case GGML_FTYPE_MOSTLY_Q1_0: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -215,6 +216,7 @@ bool ggml_common_quantize_0( case GGML_TYPE_TQ2_0: case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: + case GGML_TYPE_Q1_0: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index d6a5800e63a..4a1aaa955a8 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -294,7 +294,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } // get extra buffer types of the CPU - // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future + // TODO: a more general solution for non-CPU extra buft should be implemented in the future // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948 std::vector buft_extra; { @@ -418,7 +418,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { - llama_adapter_lora * adapter = new llama_adapter_lora(); + llama_adapter_lora * adapter = new llama_adapter_lora(model); try { llama_adapter_lora_init_impl(*model, path_lora, *adapter); @@ -471,8 +471,17 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, return snprintf(buf, buf_size, "%s", it->second.c_str()); } -void llama_adapter_lora_free(llama_adapter_lora *) { - // deprecated: adapters are freed by llama_model's destructor +void llama_adapter_lora_free(llama_adapter_lora * adapter) { + if (adapter == nullptr) { + return; + } + + if (adapter->model != nullptr) { + adapter->model->loras.erase(adapter); + adapter->model = nullptr; + } + + delete adapter; } uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index aa3ab63ad75..f0b1e50f816 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -61,6 +61,8 @@ struct llama_adapter_lora_weight { }; struct llama_adapter_lora { + llama_model * model = nullptr; + // map tensor name to lora_a_b std::unordered_map ab_map; @@ -75,7 +77,7 @@ struct llama_adapter_lora { // activated lora (aLoRA) std::vector alora_invocation_tokens; - llama_adapter_lora() = default; + explicit llama_adapter_lora(llama_model * model) : model(model) {} ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 799d16167ba..633a66fc665 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -56,6 +56,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, + { LLM_ARCH_GEMMA4, "gemma4" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -73,6 +74,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -107,6 +109,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, + { LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, @@ -123,6 +126,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_STEP35, "step35" }, @@ -163,6 +167,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" }, + { LLM_KV_EMBEDDING_LENGTH_PER_LAYER, "%s.embedding_length_per_layer_input" }, { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -236,6 +241,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, + { LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, @@ -245,6 +251,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" }, { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -362,6 +369,9 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_POST_NORM_1, "blk.%d.post_ffw_norm_1" }, + { LLM_TENSOR_FFN_POST_NORM_2, "blk.%d.post_ffw_norm_2" }, + { LLM_TENSOR_FFN_PRE_NORM_2, "blk.%d.pre_ffw_norm_2" }, { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, @@ -371,6 +381,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, @@ -538,2016 +549,6 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; -static std::set llm_get_tensor_names(llm_arch arch) { - switch (arch) { - case LLM_ARCH_CLIP: - return {}; - case LLM_ARCH_LLAMA: - case LLM_ARCH_DECI: - case LLM_ARCH_MISTRAL3: - case LLM_ARCH_LLAMA_EMBED: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_ARCEE: - case LLM_ARCH_STARCODER2: - case LLM_ARCH_NEMOTRON: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_AFMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_LLAMA4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_BAICHUAN: - case LLM_ARCH_ORION: - case LLM_ARCH_XVERSE: - case LLM_ARCH_EXAONE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_FALCON: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GROK: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_ATTN_OUT_NORM, - }; - case LLM_ARCH_GPT2: - case LLM_ARCH_STARCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GPTNEOX: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_MPT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_ACT, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_REFACT: - case LLM_ARCH_QWEN2: - case LLM_ARCH_QWEN2VL: - case LLM_ARCH_INTERNLM2: - case LLM_ARCH_GRANITE: - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_PADDLEOCR: - case LLM_ARCH_SMOLLM3: - case LLM_ARCH_DREAM: - case LLM_ARCH_LLADA: - case LLM_ARCH_PANGU_EMBED: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_NOMIC_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_NOMIC_BERT_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_NEO_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_EUROBERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_MODERN_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_CLS_NORM, - }; - case LLM_ARCH_JINA_BERT_V2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_CLS, - }; - case LLM_ARCH_JINA_BERT_V3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_LAYER_OUT_NORM, - }; - case LLM_ARCH_BLOOM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_STABLELM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_QWEN: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_QWEN2MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_QWEN3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_QWEN3MOE: - case LLM_ARCH_QWEN3VLMOE: - case LLM_ARCH_OLMOE: - case LLM_ARCH_LLADA_MOE: - case LLM_ARCH_RND1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_QWEN3NEXT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA_ALPHA, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN35: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_ALPHA, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN35MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_ALPHA, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN3VL: - case LLM_ARCH_CHAMELEON: - case LLM_ARCH_HUNYUAN_DENSE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHI2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHI3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHIMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_PLAMO: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PLAMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_SSM_DT_NORM, - LLM_TENSOR_SSM_B_NORM, - LLM_TENSOR_SSM_C_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_PLAMO3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_CODESHELL: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_MINICPM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - }; - case LLM_ARCH_MINICPM3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GEMMA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GEMMA2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GEMMA3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GEMMA3N: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_PER_LAYER_TOKEN_EMBD, - LLM_TENSOR_PER_LAYER_MODEL_PROJ, - LLM_TENSOR_PER_LAYER_PROJ_NORM, - LLM_TENSOR_ALTUP_UNEMBD_PROJ, - LLM_TENSOR_ALTUP_PROJ, - LLM_TENSOR_PER_LAYER_INP_GATE, - LLM_TENSOR_PER_LAYER_PROJ, - LLM_TENSOR_PER_LAYER_POST_NORM, - LLM_TENSOR_ALTUP_CORRECT_COEF, - LLM_TENSOR_ALTUP_CORRECT_SCALE, - LLM_TENSOR_ALTUP_PREDICT_COEF, - LLM_TENSOR_ALTUP_ROUTER, - LLM_TENSOR_ALTUP_ROUTER_NORM, - LLM_TENSOR_LAUREL_L, - LLM_TENSOR_LAUREL_R, - LLM_TENSOR_LAUREL_POST_NORM, - }; - case LLM_ARCH_GEMMA_EMBEDDING: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DENSE_2_OUT, - LLM_TENSOR_DENSE_3_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_MAMBA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_MAMBA2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_JAMBA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_DT_NORM, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_B_NORM, - LLM_TENSOR_SSM_C_NORM, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_FALCON_H1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_COMMAND_R: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_COHERE2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_DBRX: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_OLMO: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_OLMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_OPENELM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_ARCTIC: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM_EXPS, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_DEEPSEEK: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_DEEPSEEK2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_PLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_CHATGLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GLM4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_GLM4_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_GLM_DSA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_INDEXER_K_NORM, - LLM_TENSOR_INDEXER_PROJ, - LLM_TENSOR_INDEXER_ATTN_K, - LLM_TENSOR_INDEXER_ATTN_Q_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_BITNET: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_SUB_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_SUB_NORM, - }; - case LLM_ARCH_T5: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DEC_OUTPUT_NORM, - LLM_TENSOR_DEC_ATTN_NORM, - LLM_TENSOR_DEC_ATTN_Q, - LLM_TENSOR_DEC_ATTN_K, - LLM_TENSOR_DEC_ATTN_V, - LLM_TENSOR_DEC_ATTN_OUT, - LLM_TENSOR_DEC_ATTN_REL_B, - LLM_TENSOR_DEC_CROSS_ATTN_NORM, - LLM_TENSOR_DEC_CROSS_ATTN_Q, - LLM_TENSOR_DEC_CROSS_ATTN_K, - LLM_TENSOR_DEC_CROSS_ATTN_V, - LLM_TENSOR_DEC_CROSS_ATTN_OUT, - LLM_TENSOR_DEC_CROSS_ATTN_REL_B, - LLM_TENSOR_DEC_FFN_NORM, - LLM_TENSOR_DEC_FFN_GATE, - LLM_TENSOR_DEC_FFN_DOWN, - LLM_TENSOR_DEC_FFN_UP, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_ENC_ATTN_NORM, - LLM_TENSOR_ENC_ATTN_Q, - LLM_TENSOR_ENC_ATTN_K, - LLM_TENSOR_ENC_ATTN_V, - LLM_TENSOR_ENC_ATTN_OUT, - LLM_TENSOR_ENC_ATTN_REL_B, - LLM_TENSOR_ENC_FFN_NORM, - LLM_TENSOR_ENC_FFN_GATE, - LLM_TENSOR_ENC_FFN_DOWN, - LLM_TENSOR_ENC_FFN_UP, - }; - case LLM_ARCH_T5ENCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_ENC_ATTN_NORM, - LLM_TENSOR_ENC_ATTN_Q, - LLM_TENSOR_ENC_ATTN_K, - LLM_TENSOR_ENC_ATTN_V, - LLM_TENSOR_ENC_ATTN_OUT, - LLM_TENSOR_ENC_ATTN_REL_B, - LLM_TENSOR_ENC_FFN_NORM, - LLM_TENSOR_ENC_FFN_GATE, - LLM_TENSOR_ENC_FFN_DOWN, - LLM_TENSOR_ENC_FFN_UP, - }; - case LLM_ARCH_JAIS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_JAIS2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_NEMOTRON_H: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_NEMOTRON_H_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - // mamba(2) ssm layers - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - // attention layers - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - // dense FFN - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - // MoE FFN (for MoE layers) - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_FFN_LATENT_DOWN, - LLM_TENSOR_FFN_LATENT_UP, - // MoE shared expert layer - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_EXAONE4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_EXAONE_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_RWKV6: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_LERP_X, - LLM_TENSOR_TIME_MIX_LERP_W, - LLM_TENSOR_TIME_MIX_LERP_K, - LLM_TENSOR_TIME_MIX_LERP_V, - LLM_TENSOR_TIME_MIX_LERP_R, - LLM_TENSOR_TIME_MIX_LERP_G, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_FIRST, - LLM_TENSOR_TIME_MIX_DECAY, - LLM_TENSOR_TIME_MIX_DECAY_W1, - LLM_TENSOR_TIME_MIX_DECAY_W2, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_GATE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_CHANNEL_MIX_LERP_K, - LLM_TENSOR_CHANNEL_MIX_LERP_R, - LLM_TENSOR_CHANNEL_MIX_KEY, - LLM_TENSOR_CHANNEL_MIX_VALUE, - LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, - }; - case LLM_ARCH_RWKV6QWEN2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_LERP_X, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_FIRST, - LLM_TENSOR_TIME_MIX_DECAY, - LLM_TENSOR_TIME_MIX_DECAY_W1, - LLM_TENSOR_TIME_MIX_DECAY_W2, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_GATE, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_RWKV7: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_TIME_MIX_W0, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_A0, - LLM_TENSOR_TIME_MIX_A1, - LLM_TENSOR_TIME_MIX_A2, - LLM_TENSOR_TIME_MIX_V0, - LLM_TENSOR_TIME_MIX_V1, - LLM_TENSOR_TIME_MIX_V2, - LLM_TENSOR_TIME_MIX_G1, - LLM_TENSOR_TIME_MIX_G2, - LLM_TENSOR_TIME_MIX_K_K, - LLM_TENSOR_TIME_MIX_K_A, - LLM_TENSOR_TIME_MIX_R_K, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_CHANNEL_MIX_LERP_K, - LLM_TENSOR_CHANNEL_MIX_KEY, - LLM_TENSOR_CHANNEL_MIX_VALUE, - }; - case LLM_ARCH_ARWKV7: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_TIME_MIX_W0, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_A0, - LLM_TENSOR_TIME_MIX_A1, - LLM_TENSOR_TIME_MIX_A2, - LLM_TENSOR_TIME_MIX_V0, - LLM_TENSOR_TIME_MIX_V1, - LLM_TENSOR_TIME_MIX_V2, - LLM_TENSOR_TIME_MIX_G1, - LLM_TENSOR_TIME_MIX_G2, - LLM_TENSOR_TIME_MIX_K_K, - LLM_TENSOR_TIME_MIX_K_A, - LLM_TENSOR_TIME_MIX_R_K, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GRANITE_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_GRANITE_HYBRID: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_WAVTOKENIZER_DEC: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_CONV1D, - LLM_TENSOR_CONVNEXT_DW, - LLM_TENSOR_CONVNEXT_NORM, - LLM_TENSOR_CONVNEXT_PW1, - LLM_TENSOR_CONVNEXT_PW2, - LLM_TENSOR_CONVNEXT_GAMMA, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_POS_NET_CONV1, - LLM_TENSOR_POS_NET_CONV2, - LLM_TENSOR_POS_NET_NORM, - LLM_TENSOR_POS_NET_NORM1, - LLM_TENSOR_POS_NET_NORM2, - LLM_TENSOR_POS_NET_ATTN_NORM, - LLM_TENSOR_POS_NET_ATTN_Q, - LLM_TENSOR_POS_NET_ATTN_K, - LLM_TENSOR_POS_NET_ATTN_V, - LLM_TENSOR_POS_NET_ATTN_OUT, - }; - case LLM_ARCH_BAILINGMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_BAILINGMOE2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - LLM_TENSOR_LAYER_OUT_NORM, - }; - case LLM_ARCH_DOTS1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_ERNIE4_5_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_HUNYUAN_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_OPENAI_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_SINKS, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_LFM2: - return { - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SHORTCONV_CONV, - LLM_TENSOR_SHORTCONV_INPROJ, - LLM_TENSOR_SHORTCONV_OUTPROJ, - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM_LFM2, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DENSE_2_OUT, - }; - case LLM_ARCH_LFM2MOE: - return { - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SHORTCONV_CONV, - LLM_TENSOR_SHORTCONV_INPROJ, - LLM_TENSOR_SHORTCONV_OUTPROJ, - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM_LFM2, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_SMALLTHINKER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_APERTUS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_SEED_OSS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GROVEMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_CHEXPS, - LLM_TENSOR_FFN_DOWN_CHEXPS, - LLM_TENSOR_FFN_UP_CHEXPS, - }; - case LLM_ARCH_MINIMAX_M2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_COGVLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_VISEXP_ATTN_QKV, - LLM_TENSOR_VISEXP_ATTN_OUT, - LLM_TENSOR_VISEXP_FFN_GATE, - LLM_TENSOR_VISEXP_FFN_DOWN, - LLM_TENSOR_VISEXP_FFN_UP, - }; - case LLM_ARCH_MIMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_SINKS, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_STEP35: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_GPTJ: - case LLM_ARCH_UNKNOWN: - return { - LLM_TENSOR_TOKEN_EMBD, - }; - case LLM_ARCH_MAINCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_KIMI_LINEAR: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - // Dense FFN (layer 0 only) - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - // MoE FFN (layers 1+) - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - // Shared experts - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) - LLM_TENSOR_SSM_CONV1D_Q, - LLM_TENSOR_SSM_CONV1D_K, - LLM_TENSOR_SSM_CONV1D_V, - LLM_TENSOR_SSM_F_A, - LLM_TENSOR_SSM_F_B, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_G_A, - LLM_TENSOR_SSM_G_B, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_NORM, - // MLA - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_KV_A_NORM, - }; - default: - GGML_ABORT("unknown architecture for tensor mapping"); - } -} - // declare information about the model weight tensors: // - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight // - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator @@ -2559,20 +560,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { // example: https://github.com/ggml-org/llama.cpp/pull/17548 // static const std::map LLM_TENSOR_INFOS = { - {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, - {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer) + {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, {LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, {LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, @@ -2680,11 +681,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_PRE_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM_1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAYER_OUT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2705,9 +710,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // altup / laurel (gemma 3n) - {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2723,7 +728,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, - {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2778,18 +783,13 @@ std::string LLM_KV::operator()(llm_kv kv) const { } LLM_TN_IMPL::LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid) - : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid), - model_tensors(llm_get_tensor_names(arch)) {} + : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid) {} std::string LLM_TN_IMPL::str() const { if (LLM_TENSOR_NAMES.find(tensor) == LLM_TENSOR_NAMES.end()) { GGML_ABORT("unknown tensor name for tensor id %d", static_cast(tensor)); } - if (model_tensors.find(tensor) == model_tensors.end()) { - return LLM_TENSOR_NAMES.at(tensor); - } - std::string name = ::format(LLM_TENSOR_NAMES.at(tensor), bid, xid); if (suffix != nullptr) { name += "."; @@ -2875,3 +875,34 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { return false; } } + +bool llm_arch_supports_sm_tensor(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_GROK: + case LLM_ARCH_MPT: + case LLM_ARCH_PLAMO2: + case LLM_ARCH_MINICPM3: + case LLM_ARCH_GEMMA3N: + case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: + case LLM_ARCH_JAMBA: + case LLM_ARCH_FALCON_H1: + case LLM_ARCH_OLMO2: + case LLM_ARCH_OLMOE: + case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: + case LLM_ARCH_BITNET: + case LLM_ARCH_T5: + case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_GRANITE_HYBRID: + case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: + case LLM_ARCH_MINIMAX_M2: + case LLM_ARCH_MISTRAL4: + case LLM_ARCH_KIMI_LINEAR: + return false; + default: + return true; + } +} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index b1b1dcf1883..8f335f5c7b3 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -60,6 +60,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, + LLM_ARCH_GEMMA4, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -77,6 +78,7 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_DEEPSEEK2OCR, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, @@ -111,6 +113,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_DENSE, + LLM_ARCH_HUNYUAN_VL, LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, @@ -127,6 +130,7 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_MISTRAL4, LLM_ARCH_PADDLEOCR, LLM_ARCH_MIMO2, LLM_ARCH_STEP35, @@ -167,6 +171,7 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_EMBEDDING_LENGTH_OUT, + LLM_KV_EMBEDDING_LENGTH_PER_LAYER, LLM_KV_FEATURES_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, @@ -240,6 +245,7 @@ enum llm_kv { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, + LLM_KV_ATTENTION_SHARED_KV_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT_SWA, @@ -249,6 +255,7 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ALPHA, LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -367,6 +374,9 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_POST_NORM_1, + LLM_TENSOR_FFN_POST_NORM_2, + LLM_TENSOR_FFN_PRE_NORM_2, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, @@ -391,6 +401,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_LAYER_OUT_SCALE, LLM_TENSOR_POST_ATTN_NORM, LLM_TENSOR_POST_MLP_NORM, LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n @@ -576,8 +587,6 @@ struct LLM_TN_IMPL { const int bid; const int xid; - const std::set model_tensors; - LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid); std::string str() const; @@ -623,6 +632,7 @@ llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); -bool llm_arch_is_recurrent(const llm_arch & arch); -bool llm_arch_is_hybrid (const llm_arch & arch); -bool llm_arch_is_diffusion(const llm_arch & arch); +bool llm_arch_is_recurrent (const llm_arch & arch); +bool llm_arch_is_hybrid (const llm_arch & arch); +bool llm_arch_is_diffusion (const llm_arch & arch); +bool llm_arch_supports_sm_tensor(const llm_arch & arch); diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index 8e6fac0efab..f77520e86c3 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -18,7 +18,7 @@ struct llama_ubatch { } // typical for M-RoPE cases: - // 0 - sequantial position of the tokens/embeddings in the sequence + // 0 - sequential position of the tokens/embeddings in the sequence // 1 - y position in the image // 2 - x position in the image // 3 - other diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index c415a998f33..6554a89b28a 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -49,6 +49,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, + { "deepseek-ocr", LLM_CHAT_TEMPLATE_DEEPSEEK_OCR }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 }, @@ -59,7 +60,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, - { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X }, + { "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, @@ -71,6 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, + { "hunyuan-ocr", LLM_CHAT_TEMPLATE_HUNYUAN_OCR }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -190,7 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { - return LLM_CHAT_TEMPLATE_GRANITE; + if (tmpl_contains("") || tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GRANITE_4_0; + } + return LLM_CHAT_TEMPLATE_GRANITE_3_X; } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { @@ -211,6 +217,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; + } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_OCR; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -548,6 +556,11 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << LU8("<|Assistant|>"); } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_OCR) { + for (auto message : chat) { + // no template + ss << message->content; + } } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct @@ -611,8 +624,8 @@ int32_t llm_chat_apply_template( ss << "Assistant: " << trim(chat[i]->content) << "\n\n"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { - // IBM Granite template + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_3_X) { + // IBM Granite 3.x template for (const auto & message : chat) { std::string role(message->role); ss << "<|start_of_role|>" << role << "<|end_of_role|>"; @@ -624,6 +637,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_0) { + // IBM Granite 4.0 template + for (const auto & message : chat) { + std::string role(message->role); + if (role == "assistant_tool_call") { + ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>"; + } else { + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; @@ -798,6 +825,22 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_OCR) { + // tencent/HunyuanOCR + ss << "<|hy_begin▁of▁sentence|>"; + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (i == 0 && role == "system") { + ss << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; + continue; + } + + if (role == "user") { + ss << chat[i]->content << "<|hy_User|>"; + } else if (role == "assistant") { + ss << chat[i]->content << "<|hy_Assistant|>"; + } + } } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { // moonshotai/Kimi-K2-Instruct for (auto message : chat) { diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 9ed1db128ec..13f936a946c 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -28,6 +28,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_DEEPSEEK, LLM_CHAT_TEMPLATE_DEEPSEEK_2, LLM_CHAT_TEMPLATE_DEEPSEEK_3, + LLM_CHAT_TEMPLATE_DEEPSEEK_OCR, LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGLM_3, @@ -38,7 +39,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_EXAONE_4, LLM_CHAT_TEMPLATE_EXAONE_MOE, LLM_CHAT_TEMPLATE_RWKV_WORLD, - LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GRANITE_3_X, + LLM_CHAT_TEMPLATE_GRANITE_4_0, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, @@ -51,6 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, + LLM_CHAT_TEMPLATE_HUNYUAN_OCR, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 1f7a52d7895..8126249e143 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -1,5 +1,6 @@ #include "llama-context.h" +#include "ggml.h" #include "llama-arch.h" #include "llama-impl.h" #include "llama-batch.h" @@ -8,6 +9,7 @@ #include "llama-mmap.h" #include "llama-model.h" #include "llama-ext.h" +#include "llama.h" #include #include @@ -217,10 +219,10 @@ llama_context::llama_context( if (!hparams.vocab_only) { // GPU backends - for (auto * dev : model.devices) { - ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + for (const auto & dev : model.devices) { + ggml_backend_t backend = ggml_backend_dev_init(dev.dev, nullptr); if (backend == nullptr) { - throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); + throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev.dev))); } backends.emplace_back(backend); } @@ -295,8 +297,8 @@ llama_context::llama_context( if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { // use the host buffer of the first device CPU for faster transfer of the intermediate state - auto * dev = model.devices[0]; - auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + const auto & dev = model.devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev.dev); if (host_buft) { buft = host_buft; } @@ -342,14 +344,6 @@ llama_context::llama_context( if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); - - if (!graph_reuse_disable) { - // TODO: figure out a way to make graph reuse work with pipeline parallelism - // ref: https://github.com/ggml-org/llama.cpp/pull/20463 - LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__); - - graph_reuse_disable = true; - } } sched_reserve(); @@ -594,7 +588,7 @@ void llama_context::sched_reserve() { // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // TODO: not sure if the following graph would be worst case for multi-stream KV caches: // // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); // @@ -1028,9 +1022,11 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void for (auto & backend : backends) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); - auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); - if (set_abort_callback_fn) { - set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + if (reg) { + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + } } } } @@ -1165,9 +1161,11 @@ bool llama_context::set_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); - // TODO: should we reserve? + bool res = cvec->apply(model, data, len, n_embd, il_start, il_end); - return cvec->apply(model, data, len, n_embd, il_start, il_end); + sched_need_reserve = true; + + return res; } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1187,6 +1185,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); + // with pipeline parallelism, the previous graph_compute_async may still be running + // on the GPU. we must synchronize before set_inputs to avoid overwriting input tensors + // that the previous compute is still reading. + if (cparams.pipeline_parallel) { + ggml_backend_sched_synchronize(sched.get()); + } + n_reused++; } else { res->reset(); @@ -1345,8 +1350,11 @@ int llama_context::encode(const llama_batch & batch_inp) { const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const int32_t seq_idx = ubatch.seq_idx[seq_id]; - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_RANK: @@ -1767,12 +1775,16 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = embd_seq; + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const int32_t seq_idx = ubatch.seq_idx[seq_id]; - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_RANK: @@ -1944,6 +1956,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); return 0; } + ggml_backend_buffer_clear(buf_output.get(), 0); } float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); @@ -2623,7 +2636,7 @@ void llama_context::perf_reset() { n_reused = 0; } -std::map llama_context::memory_breakdown() const { +llama_memory_breakdown llama_context::memory_breakdown() const { std::map ret; for (const auto & [buft, size] : model.memory_breakdown()) { ret[buft].model += size; @@ -2933,7 +2946,22 @@ llama_context * llama_init_from_model( params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { + if (model->split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + LLAMA_LOG_INFO("%s: enabling flash_attn since it is required for SPLIT_MODE_TENSOR\n", __func__); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_ENABLED) { + LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); + return nullptr; + } + if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { + LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); + return nullptr; + } + } + + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { if (model->hparams.n_embd_head_k(il) % blck_size != 0) { @@ -2944,7 +2972,7 @@ llama_context * llama_init_from_model( } } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { if (model->hparams.n_embd_head_v(il) % blck_size != 0) { @@ -3465,142 +3493,6 @@ void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } -void llama_memory_breakdown_print(const struct llama_context * ctx) { - const std::vector & devices = ctx->get_model().devices; - - std::map memory_breakdown = ctx->memory_breakdown(); - - std::vector> table_data; - table_data.reserve(devices.size()); - const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; - const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; - const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; - - table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); - - constexpr size_t MiB = 1024 * 1024; - const std::vector desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; - - // track seen buffer types to avoid double counting: - std::set seen_buffer_types; - - // accumulative memory breakdown for each device and for host: - std::vector mb_dev(devices.size()); - llama_memory_breakdown_data mb_host; - - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (ggml_backend_buft_is_host(buft)) { - mb_host.model += mb.model; - mb_host.context += mb.context; - mb_host.compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (dev) { - int i_dev = -1; - for (size_t i = 0; i < devices.size(); i++) { - if (devices[i] == dev) { - i_dev = i; - break; - } - } - if (i_dev != -1) { - mb_dev[i_dev].model += mb.model; - mb_dev[i_dev].context += mb.context; - mb_dev[i_dev].compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - } - } - - // print memory breakdown for each device: - for (size_t i = 0; i < devices.size(); i++) { - ggml_backend_dev_t dev = devices[i]; - llama_memory_breakdown_data mb = mb_dev[i]; - - const std::string name = ggml_backend_dev_name(dev); - std::string desc = ggml_backend_dev_description(dev); - for (const std::string & prefix : desc_prefixes_strip) { - if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { - desc = desc.substr(prefix.length()); - } - } - - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - - const size_t self = mb.model + mb.context + mb.compute; - const size_t unaccounted = total - self - free; - - table_data.push_back({ - template_gpu, - " - " + name + " (" + desc + ")", - std::to_string(total / MiB), - std::to_string(free / MiB), - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - std::to_string(unaccounted / MiB)}); - } - - // print memory breakdown for host: - { - const size_t self = mb_host.model + mb_host.context + mb_host.compute; - table_data.push_back({ - template_other, - " - Host", - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb_host.model / MiB), - std::to_string(mb_host.context / MiB), - std::to_string(mb_host.compute / MiB), - ""}); // unaccounted - } - - // print memory breakdown for all remaining buffer types: - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (seen_buffer_types.count(buft) == 1) { - continue; - } - const std::string name = ggml_backend_buft_name(buft); - const size_t self = mb.model + mb.context + mb.compute; - table_data.push_back({ - template_other, - " - " + name, - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - ""}); // unaccounted - seen_buffer_types.insert(buft); - } - - for (size_t j = 1; j < table_data[0].size(); j++) { - size_t max_len = 0; - for (const auto & td : table_data) { - max_len = std::max(max_len, td[j].length()); - } - for (auto & td : table_data) { - td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); - } - } - for (const auto & td : table_data) { - LLAMA_LOG_INFO(td[0].c_str(), - __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), - td[6].c_str(), td[7].c_str(), td[8].c_str()); - } -} - // // training // @@ -3631,3 +3523,11 @@ void llama_opt_epoch( callback_train, callback_eval); } + +// +// ext +// + +llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { + return ctx->memory_breakdown(); +} diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index e0d0085c1c3..53c705eaffc 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-ext.h" #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" @@ -22,17 +23,6 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; -// "memory" as in physical memory for a buffer type, in bytes -struct llama_memory_breakdown_data { - size_t model = 0; // memory allocated for the model - size_t context = 0; // memory allocated for the context - size_t compute = 0; // memory allocated for temporary compute buffers - - size_t total() const { - return model + context + compute; - } -}; - struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -172,7 +162,7 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); - std::map memory_breakdown() const; + llama_memory_breakdown memory_breakdown() const; // // training diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 13ced783b42..8ce29d217cb 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -1,8 +1,12 @@ #pragma once -#include "llama-context.h" -#include "ggml.h" -#include "stdint.h" +// this is a staging header for new llama.cpp API +// breaking changes and C++ are allowed. everything here should be considered WIP + +#include "llama.h" + +#include +#include // Reserve a new compute graph. It is valid until the next call to llama_graph_reserve. LLAMA_API struct ggml_cgraph * llama_graph_reserve( @@ -10,3 +14,77 @@ LLAMA_API struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs); + +// Get the default ggml_type for a given ftype. +LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype); + +struct quantize_state_impl; + +LLAMA_API quantize_state_impl * llama_quant_init( + const llama_model * model, + const llama_model_quantize_params * params); + +LLAMA_API void llama_quant_free(quantize_state_impl * qs); + +// Descriptor for constructing a mock model for quantization testing. +struct llama_quant_model_desc { + const char * architecture; + uint32_t n_embd; + uint32_t n_ff; + uint32_t n_layer; + uint32_t n_head; + uint32_t n_head_kv; + uint32_t n_expert; + uint32_t n_embd_head_k; + uint32_t n_embd_head_v; +}; + +// Create a mock model from a metadata descriptor (for testing). +// The returned model must be freed with llama_model_free(). +LLAMA_API llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc); + +// Returns true if this tensor should be quantized (based on name, dims, params). +LLAMA_API bool llama_quant_tensor_allows_quantization( + const quantize_state_impl * qs, + const ggml_tensor * tensor); + +// Compute quantization type assignments for a list of tensors. +// All tensors should be quantizable (use llama_quant_tensor_allows_quantization to filter). +// result_types: caller-allocated array of n_tensors elements, filled with assigned types. +LLAMA_API void llama_quant_compute_types( + quantize_state_impl * qs, + llama_ftype ftype, + ggml_tensor ** tensors, + ggml_type * result_types, + size_t n_tensors); + +// +// device memory querying +// + +// "memory" as in physical memory for a buffer type, in bytes +struct llama_memory_breakdown_data { + size_t model = 0; // memory allocated for the model + size_t context = 0; // memory allocated for the context + size_t compute = 0; // memory allocated for temporary compute buffers + + size_t total() const { + return model + context + compute; + } +}; + +struct llama_device_memory_data { + int64_t total; + int64_t free; + llama_memory_breakdown_data mb; +}; + +// TODO: convert to C-style data structure +using llama_memory_breakdown = std::map; + +LLAMA_API int32_t llama_model_n_expert (const struct llama_model * model); +LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); + +LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); + +LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index aac0d41f2b4..badcbfd0fbb 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #define MAX_REPETITION_THRESHOLD 2000 @@ -454,6 +455,7 @@ const char * llama_grammar_parser::parse_sequence( bool is_nested) { size_t last_sym_start = rule.size(); const char * pos = src; + uint64_t n_prev_rules = 1; // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used // (though it's technically the same as -1 now) @@ -481,6 +483,18 @@ const char * llama_grammar_parser::parse_sequence( // S' ::= S | llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + // Calculate the total number of rules that will be generated by this repetition + uint64_t total_rules = 1; // Start with 1 for the original rule + if (!no_max && max_times > 0) { + total_rules = max_times; + } else if (min_times > 0) { + total_rules = min_times; + } + + if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) { + throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity"); + } + if (min_times == 0) { rule.resize(last_sym_start); } else { @@ -508,12 +522,15 @@ const char * llama_grammar_parser::parse_sequence( if (n_opt > 0) { rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); } + n_prev_rules *= total_rules; + GGML_ASSERT(n_prev_rules >= 1); }; while (*pos) { if (*pos == '"') { // literal string pos++; last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != '"') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -531,6 +548,7 @@ const char * llama_grammar_parser::parse_sequence( start_type = LLAMA_GRETYPE_CHAR_NOT; } last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != ']') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -561,6 +579,7 @@ const char * llama_grammar_parser::parse_sequence( auto token_pair = parse_token(vocab, pos); const char * token_end = token_pair.second; last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({type, token_pair.first}); pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference @@ -568,12 +587,15 @@ const char * llama_grammar_parser::parse_sequence( uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); + uint32_t n_rules_before = symbol_ids.size(); uint32_t sub_rule_id = generate_symbol_id(rule_name); pos = parse_alternates(pos, rule_name, sub_rule_id, true); + n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before); last_sym_start = rule.size(); // output reference to synthesized rule rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); @@ -583,6 +605,7 @@ const char * llama_grammar_parser::parse_sequence( pos = parse_space(pos + 1, is_nested); } else if (*pos == '.') { // any char last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); pos = parse_space(pos + 1, is_nested); } else if (*pos == '*') { @@ -830,32 +853,54 @@ static bool llama_grammar_match_token( static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { - if (stack.empty()) { - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - new_stacks.emplace_back(stack); + llama_grammar_stacks & new_stacks) { + std::vector todo; + todo.push_back(stack); + + auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) { + return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(), + [](const llama_grammar_element * pa, const llama_grammar_element * pb) { + return pa < pb; // Compare pointer addresses + } + ); + }; + + std::set seen(stack_cmp); + + while (!todo.empty()) { + llama_grammar_stack curr_stack = std::move(todo.back()); + todo.pop_back(); + + if (seen.find( curr_stack) != seen.end()) { + continue; } - return; - } + seen.insert(curr_stack); - const llama_grammar_element * pos = stack.back(); + if (curr_stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { + new_stacks.emplace_back(std::move(curr_stack)); + } + continue; + } - switch (pos->type) { + const llama_grammar_element * pos = curr_stack.back(); + + switch (pos->type) { case LLAMA_GRETYPE_RULE_REF: { const size_t rule_id = static_cast(pos->value); const llama_grammar_element * subpos = rules[rule_id].data(); do { // init new stack without the top (pos) - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1); if (!llama_grammar_is_end_of_sequence(pos + 1)) { // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); + next_stack.push_back(pos + 1); } if (!llama_grammar_is_end_of_sequence(subpos)) { // if alternate is nonempty, add to stack - new_stack.push_back(subpos); + next_stack.push_back(subpos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + todo.push_back(std::move(next_stack)); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -874,9 +919,9 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR_ANY: case LLAMA_GRETYPE_TOKEN: case LLAMA_GRETYPE_TOKEN_NOT: - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have - new_stacks.emplace_back(stack); + new_stacks.emplace_back(std::move(curr_stack)); } break; default: @@ -884,6 +929,7 @@ static void llama_grammar_advance_stack( // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on // those GGML_ABORT("fatal error"); + } } } diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 9a215bb77a0..2ff23f87cf4 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -1,6 +1,7 @@ #include "llama-graph.h" #include "llama-impl.h" +#include "llama-model.h" #include "llama-batch.h" #include "llama-cparams.h" @@ -19,7 +20,7 @@ // dedup helpers -static ggml_tensor * build_kq_mask( +static ggml_tensor * build_attn_inp_kq_mask( ggml_context * ctx, const llama_kv_cache_context * mctx, const llama_ubatch & ubatch, @@ -28,7 +29,11 @@ static ggml_tensor * build_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_kq_mask"); + + return res; } static bool can_reuse_kq_mask( @@ -52,6 +57,21 @@ static bool can_reuse_kq_mask( // impl +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -429,6 +449,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->set_input_v_rot(self_v_rot); + } } bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { @@ -476,6 +504,22 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->get_base()->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->get_base()->set_input_v_rot(self_v_rot); + } + + if (self_k_rot_swa) { + mctx->get_swa()->set_input_k_rot(self_k_rot_swa); + } + + if (self_v_rot_swa) { + mctx->get_swa()->set_input_v_rot(self_v_rot_swa); + } } bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { @@ -532,6 +576,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + if (inp_attn->self_k_rot) { + mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -630,6 +682,22 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } + if (inp_attn->self_k_rot) { + attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot); + } + + if (inp_attn->self_k_rot_swa) { + attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa); + } + + if (inp_attn->self_v_rot_swa) { + attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -992,6 +1060,84 @@ ggml_tensor * llm_graph_context::build_norm( return cur; } + +llm_graph_qkv llm_graph_context::build_qkv( + const llama_layer & layer, + ggml_tensor * cur, + int64_t n_embd_head, + int64_t n_head, + int64_t n_head_kv, + int il) const { + const int64_t n_embd_q = n_embd_head * n_head; + const int64_t n_embd_kv = n_embd_head * n_head_kv; + + ggml_tensor * Qcur, * Kcur, * Vcur; + + if (layer.wqkv) { + // fused QKV path + ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s); + cb(qkv, "wqkv", il); + if (layer.wqkv_b) { + qkv = ggml_add(ctx0, qkv, layer.wqkv_b); + cb(qkv, "wqkv_b", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(qkv, "wqkv_clamped", il); + } + Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], 0); + Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], + ggml_row_size(qkv->type, n_embd_q)); + Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], + ggml_row_size(qkv->type, n_embd_q + n_embd_kv)); + } else { + // separate Q/K/V path + Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur, "Qcur", il); + if (layer.wq_b) { + Qcur = ggml_add(ctx0, Qcur, layer.wq_b); + cb(Qcur, "Qcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur_clamped", il); + } + Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + cb(Kcur, "Kcur", il); + if (layer.wk_b) { + Kcur = ggml_add(ctx0, Kcur, layer.wk_b); + cb(Kcur, "Kcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur_clamped", il); + } + Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Vcur, "Vcur", il); + if (layer.wv_b) { + Vcur = ggml_add(ctx0, Vcur, layer.wv_b); + cb(Vcur, "Vcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur_clamped", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + return { Qcur, Kcur, Vcur }; +} + + ggml_tensor * llm_graph_context::build_ffn( ggml_tensor * cur, ggml_tensor * up, @@ -1516,9 +1662,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); + cb(experts, "ffn_moe_weighted", il); } + ggml_build_forward_expand(gf, experts); + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; assert(n_expert_used > 0); @@ -1538,6 +1686,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); + + ggml_build_forward_expand(gf, moe_out); } if (hparams.n_expert_used == 1) { @@ -1665,7 +1815,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { ggml_tensor * llm_graph_context::build_inp_out_ids() const { // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, - // but this would make the graph topology depend on the number of output tokens, which can interere with + // but this would make the graph topology depend on the number of output tokens, which can interfere with // features that require constant topology such as pipeline parallelism // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471 //if (n_outputs < n_tokens) { @@ -1940,6 +2090,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -1973,7 +2124,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2002,13 +2153,13 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0); + return inp; } @@ -2024,6 +2175,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2034,6 +2186,15 @@ ggml_tensor * llm_graph_context::build_attn( int il) const { GGML_ASSERT(v_mla == nullptr); + if (inp->self_k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + } + + if (inp->self_v_rot) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -2061,11 +2222,20 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (inp->self_v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + } + if (wo) { - cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators + cur = build_lora_mm(wo, cur); ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + if (wo_s) { + cur = ggml_mul(ctx0, cur, wo_s); + } + } else { + cur = build_lora_mm(wo, cur, wo_s); } } @@ -2090,9 +2260,7 @@ static std::unique_ptr build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } @@ -2111,6 +2279,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_k * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2145,10 +2314,15 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + cur = build_lora_mm(wo, cur); ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + if (wo_s) { + cur = ggml_mul(ctx0, cur, wo_s); + } + } else { + cur = build_lora_mm(wo, cur, wo_s); } } @@ -2163,6 +2337,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2171,6 +2346,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { + const bool is_swa = hparams.is_swa(il); + + auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot; + auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot; + + if (k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot); + if (k_cur) { + k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot); + } + } + if (v_rot) { + if (v_cur) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot); + } + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -2185,8 +2377,6 @@ ggml_tensor * llm_graph_context::build_attn( const auto * mctx_iswa = inp->mctx; - const bool is_swa = hparams.is_swa(il); - const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); // optionally store to KV cache @@ -2211,8 +2401,12 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, v_rot); + } + if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2243,6 +2437,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_cross * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2267,7 +2462,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2293,12 +2488,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - ggml_set_name(inp->self_kq_mask, "self_kq_mask"); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -2307,14 +2498,16 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask_swa); - ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); - + inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; - ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } + inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0); + + inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0); + inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } @@ -2348,7 +2541,7 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, - ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); + ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1]))); return output_states; } @@ -2473,9 +2666,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask); - + inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } @@ -2483,9 +2674,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask_swa); - + inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 4855685ef71..5cb1756c6a9 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -17,6 +17,7 @@ struct ggml_context; struct ggml_tensor; struct llama_cparams; +struct llama_layer; struct llama_memory_context_i; @@ -308,6 +309,10 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + // note: assumes v_rot^2 == I + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + // note: these have to be copies because in order to be able to reuse a graph, its inputs // need to carry these parameters with them. otherwise, they can point to freed // llm_graph_params from a previous batch, causing stack-use-after-return @@ -384,6 +389,12 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + + ggml_tensor * self_k_rot_swa = nullptr; + ggml_tensor * self_v_rot_swa = nullptr; + const llama_hparams hparams; const llama_cparams cparams; @@ -697,6 +708,12 @@ using llm_graph_result_ptr = std::unique_ptr; // used in build_rs to properly order writes and avoid unnecessary copies using llm_graph_get_rows_fn = std::function; +struct llm_graph_qkv { + ggml_tensor * q; // [n_embd_head, n_head, n_tokens] + ggml_tensor * k; // [n_embd_head, n_head_kv, n_tokens] + ggml_tensor * v; // [n_embd_head, n_head_kv, n_tokens] +}; + struct llm_graph_context { const llm_arch arch; @@ -783,6 +800,17 @@ struct llm_graph_context { llm_norm_type type, int il) const; + + // compute Q, K, V projections with optional bias and reshape + // supports both fused wqkv and separate wq/wk/wv paths + llm_graph_qkv build_qkv( + const llama_layer & layer, + ggml_tensor * cur, + int64_t n_embd_head, + int64_t n_head, + int64_t n_head_kv, + int il) const; + ggml_tensor * build_ffn( ggml_tensor * cur, ggml_tensor * up, @@ -882,6 +910,7 @@ struct llm_graph_context { llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -897,6 +926,7 @@ struct llm_graph_context { llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -912,6 +942,7 @@ struct llm_graph_context { llm_graph_input_attn_k * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -928,6 +959,7 @@ struct llm_graph_context { llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional @@ -943,6 +975,7 @@ struct llm_graph_context { llm_graph_input_attn_cross * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 78c0bc27d4d..ac7f9ee8650 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -116,6 +116,7 @@ struct llama_hparams { float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; float rope_freq_scale_train_swa = 1.0f; + float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -209,6 +210,9 @@ struct llama_hparams { // qwen3vl deepstack uint32_t n_deepstack_layers = 0; + // gemma4 per-layer embedding + uint32_t n_embd_per_layer = 0; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/examples/talk-llama/llama-impl.cpp b/examples/talk-llama/llama-impl.cpp index 4c0188ee722..b3a94b946d2 100644 --- a/examples/talk-llama/llama-impl.cpp +++ b/examples/talk-llama/llama-impl.cpp @@ -128,7 +128,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); - case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + case GGUF_TYPE_BOOL: return ((const int8_t *)data)[i] != 0 ? "true" : "false"; default: return format("unknown type %d", type); } } diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 01166fac9ce..09102f549c8 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -13,6 +13,65 @@ #include #include +static bool ggml_is_power_of_2(int n) { + return (n & (n - 1)) == 0; +} + +// orthonormal Walsh-Hadamard rotation matrix +// note: res^2 == I +static void ggml_gen_hadamard(ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + + const int n = tensor->ne[0]; + + assert(ggml_is_power_of_2(n)); + assert(tensor->ne[1] == n); + assert(tensor->ne[2] == 1); + assert(tensor->ne[3] == 1); + + std::vector data_f32; + + float * data = (float *) tensor->data; + + if (tensor->type != GGML_TYPE_F32) { + data_f32.resize(n*n); + data = data_f32.data(); + } + + data[0*n + 0] = 1.0 / sqrtf(n); + + for (int s = 1; s < n; s *= 2) { + for (int i = 0; i < s; i++) { + for (int j = 0; j < s; j++) { + const float val = data[i*n + j]; + + data[(i + s)*n + (j )] = val; + data[(i )*n + (j + s)] = val; + data[(i + s)*n + (j + s)] = -val; + } + } + } + + if (tensor->type != GGML_TYPE_F32) { + ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr); + } +} + +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + // // llama_kv_cache // @@ -110,6 +169,18 @@ llama_kv_cache::llama_kv_cache( continue; } + if (n_embd_head_k_all == 0) { + n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); + } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { + n_embd_head_k_all = -1; + } + + if (n_embd_head_v_all == 0) { + n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il); + } else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) { + n_embd_head_v_all = -1; + } + // [TAG_V_CACHE_VARIABLE] const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); @@ -209,6 +280,48 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + n_embd_head_k_all > 0 && + ggml_is_quantized(type_k) && + hparams.n_embd_head_k() % 64 == 0; + + attn_rot_v = + !attn_rot_disable && + n_embd_head_v_all > 0 && + ggml_is_quantized(type_v) && + hparams.n_embd_head_v() % 64 == 0; + + LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); + LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); + + // pre-compute the haramard matrices and keep them in host memory + // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers + if (attn_rot_k || attn_rot_v) { + for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) { + attn_rot_hadamard[n] = std::vector(n*n); + + ggml_init_params params = { + /* .mem_size = */ 1*ggml_tensor_overhead(), + /* .mem_buffer = */ nullptr, + /* .no_alloc = */ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + + ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n); + tmp->data = attn_rot_hadamard[n].data(); + + ggml_gen_hadamard(tmp); + } + } + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; } @@ -1004,6 +1117,14 @@ bool llama_kv_cache::get_has_shift() const { return result; } +ggml_type llama_kv_cache::type_k() const { + return layers[0].k->type; +} + +ggml_type llama_kv_cache::type_v() const { + return layers[0].v->type; +} + uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; @@ -1189,6 +1310,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama return v_idxs; } +ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_k) { + int nrot = 64; + + // TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V) + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088 + do { + nrot *= 2; + } while (n_embd_head_k_all % nrot == 0); + nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_k_rot"); + } + + return res; +} + +ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_v) { + int nrot = 64; + // using smaller rotation matrices for V seems beneficial + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570 + //do { + // nrot *= 2; + //} while (hparams.n_embd_head_v() % nrot == 0); + //nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_v_rot"); + } + + return res; +} + void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1507,6 +1669,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch } } +void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + +void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + size_t llama_kv_cache::total_size() const { size_t size = 0; @@ -1542,6 +1722,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -1561,17 +1742,22 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // ref: https://github.com/ggml-org/llama.cpp/pull/13870 ? LLAMA_ROPE_TYPE_NEOX : hparams.rope_type; - ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + // rotate back + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_rope_ext(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + // rotate fwd + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_cpy(ctx, tmp, cur); } else { // we rotate only the first n_rot dimensions @@ -1592,6 +1778,9 @@ class llm_graph_input_k_shift : public llm_graph_input_i { ggml_tensor * k_shift; // I32 [kv_size*n_stream] + // note: assumes k_rot^2 == I + ggml_tensor * k_rot = nullptr; + const llama_kv_cache * kv_self; }; @@ -1601,6 +1790,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { if (k_shift) { kv_self->set_input_k_shift(k_shift); } + + if (k_rot) { + kv_self->set_input_k_rot(k_rot); + } } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { @@ -1612,6 +1805,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + inp->k_rot = build_input_k_rot(ctx); + const auto & cparams = lctx->get_cparams(); for (const auto & layer : layers) { @@ -1636,7 +1831,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_row_size(layer.k->type, n_embd_k_gqa), ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } @@ -2240,6 +2435,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } +ggml_type llama_kv_cache_context::type_k() const { + return kv->type_k(); +} + +ggml_type llama_kv_cache_context::type_v() const { + return kv->type_v(); +} + ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } @@ -2264,6 +2467,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con return kv->build_input_v_idxs(ctx, ubatch); } +ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const { + return kv->build_input_k_rot(ctx); +} + +ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const { + return kv->build_input_v_rot(ctx); +} + void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } @@ -2283,3 +2494,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } + +void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const { + kv->set_input_k_rot(dst); +} + +void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const { + kv->set_input_v_rot(dst); +} diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 33c78c5f210..0b62dc7b232 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -152,6 +152,9 @@ class llama_kv_cache : public llama_memory_i { bool get_has_shift() const; + ggml_type type_k() const; + ggml_type type_v() const; + // // graph_build API // @@ -191,6 +194,9 @@ class llama_kv_cache : public llama_memory_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; @@ -199,6 +205,9 @@ class llama_kv_cache : public llama_memory_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: const llama_model & model; const llama_hparams & hparams; @@ -226,6 +235,18 @@ class llama_kv_cache : public llama_memory_i { // SWA const uint32_t n_swa = 0; + // env: LLAMA_ATTN_ROT_DISABLE + bool attn_rot_k = false; + bool attn_rot_v = false; + + // if all layers participating in the cache have constant head size, the value is stored here + // otherwise the value is -1 + int32_t n_embd_head_k_all = 0; + int32_t n_embd_head_v_all = 0; + + // pre-computed hadamard martrices + std::unordered_map> attn_rot_hadamard; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; @@ -262,6 +283,7 @@ class llama_kv_cache : public llama_memory_i { ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -328,12 +350,15 @@ class llama_kv_cache_context : public llama_memory_context_i { uint32_t get_n_kv() const; + ggml_type type_k() const; + ggml_type type_v() const; + // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location - // note: the heads in k_cur and v_cur should be layed out contiguously in memory + // note: the heads in k_cur and v_cur should be laid out contiguously in memory // - k_cur [n_embd_head_k, n_head_k, n_tokens] // - k_idxs [n_tokens] // - v_cur [n_embd_head_v, n_head_v, n_tokens] @@ -347,6 +372,9 @@ class llama_kv_cache_context : public llama_memory_context_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -354,6 +382,9 @@ class llama_kv_cache_context : public llama_memory_context_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: llama_memory_status status; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 411769672af..10e6b459797 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index a1b45e4a3cc..4ce1af592c1 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 6e8413f493d..9287fe45e96 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -1,5 +1,6 @@ #include "llama-memory-recurrent.h" +#include "ggml-backend.h" #include "llama-impl.h" #include "llama-io.h" #include "llama-batch.h" @@ -91,8 +92,8 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size); - ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -928,11 +929,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id)); - // TODO: llama_memory_recurrent should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max); return false; } diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index c03228e9ce2..ed572da7fb5 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -40,6 +40,14 @@ #include #endif +#ifdef _WIN32 +# define llama_mmap_ftell _ftelli64 +# define llama_mmap_fseek _fseeki64 +#else +# define llama_mmap_ftell ftello +# define llama_mmap_fseek fseeko +#endif + // TODO: consider moving to llama-impl.h if needed in more places #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { @@ -86,6 +94,14 @@ struct llama_file::impl { seek(0, SEEK_SET); } + impl(FILE * file) : owns_fp(false) { + fp = file; + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + size_t tell() const { LARGE_INTEGER li; li.QuadPart = 0; @@ -159,7 +175,7 @@ struct llama_file::impl { } ~impl() { - if (fp) { + if (fp && owns_fp) { std::fclose(fp); } } @@ -209,9 +225,16 @@ struct llama_file::impl { seek(0, SEEK_SET); } + impl(FILE * file) : fname("(file*)"), owns_fp(false) { + fp = file; + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + size_t tell() const { if (fd == -1) { - long ret = std::ftell(fp); + off_t ret = llama_mmap_ftell(fp); if (ret == -1) { throw std::runtime_error(format("ftell error: %s", strerror(errno))); } @@ -229,7 +252,7 @@ struct llama_file::impl { void seek(size_t offset, int whence) const { off_t ret = 0; if (fd == -1) { - ret = std::fseek(fp, (long) offset, whence); + ret = llama_mmap_fseek(fp, offset, whence); } else { ret = lseek(fd, offset, whence); } @@ -353,7 +376,7 @@ struct llama_file::impl { ~impl() { if (fd != -1) { close(fd); - } else { + } else if (owns_fp) { std::fclose(fp); } } @@ -369,10 +392,14 @@ struct llama_file::impl { FILE * fp{}; size_t size{}; + bool owns_fp = true; }; llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) : pimpl(std::make_unique(fname, mode, use_direct_io)) {} + +llama_file::llama_file(FILE * file) : pimpl(std::make_unique(file)) {} + llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } diff --git a/examples/talk-llama/llama-mmap.h b/examples/talk-llama/llama-mmap.h index 29ce4d24685..b7d5c61e95f 100644 --- a/examples/talk-llama/llama-mmap.h +++ b/examples/talk-llama/llama-mmap.h @@ -15,6 +15,7 @@ using llama_mlocks = std::vector>; struct llama_file { llama_file(const char * fname, const char * mode, bool use_direct_io = false); + llama_file(FILE * file); ~llama_file(); size_t tell() const; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 413f34c2268..4e65a45a50d 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -36,6 +36,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; @@ -374,8 +375,9 @@ namespace GGUFMeta { } } else { if (arr_info.gt == GGUF_TYPE_BOOL) { - std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(), [](bool x) { - return static_cast(x); + const int8_t * values = (const int8_t *) arr_info.data; + std::transform(values, values + arr_info.length, result.begin(), [](int8_t x) { + return static_cast(x != 0); }); } else { std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); @@ -511,6 +513,7 @@ llama_model_loader::llama_model_loader( void * set_tensor_data_ud, const std::string & fname, std::vector & splits, + FILE * file, bool use_mmap, bool use_direct_io, bool check_tensors, @@ -658,6 +661,36 @@ llama_model_loader::llama_model_loader( LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); } + } else if (file != nullptr) { + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + + metadata_ptr.reset(gguf_init_from_file_ptr(file, params)); + metadata = metadata_ptr.get(); + if (metadata == nullptr) { + throw std::runtime_error(format("%s: failed to load model from file pointer", __func__)); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(file)); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the main file. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, metadata, cur)); + } } else { get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); @@ -669,7 +702,7 @@ llama_model_loader::llama_model_loader( fver = (enum llama_fver) gguf_get_version(metadata); LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", - __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + __func__, n_kv, n_tensors, fname.empty() ? "(file*)" : fname.c_str(), llama_file_version_name(fver)); // determine file type based on the number of tensors for each quantization and print meta data // TODO: make optional @@ -726,6 +759,7 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break; + case GGML_TYPE_Q1_0: ftype = LLAMA_FTYPE_MOSTLY_Q1_0; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -1127,6 +1161,12 @@ struct ggml_tensor * llama_model_loader::create_tensor( if (overrides->buft == ggml_backend_cpu_buffer_type()) { // when overriding to a CPU buffer, consider the extra buffer types buft = select_weight_buft(hparams, t_meta, op, buft_list_cpu); + if (use_mmap) { + static std::once_flag once; + std::call_once(once, [] { + LLAMA_LOG_WARN("llama_model_loader: tensor overrides to CPU are used with mmap enabled - consider using --no-mmap for better performance\n"); + }); + } } else { buft = overrides->buft; } diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index ed5de729caf..7b3d6703c03 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -125,6 +125,7 @@ struct llama_model_loader { void * set_tensor_data_ud, const std::string & fname, std::vector & splits, // optional, only need if the split does not follow naming scheme + FILE * file, bool use_mmap, bool use_direct_io, bool check_tensors, diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 6f6538aeccd..26864c18e97 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -1,7 +1,9 @@ #include "llama-model-saver.h" +#include "ggml.h" #include "gguf.h" +#include "llama-arch.h" #include "llama.h" #include "llama-hparams.h" #include "llama-model.h" @@ -10,8 +12,33 @@ #include #include +bool llama_model_saver_supports_arch(llm_arch arch) { + switch (arch) { + case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + case LLM_ARCH_PLAMO3: + case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA3N: + case LLM_ARCH_COHERE2: + case LLM_ARCH_OLMO2: + case LLM_ARCH_BITNET: + case LLM_ARCH_T5: + case LLM_ARCH_EXAONE_MOE: + case LLM_ARCH_AFMOE: + case LLM_ARCH_APERTUS: + case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: + return false; + default: + return true; + } +} + llama_model_saver::llama_model_saver(const struct llama_model * model) : - gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) {} + gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) { + GGML_ASSERT(llama_model_saver_supports_arch(model->arch)); +} llama_model_saver::llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx) : gguf_ctx(gguf_ctx == nullptr ? gguf_init_empty() : gguf_ctx), gguf_ctx_owned(gguf_ctx == nullptr), model(nullptr), llm_kv(arch) {} @@ -105,7 +132,10 @@ void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { return; } if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) { - GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME + const std::string tensor_name = tensor->name; + GGML_ASSERT( + tensor_name == "rope_freqs.weight" || tensor_name == "rope_factors_long.weight" || + tensor_name == "rope_factors_short.weight"); // FIXME return; } gguf_add_tensor(gguf_ctx, tensor); @@ -127,6 +157,7 @@ void llama_model_saver::add_kv_from_model() { tokens[id] = token_data.text; scores[id] = token_data.score; + // FIXME should this be treated as flags? switch(token_data.attr) { case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; @@ -134,6 +165,9 @@ void llama_model_saver::add_kv_from_model() { case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + // case LLAMA_TOKEN_ATTR_NORMALIZED: ??? + // case LLAMA_TOKEN_ATTR_LSTRIP: ??? + // case LLAMA_TOKEN_ATTR_RSTRIP: ??? case LLAMA_TOKEN_ATTR_UNDEFINED: default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; } @@ -144,6 +178,19 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_GENERAL_ARCHITECTURE, model->arch_name()); // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); + // add_kv(LLM_KV_GENERAL_FILE_TYPE, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_SEQUENCE, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TOP_K, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TOP_P, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIN_P, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TEMP, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, ???); add_kv(LLM_KV_GENERAL_NAME, model->name); // add_kv(LLM_KV_GENERAL_AUTHOR, ???); // add_kv(LLM_KV_GENERAL_VERSION, ???); @@ -163,17 +210,31 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_chexp); + add_kv(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp); + add_kv(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp); add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???); add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert); add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + add_kv(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups); + add_kv(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used); add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + add_kv(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm); + add_kv(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + add_kv(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); + add_kv(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); + add_kv(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers); + add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers); + add_kv(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers); add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); + add_kv(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer); add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + add_kv(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping); add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm); add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers); @@ -181,6 +242,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count); + add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); @@ -188,22 +252,39 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full); add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full); - add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); - add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + add_kv(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + add_kv(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + add_kv(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + add_kv(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + add_kv(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + add_kv(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate); add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + // add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, ???); add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + add_kv(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale); + add_kv(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length); + add_kv(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full); add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa); + add_kv(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections); add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); + add_kv(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor); @@ -211,6 +292,10 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn); add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned); add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + add_kv(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor); + add_kv(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor); + add_kv(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast); + add_kv(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow); // TODO: implement split file support // add_kv(LLM_KV_SPLIT_NO, ???); @@ -221,8 +306,11 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + add_kv(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms); + add_kv(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model()); @@ -260,15 +348,39 @@ void llama_model_saver::add_kv_from_model() { // TODO: implement LoRA support // add_kv(LLM_KV_ADAPTER_TYPE, ???); // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???); + // add_kv(LLM_KV_ADAPTER_LORA_TASK_NAME, ???); + // add_kv(LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, ???); + // add_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, ???); + + add_kv(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); + add_kv(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); + + add_kv(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd); + add_kv(LLM_KV_CONVNEXT_BLOCK_COUNT, hparams.convnext.n_layer); + + add_kv(LLM_KV_CLASSIFIER_OUTPUT_LABELS, model->classifier_labels); + + add_kv(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + + add_kv(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n); + add_kv(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p); + add_kv(LLM_KV_XIELU_BETA, hparams.xielu_beta); + add_kv(LLM_KV_XIELU_EPS, hparams.xielu_eps); // deprecated // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???); // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???); // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???); + + add_kv(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in); + add_kv(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out); + add_kv(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in); + add_kv(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out); } void llama_model_saver::add_tensors_from_model() { - if (std::string(model->output->name) != std::string(model->tok_embd->name)) { + if (model->output != nullptr && + std::string(model->output->name) != std::string(model->tok_embd->name)) { add_tensor(model->tok_embd); // some models use the same tensor for tok_embd and output } add_tensor(model->type_embd); @@ -297,3 +409,6 @@ void llama_model_saver::save(const std::string & path_model) { gguf_write_to_file(gguf_ctx, path_model.c_str(), false); } +void llama_model_saver::save(FILE * file) { + gguf_write_to_file_ptr(gguf_ctx, file, false); +} diff --git a/examples/talk-llama/llama-model-saver.h b/examples/talk-llama/llama-model-saver.h index 2b3541ce6c5..36a715e2b6b 100644 --- a/examples/talk-llama/llama-model-saver.h +++ b/examples/talk-llama/llama-model-saver.h @@ -6,6 +6,9 @@ #include +// FIXME temporary function for better error messages +bool llama_model_saver_supports_arch(llm_arch arch); + struct llama_model_saver { struct gguf_context * gguf_ctx = nullptr; const bool gguf_ctx_owned; @@ -37,4 +40,5 @@ struct llama_model_saver { void add_tensors_from_model(); void save(const std::string & path_model); + void save(FILE * file); }; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index e8e1bbf1cd1..9e2a13cbd43 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1,6 +1,8 @@ #include "llama-model.h" -#include "ggml.h" +#include "llama-arch.h" +#include "llama-ext.h" +#include "llama-hparams.h" #include "llama-impl.h" #include "llama-mmap.h" #include "llama-cparams.h" @@ -12,10 +14,11 @@ #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" -#include "ggml-cpp.h" - #include "models/models.h" +#include "ggml.h" +#include "ggml-cpp.h" + #include #include #include @@ -24,9 +27,358 @@ #include #include #include +#include #include #include #include +#include +#include + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { + const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; + const llama_hparams & hparams = ud->model->hparams; + const std::string tensor_name = tensor->name; + + const std::regex pattern_q_weight ("blk\\.\\d*\\.attn_q.weight"); + const std::regex pattern_kv_weight ("blk\\.\\d*\\.attn_(k|v).weight"); + const std::regex pattern_qkv_weight ("blk\\.\\d*\\.attn_qkv.weight"); + const std::regex pattern_q_bias ("blk\\.\\d*\\.attn_q\\.bias"); + const std::regex pattern_kv_bias ("blk\\.\\d*\\.attn_(k|v)\\.bias"); + const std::regex pattern_qkv_bias ("blk\\.\\d*\\.attn_qkv.bias"); + const std::regex pattern_qk_norm ("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); + const std::regex pattern_kv_cache ("cache_(k|v)_l\\d*"); + const std::regex pattern_attn_sinks ("blk\\.\\d*\\.attn_sinks.weight"); + const std::regex pattern_attn_out_weight ("blk\\.\\d*\\.attn_output.weight"); + const std::regex pattern_attn_out_bias ("blk\\.\\d*\\.attn_output.bias"); + const std::regex pattern_attn_gate_weight("blk\\.\\d*\\.attn_gate.weight"); + + const std::regex pattern_ssm_dt ("blk\\.\\d*\\.ssm_dt.bias"); + const std::regex pattern_ssm_a ("blk\\.\\d*\\.ssm_a"); + const std::regex pattern_ssm_alpha ("blk\\.\\d*\\.ssm_alpha.weight"); + const std::regex pattern_ssm_beta ("blk\\.\\d*\\.ssm_beta.weight"); + const std::regex pattern_ssm_beta_alpha ("blk\\.\\d*\\.ssm_ba.weight"); + const std::regex pattern_r_cache ("cache_r_l\\d*"); + const std::regex pattern_s_cache ("cache_s_l\\d*"); + const std::regex pattern_ssm_conv1d ("blk\\.\\d*\\.ssm_conv1d.weight"); + const std::regex pattern_ssm_out_weight ("blk\\.\\d*\\.ssm_out.weight"); + + const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); + const std::regex pattern_ffn_up_gate_bias ("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); + const std::regex pattern_ffn_gate_up_weight("blk\\.\\d*\\.ffn_gate_up(_exps)?.weight"); + const std::regex pattern_ffn_down_weight ("blk\\.\\d*\\.ffn_down(_exps)?.weight"); + const std::regex pattern_ffn_down_bias ("blk\\.\\d*\\.ffn_down.bias"); + const std::regex pattern_ffn_down_exps_bias("blk\\.\\d*\\.ffn_down_exps.bias"); + + const std::regex pattern_output_weight("output\\.weight"); + const std::regex pattern_output_bias ("output\\.bias"); + + struct tensor_config { + ggml_backend_meta_split_axis axis; + + const ggml_tensor * tensor_axis_0; + + uint32_t il; + size_t rotation; // when assigning tensor slices, rotate how the rounding is done for more even allocation + }; + + auto get_tensor_config_impl = [&]( + const ggml_backend_meta_split_axis axis, const std::string & suffix = "", const std::string & suffix_fallback = "") -> tensor_config { + // the layers in a tensor can be inhomogeneous, if the pattern is cleanly divided by the number of GPUs there can be aliasing effects, + // count only the same type of previous layers to avoid this + auto get_il_eff = [&](const size_t il){ + size_t ret = 0; + const bool il_is_recurrent = hparams.is_recurrent(il); + const bool il_is_swa = hparams.is_swa(il); + for (size_t il_prev = 0; il_prev < il; il_prev++) { + ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa; + } + return ret; + }; + + uint32_t il; + std::string prefix; + size_t rotation; + if (tensor_name.substr(0, 4) == "blk.") { + const size_t length_prefix = tensor_name.find('.', 4); + GGML_ASSERT(length_prefix != std::string::npos); + prefix = tensor_name.substr(0, length_prefix + 1); + il = std::stoull(tensor_name.substr(4, length_prefix)); + rotation = get_il_eff(il) % ud->n_devices; + } else if (tensor_name.substr(0, 6) == "cache_") { + const size_t layer_index_start = tensor_name.find("_l", 6); + GGML_ASSERT(layer_index_start != std::string::npos); + il = std::stoull(tensor_name.substr(layer_index_start + 2)); + prefix = "blk." + std::to_string(il) + "."; + rotation = get_il_eff(il) % ud->n_devices; + } else { + il = 0; + rotation = hparams.n_layer % ud->n_devices; + } + const ggml_tensor * tensor_axis_0 = suffix.empty() ? tensor : ud->model->get_tensor((prefix + suffix).c_str()); + if (tensor_axis_0 == nullptr) { + GGML_ASSERT(!suffix_fallback.empty()); + tensor_axis_0 = ud->model->get_tensor((prefix + suffix_fallback).c_str()); + } + GGML_ASSERT(tensor_axis_0 != nullptr); + return {axis, tensor_axis_0, il, rotation}; + }; + + auto get_tensor_config = [&]() -> tensor_config { + // standard attention + if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_q_bias) || std::regex_match(tensor_name, pattern_kv_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_qkv_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if ( std::regex_match(tensor_name, pattern_qkv_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + if (std::regex_match(tensor_name, pattern_qk_norm)) { + return get_tensor_config_impl(tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_kv_cache) || std::regex_match(tensor_name, pattern_attn_sinks)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_attn_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + if (std::regex_match(tensor_name, pattern_attn_out_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + } + + if (std::regex_match(tensor_name, pattern_attn_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta) || + std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_r_cache) || std::regex_match(tensor_name, pattern_s_cache)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_conv1d)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + + // FFN + if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_up_gate_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_down_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_down_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + } + if (std::regex_match(tensor_name, pattern_ffn_down_exps_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_PARTIAL); + } + + // output + if (std::regex_match(tensor_name, pattern_output_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if (std::regex_match(tensor_name, pattern_output_bias)) { + const ggml_tensor * output_weight = ud->model->get_tensor("output.weight"); + GGML_ASSERT(output_weight != nullptr); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + + // everything else + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + }; + + auto get_split_segments = [&](int axis, uint32_t il) -> std::vector { + if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + + // both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different: + // - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern) + // - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern) + if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return {key_dim, key_dim, value_dim}; + } + } else { + const int64_t head_ratio = n_v_heads / n_k_heads; + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return std::vector(2 + head_ratio, key_dim); + } + if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return std::vector(head_ratio, key_dim); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return std::vector(head_ratio, n_k_heads); + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return std::vector(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return std::vector(head_ratio, n_k_heads * head_v_dim * head_v_dim); + } + } + + // the FFN is the same for Qwen 3 Next and Qwen 3.5: + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + const int64_t n_ff_exp = hparams.n_ff_exp; + GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); + return {n_ff_exp, n_ff_exp}; + } + return {tensor->ne[axis]}; + } + + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il); + GGML_ASSERT(hparams.n_embd_k_gqa() == n_embd_gqa); + GGML_ASSERT(tensor->ne[axis] == n_embd + 2*n_embd_gqa); + return {n_embd, n_embd_gqa, n_embd_gqa}; + } + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + const int64_t n_ff_exp = hparams.n_ff_exp; + GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); + return {n_ff_exp, n_ff_exp}; + } + return {tensor->ne[axis]}; + }; + + auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector & segments) -> std::vector { + if (hparams.is_recurrent(il)) { + // linear attention + const int64_t head_dim = hparams.ssm_d_state; + const int64_t granularity_qkv = std::lcm(blck_size, head_dim); + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || + std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return std::vector(segments.size(), granularity_qkv); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return std::vector(segments.size(), granularity_qkv / head_dim); + } + if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return std::vector(segments.size(), 2 * (granularity_qkv / head_dim)); + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return std::vector(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1)); + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return std::vector(segments.size(), granularity_qkv * head_dim); + } + } else { + // regular attention + const uint32_t n_gqa = hparams.n_gqa(il); + const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il); + if (std::regex_match(tensor_name, pattern_attn_sinks)) { + GGML_ASSERT(segments.size() == 1); + return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa}; + } + + const int64_t granularity_q = std::lcm(n_embd_q, blck_size); + if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) { + GGML_ASSERT(segments.size() == 1); + // some models have Q gate tensors, for those cases the granularity needs to be doubled: + if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { + return {std::lcm(2*n_embd_q, blck_size)}; + } + return {granularity_q}; + } + if (std::regex_match(tensor_name, pattern_attn_out_weight)) { + GGML_ASSERT(segments.size() == 1); + return {granularity_q}; + } + + const int64_t granularity_kv = granularity_q / n_gqa; + if (std::regex_match(tensor_name, pattern_kv_weight) || + std::regex_match(tensor_name, pattern_kv_bias) || + std::regex_match(tensor_name, pattern_kv_cache)) { + GGML_ASSERT(segments.size() == 1); + return {granularity_kv}; + } + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { + GGML_ASSERT(segments.size() == 3); + return {granularity_q, granularity_kv, granularity_kv}; + } + } + + // FFN + if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || + std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { + GGML_ASSERT(segments.size() <= 2); + return std::vector(segments.size(), blck_size); + } + + // everything else + GGML_ASSERT(segments.size() == 1); + return {1}; + }; + + ggml_backend_meta_split_state split_state; + memset(&split_state, 0, sizeof(split_state)); + tensor_config tc = get_tensor_config(); + split_state.axis = tc.axis; + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + const int64_t ne_full = tensor->ne[split_state.axis]; + const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type); + const float * tensor_split = ud->model->tensor_split(); + std::vector tensor_split_scan; + tensor_split_scan.reserve(ud->n_devices); + for (size_t j = 0; j < ud->n_devices; j++) { + tensor_split_scan.push_back(tensor_split == nullptr ? 0.0f : tensor_split[(j + tc.rotation) % ud->n_devices]); + if (j > 0) { + tensor_split_scan[j] += tensor_split_scan[j - 1]; + } + } + const std::vector segments = get_split_segments(split_state.axis, tc.il); + const std::vector granularity = get_split_granularity(blck_size, tc.il, segments); + for (size_t is = 0; is < segments.size(); is++) { + const int64_t ne_s = segments[is]; + const int64_t g_s = granularity[is]; + GGML_ASSERT(ne_full % g_s == 0); + int64_t low = 0; + size_t j = 0; + for (; j < ud->n_devices - 1; j++) { + int64_t high = tensor_split_scan.back() == 0.0f ? + ne_s * (j+1)/ud->n_devices : ne_s * tensor_split_scan[j]/tensor_split_scan.back(); + if (high % g_s != 0) { + high -= high % g_s; + } + split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = high - low; + low = high; + } + split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = ne_s - low; + } + split_state.n_segments = segments.size(); + } else { + memset(split_state.ne, 0, sizeof(split_state.ne)); + split_state.n_segments = 1; + } + return split_state; + GGML_UNUSED(userdata); +} const char * llm_type_name(llm_type type) { switch (type) { @@ -93,6 +445,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_26B: return "26B"; case LLM_TYPE_27B: return "27B"; case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_31B: return "31B"; case LLM_TYPE_32B: return "32B"; case LLM_TYPE_34B: return "34B"; case LLM_TYPE_35B: return "35B"; @@ -127,6 +480,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_24B_A2B: return "24B.A2B"; + case LLM_TYPE_26B_A4B: return "26B.A4B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_35B_A3B: return "35B.A3B"; @@ -181,7 +535,7 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -203,10 +557,10 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer if (!no_host) { - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + for (const auto & dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev.dev); if (buft) { - buft_list.emplace_back(dev, buft); + buft_list.emplace_back(dev.dev, buft); break; } } @@ -273,14 +627,16 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s // add the device extra buffer type (if any) ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); - - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(dev, *extra_bufts); - ++extra_bufts; + if (reg) { + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(dev, *extra_bufts); + ++extra_bufts; + } } } @@ -342,6 +698,9 @@ void llama_model::load_arch(llama_model_loader & ml) { if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } + if (!devices.empty() && devices[0].is_meta && !llm_arch_supports_sm_tensor(arch)) { + throw std::runtime_error(std::string("LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '") + llm_arch_name(arch) + "'"); + } } void llama_model::load_hparams(llama_model_loader & ml) { @@ -370,12 +729,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); + if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) { + if (hparams.n_expert <= 1) { + hparams.n_expert = 0; + hparams.n_expert_used = 0; + } + } + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); @@ -454,6 +822,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false); // non-transformer models do not have attention heads if (hparams.n_head() > 0) { @@ -748,8 +1117,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 3: @@ -781,8 +1148,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 12: @@ -797,8 +1162,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { @@ -810,8 +1173,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 24: @@ -823,8 +1184,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); if (hparams.n_layer == 12 && hparams.n_embd == 768) { @@ -838,8 +1197,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NEO_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 28) { type = LLM_TYPE_250M; @@ -848,8 +1205,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_EUROBERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 12) { type = LLM_TYPE_SMALL; // 0.2B @@ -913,7 +1268,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { // fall through case LLM_ARCH_QWEN2: { - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; @@ -940,8 +1294,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // Set non-causal attention for diffusion models hparams.causal_attn = false; - } - break; + } break; case LLM_ARCH_LLADA: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -955,8 +1308,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // Set non-causal attention for diffusion models hparams.causal_attn = false; - } - break; + } break; case LLM_ARCH_LLADA_MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -995,7 +1347,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3: { - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; @@ -1275,6 +1626,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA4: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; + hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_26B_A4B; break; + case 35: type = LLM_TYPE_E2B; break; + case 42: type = LLM_TYPE_E4B; break; + case 60: type = LLM_TYPE_31B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GEMMA_EMBEDDING: { hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; @@ -1287,7 +1666,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); //applied only if model converted with --sentence-transformers-dense-modules ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); @@ -1587,6 +1965,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); @@ -1623,7 +2002,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // (optional) temperature tuning - used by mistral-large ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); // FIXME why not use temperature_length? hparams.f_attn_temp_offset = 0.0f; @@ -1635,6 +2014,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DEEPSEEK2OCR: + { + // similar to deepseek2, but without MLA + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_PLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1672,6 +2071,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters (GLM-OCR) ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1705,6 +2105,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1751,6 +2152,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1925,6 +2327,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); switch (hparams.n_layer) { case 32: type = LLM_TYPE_30B_A3B; break; @@ -2053,7 +2456,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_embd) { case 768: type = LLM_TYPE_350M; break; - case 1536: type = (hparams.n_embd == 2048 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; + case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; case 2048: case 2560: type = LLM_TYPE_3B; break; case 4096: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -2079,7 +2482,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); } break; case LLM_ARCH_BAILINGMOE: { @@ -2107,6 +2509,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -2197,9 +2600,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) + if (hparams.rope_scaling_alpha > 0.0f) { + const int dim = hparams.n_embd_head_k(); + hparams.rope_freq_base_train = hparams.rope_freq_base_train + * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); + } switch (hparams.n_embd) { case 1024: type = LLM_TYPE_0_5B; break; @@ -2588,11 +3000,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // build a list of buffer types for the CPU and GPU devices pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); - for (auto * dev : devices) { - buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); + for (const auto & dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev.dev, split_mode, tensor_split); // add CPU buffer types as a fallback buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); - pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); + pimpl->gpu_buft_list.emplace(dev.dev, std::move(buft_list)); } ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); @@ -2606,7 +3018,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (all_zero) { // default split, by free memory for (size_t i = 0; i < n_devices(); ++i) { - ggml_backend_dev_t dev = devices[i]; + ggml_backend_dev_t dev = devices[i].dev; size_t total; size_t free; ggml_backend_dev_memory(dev, &free, &total); @@ -2642,7 +3054,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return {cpu_dev, &pimpl->cpu_buft_list}; } const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); - auto * dev = devices.at(layer_gpu); + auto * dev = devices.at(layer_gpu).dev; LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); return {dev, &pimpl->gpu_buft_list.at(dev)}; }; @@ -2708,6 +3120,25 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); } }; + + // helper: try to load merged qkv first, fall back to separate q, k, v + auto create_tensor_qkv = [&](llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags) { + const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (layer.wqkv) { + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + } + }; + switch (arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: @@ -2733,16 +3164,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -2805,7 +3231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -2841,9 +3267,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -2882,9 +3306,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -2928,7 +3350,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); const int64_t n_ff = hparams.n_ff(i); const int64_t n_head = hparams.n_head(i); const int64_t n_head_kv = hparams.n_head_kv(i); @@ -2941,17 +3362,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { else if (n_head_kv > 0) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); } // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); if (n_ff > 0) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3043,9 +3459,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); @@ -3108,9 +3522,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3175,10 +3587,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3211,28 +3623,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - if (!layer.wqkv) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3259,7 +3659,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MODERN_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -3325,9 +3725,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3342,31 +3740,24 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; // JinaBertLayer - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3394,8 +3785,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_BLOOM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -3414,10 +3805,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3450,10 +3841,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); @@ -3490,16 +3881,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors, present in Stable LM 2 1.6B - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - // optional q and k layernorms, present in StableLM 2 12B layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); @@ -3527,7 +3911,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3557,16 +3941,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -3587,16 +3964,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); @@ -3645,9 +4015,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -3678,9 +4046,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -3721,22 +4087,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -3763,7 +4117,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -3793,19 +4147,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); @@ -3832,9 +4176,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -3971,10 +4313,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4006,11 +4348,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4036,9 +4377,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4062,9 +4401,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4086,9 +4423,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4110,9 +4445,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4147,9 +4480,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4175,13 +4506,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0); - per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0); + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -4190,9 +4522,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); @@ -4219,6 +4549,101 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GEMMA4: + { + const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; + const int64_t n_ff_exp = hparams.n_ff_exp; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + if (n_embd_per_layer > 0) { + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0); + } + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_embd_k = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v = hparams.n_embd_v_gqa(i); + const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED); + + if (!hparams.is_swa(i)) { + // full_attention layers use rope_freqs for proportional rope + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + // handle use_double_wide_mlp + int64_t n_ff_cur = hparams.n_ff(i); + + // for expert layers, we use normal FFN as shared expert (same as python code) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // MoE router + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + bool has_expert = layer.ffn_gate_inp != nullptr; + + // norm + if (has_expert) { + layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0); + + layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0); + + // MoE FFN + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + + // per-expert scale will be loaded as down_exps_s at the end of the current switch case + } + + // per-layer embeddings + if (n_embd_per_layer > 0) { + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + } + } + } break; case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4239,16 +4664,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4414,9 +4834,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } else { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); } @@ -4492,14 +4910,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_head_i = hparams.n_head(i); const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } // feed forward (w/ optional biases) @@ -4542,9 +4955,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4572,9 +4983,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -4597,9 +5006,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); @@ -4622,9 +5029,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -4645,9 +5050,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); @@ -4678,14 +5081,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0); + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4709,9 +5107,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); @@ -4778,10 +5174,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4811,9 +5207,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4850,9 +5244,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4883,6 +5275,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: { const bool is_mla = hparams.is_mla(); @@ -4960,6 +5353,60 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_DEEPSEEK2OCR: + { + // similar to deepseek2, but without MLA + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // norm + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Shared expert branch layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); @@ -5145,10 +5592,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5187,10 +5634,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // attention biases - all have shape n_embd (output dimension of projections) - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5218,17 +5665,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - } + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -5261,17 +5698,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - } + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); @@ -5329,12 +5756,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); // GLM-style attention with bias terms - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); @@ -5514,16 +5936,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5590,14 +6007,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_head_i = hparams.n_head(i); const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } else { if (n_expert != 0) { const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; @@ -5645,9 +6057,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -5673,9 +6083,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); @@ -5718,9 +6126,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, flags); + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); @@ -5773,8 +6179,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -5888,8 +6294,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -6044,9 +6450,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6060,8 +6464,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); - conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); // posnet { @@ -6126,8 +6530,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); // convnext { @@ -6175,9 +6579,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6278,9 +6680,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_head_k * n_head, n_embd_head_k * n_head, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6333,9 +6733,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6370,9 +6768,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); // attention projections - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // Q/K normalization @@ -6430,16 +6826,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6519,14 +6910,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { /*ATTENTION LAYERS*/ // attention layers (with optional bias) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0); + create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); @@ -6560,9 +6946,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6580,6 +6964,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6597,9 +6982,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6631,9 +7014,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6658,9 +7039,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); @@ -6670,11 +7049,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - // bias - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); @@ -6722,9 +7097,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); } else { @@ -6756,9 +7129,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -6795,9 +7166,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6841,16 +7210,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); @@ -6874,9 +7238,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -6933,9 +7295,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // q, k, v projections // Python: q_proj, k_proj, v_proj - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); // KDA specific projections // f_a_proj, f_b_proj @@ -7081,16 +7441,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); // weight tensors - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -7147,9 +7502,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7213,9 +7566,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7278,9 +7629,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7319,9 +7668,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); uint32_t n_head = hparams.n_head(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -7380,9 +7727,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); // head-wise attention gate (Step35 self_attn.g_proj) @@ -7426,9 +7771,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -7501,6 +7844,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_in_s && layer.ssm_in) { + layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } if (!layer.ssm_out_s && layer.ssm_out) { layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); } @@ -7510,11 +7856,77 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_s && layer.ssm_beta) { layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + + // input scales + if (!layer.wq_in_s && layer.wq) { + layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_in_s && layer.wk) { + layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_in_s && layer.wv) { + layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_in_s && layer.wo) { + layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_in_s && layer.wqkv) { + layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_in_s && layer.wqkv_gate) { + layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_in_s && layer.ffn_gate) { + layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_in_s && layer.ffn_down) { + layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_in_s && layer.ffn_up) { + layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) { + layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) { + layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_in_in_s && layer.ssm_in) { + layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_in_s && layer.ssm_out) { + layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_in_s && layer.ssm_alpha) { + layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_in_s && layer.ssm_beta) { + layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } } } ml.done_getting_tensors(); + // populate tensors_by_name + for (auto & [_, ctx_ptr] : ml.ctx_map) { + for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { + tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + } + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); pimpl->mappings.reserve(ml.mappings.size()); @@ -7597,14 +8009,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { buf_map.emplace(idx, buf); } } - pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - for (auto & buf : buf_map) { + for (auto & buf : bufs) { // indicate that this buffer contains weights // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight - ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_buffer_set_usage(buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); } + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); + ctx_buf_maps.emplace_back(ctx, buf_map); } @@ -7632,13 +8045,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } - // populate tensors_by_name - for (auto & [ctx, _] : pimpl->ctxs_bufs) { - for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { - tensors_by_name.emplace_back(ggml_get_name(cur), cur); - } - } - if (ml.no_alloc) { return true; } @@ -7683,6 +8089,10 @@ size_t llama_model::n_devices() const { return devices.size(); } +const float * llama_model::tensor_split() const { + return params.tensor_split; +} + uint32_t llama_model::n_gpu_layers() const { return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; } @@ -7801,114 +8211,114 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; - for (auto label : classifier_labels) { + for (const auto & label : classifier_labels) { LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } - } - if (arch == LLM_ARCH_MAMBA || - arch == LLM_ARCH_MAMBA2 || - arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1 || - arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_QWEN35 || - arch == LLM_ARCH_QWEN35MOE || - arch == LLM_ARCH_NEMOTRON_H || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); - } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); - } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); - } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); - } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - } + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - } + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } - if (arch == LLM_ARCH_MINICPM || - arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } - if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); - } + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } - if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } } vocab.print_info(); @@ -8105,7 +8515,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } else { llama_memory_i::layer_reuse_cb reuse = nullptr; - if (arch == LLM_ARCH_GEMMA3N) { + if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { if (il >= (int32_t) hparams.n_layer_kv_from_start) { return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); @@ -8168,9 +8578,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_LLAMA4: { if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique>(*this, params); + llm = std::make_unique>(*this, params); } else { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } } break; case LLM_ARCH_LLAMA_EMBED: @@ -8248,23 +8658,19 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_DREAM: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_LLADA: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_LLADA_MOE: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_RND1: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -8358,6 +8764,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GEMMA4: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_GEMMA_EMBEDDING: { llm = std::make_unique(*this, params); @@ -8424,7 +8834,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_GLM_DSA: + case LLM_ARCH_MISTRAL4: { llm = std::make_unique(*this, params); } break; @@ -8448,11 +8860,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { switch (params.gtype) { case LLM_GRAPH_TYPE_ENCODER: - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); break; case LLM_GRAPH_TYPE_DEFAULT: case LLM_GRAPH_TYPE_DECODER: - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); break; default: GGML_ABORT("invalid graph type"); @@ -8460,9 +8872,8 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_T5ENCODER: { - llm = std::make_unique(*this, params); - } - break; + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_JAIS: { llm = std::make_unique(*this, params); @@ -8574,6 +8985,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { llm = std::make_unique(*this, params); @@ -8823,6 +9235,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: @@ -8836,6 +9249,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: case LLM_ARCH_GLM_DSA: @@ -8874,6 +9288,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: + case LLM_ARCH_GEMMA4: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: @@ -8920,6 +9335,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4_MOE: return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_HUNYUAN_VL: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); @@ -9054,3 +9472,18 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + +int32_t llama_model_n_expert(const struct llama_model * model) { + return model->hparams.n_expert; +} + +int32_t llama_model_n_devices(const struct llama_model * model) { + return (int32_t)model->devices.size(); +} + +ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) { + if (i < 0 || i >= (int)model->devices.size()) { + return nullptr; + } + return model->devices[i].dev; +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 25bf892e7e2..5f101bd6374 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -84,6 +84,7 @@ enum llm_type { LLM_TYPE_26B, LLM_TYPE_27B, LLM_TYPE_30B, + LLM_TYPE_31B, LLM_TYPE_32B, LLM_TYPE_34B, LLM_TYPE_35B, @@ -118,6 +119,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_24B_A2B, // lfm2moe + LLM_TYPE_26B_A4B, // Gemma4 LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, LLM_TYPE_35B_A3B, // Qwen3.5 @@ -244,6 +246,8 @@ struct llama_layer { struct ggml_tensor * wkv_b = nullptr; struct ggml_tensor * wk_b = nullptr; struct ggml_tensor * wv_b = nullptr; + struct ggml_tensor * wqkv_b = nullptr; + struct ggml_tensor * wo_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; @@ -254,13 +258,6 @@ struct llama_layer { struct ggml_tensor * wo_enc = nullptr; struct ggml_tensor * wqkv_gate = nullptr; - // attention bias - struct ggml_tensor * bq = nullptr; - struct ggml_tensor * bk = nullptr; - struct ggml_tensor * bv = nullptr; - struct ggml_tensor * bo = nullptr; - struct ggml_tensor * bqkv = nullptr; - // relative position bias struct ggml_tensor * attn_rel_b = nullptr; struct ggml_tensor * attn_rel_b_enc = nullptr; @@ -270,6 +267,9 @@ struct llama_layer { struct ggml_tensor * ffn_norm = nullptr; struct ggml_tensor * ffn_norm_b = nullptr; struct ggml_tensor * ffn_post_norm = nullptr; + struct ggml_tensor * ffn_post_norm_1 = nullptr; // gemma4 + struct ggml_tensor * ffn_post_norm_2 = nullptr; // gemma4 + struct ggml_tensor * ffn_pre_norm_2 = nullptr; // gemma4 struct ggml_tensor * layer_out_norm = nullptr; struct ggml_tensor * layer_out_norm_b = nullptr; struct ggml_tensor * ffn_norm_exps = nullptr; @@ -285,6 +285,7 @@ struct llama_layer { // ff MoE struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_inp_s = nullptr; // gemma4 struct ggml_tensor * ffn_gate_exps = nullptr; struct ggml_tensor * ffn_down_exps = nullptr; struct ggml_tensor * ffn_up_exps = nullptr; @@ -409,10 +410,32 @@ struct llama_layer { struct ggml_tensor * ffn_gate_shexp_s = nullptr; struct ggml_tensor * ffn_up_shexp_s = nullptr; struct ggml_tensor * ffn_down_shexp_s = nullptr; - struct ggml_tensor * ssm_out_s = nullptr; + struct ggml_tensor * ssm_in_s = nullptr; + struct ggml_tensor * ssm_out_s = nullptr; struct ggml_tensor * ssm_alpha_s = nullptr; struct ggml_tensor * ssm_beta_s = nullptr; + // input scales + struct ggml_tensor * wq_in_s = nullptr; + struct ggml_tensor * wk_in_s = nullptr; + struct ggml_tensor * wv_in_s = nullptr; + struct ggml_tensor * wo_in_s = nullptr; + struct ggml_tensor * wqkv_in_s = nullptr; + struct ggml_tensor * wqkv_gate_in_s = nullptr; + struct ggml_tensor * ffn_gate_in_s = nullptr; + struct ggml_tensor * ffn_up_in_s = nullptr; + struct ggml_tensor * ffn_down_in_s = nullptr; + struct ggml_tensor * ffn_gate_exps_in_s = nullptr; + struct ggml_tensor * ffn_down_exps_in_s = nullptr; + struct ggml_tensor * ffn_up_exps_in_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_in_s= nullptr; + struct ggml_tensor * ffn_up_shexp_in_s = nullptr; + struct ggml_tensor * ffn_down_shexp_in_s= nullptr; + struct ggml_tensor * ssm_in_in_s = nullptr; + struct ggml_tensor * ssm_out_in_s = nullptr; + struct ggml_tensor * ssm_alpha_in_s = nullptr; + struct ggml_tensor * ssm_beta_in_s = nullptr; + // altup & laurel struct ggml_tensor * per_layer_inp_gate = nullptr; struct ggml_tensor * per_layer_proj = nullptr; @@ -461,6 +484,9 @@ struct llama_layer { struct ggml_tensor * indexer_attn_k = nullptr; struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + // gemma4 layer output scale + struct ggml_tensor * out_scale = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -470,6 +496,19 @@ struct llama_layer { struct llama_layer_nextn nextn; }; +struct llama_device { + bool is_meta; + + ggml_backend_dev_t dev; +}; + +struct llama_meta_device_get_split_state_userdata { + size_t n_devices; + const struct llama_model * model; +}; + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata); + struct llama_model { llm_type type = LLM_TYPE_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN; @@ -505,9 +544,9 @@ struct llama_model { struct ggml_tensor * conv1d_b = nullptr; // gemma3n altup - struct ggml_tensor * tok_embd_per_layer = nullptr; struct ggml_tensor * altup_proj = nullptr; struct ggml_tensor * altup_unembd_proj = nullptr; + struct ggml_tensor * per_layer_tok_embd = nullptr; struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; @@ -524,7 +563,7 @@ struct llama_model { std::unordered_map gguf_kv; // list of devices used in this model - std::vector devices; + std::vector devices; // for quantize-stats only std::vector> tensors_by_name; @@ -532,6 +571,9 @@ struct llama_model { // for keeping track of associated LoRA adapters std::unordered_set loras; + // statically allocated context for assigning + struct llama_meta_device_get_split_state_userdata get_split_state_ud; + int64_t t_load_us = 0; int64_t t_start_us = 0; @@ -552,6 +594,7 @@ struct llama_model { size_t size() const; // file size size_t n_tensors() const; size_t n_devices() const; + const float * tensor_split() const; uint32_t n_gpu_layers() const; llama_split_mode split_mode() const; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 8e8ce231249..2f0f70b73b6 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -1,11 +1,11 @@ -#include "llama.h" #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" +#include "llama-ext.h" +#include #include #include -#include #include #include #include @@ -84,7 +84,6 @@ static std::string remap_imatrix(const std::string & orig_name, const std::maptensor_types) { - const auto & tensor_types = *static_cast *>(params->tensor_types); - for (const auto & [tname, qtype] : tensor_types) { - tensor_type_patterns.emplace_back(std::regex(tname), qtype); + if (params->tt_overrides) { + for (const auto * p = params->tt_overrides; p->pattern != nullptr; p++) { + tensor_type_patterns.emplace_back(std::regex(p->pattern), p->type); } } } @@ -199,6 +197,7 @@ struct quantize_state_impl { // per-tensor metadata, computed in the preliminary loop and used in the main loop struct tensor_metadata { + std::string name; ggml_type target_type; tensor_category category; std::string remapped_imatrix_name; @@ -344,7 +343,13 @@ static bool tensor_allows_quantization(const llama_model_quantize_params * param quantize &= name.find("attn_rel_b.weight") == std::string::npos; // do not quantize specific multimodal tensors - quantize &= name.find(".position_embd.") == std::string::npos; + quantize &= name.find(".position_embd") == std::string::npos; + quantize &= name.find("sam.pos_embd") == std::string::npos; + quantize &= name.find("sam.neck.") == std::string::npos; + quantize &= name.find("sam.net_") == std::string::npos; + quantize &= name.find(".rel_pos") == std::string::npos; + quantize &= name.find(".patch_embd") == std::string::npos; + quantize &= name.find(".patch_merger") == std::string::npos; return quantize; } @@ -678,9 +683,9 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_mod LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n", __func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype)); new_type = qtype; - manual = true; - break; } + manual = true; + break; } } } @@ -784,7 +789,7 @@ static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type ds // given a file type, get the default tensor type // -static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { +ggml_type llama_ftype_get_default_type(llama_ftype ftype) { switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0; case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1; @@ -794,6 +799,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_F16: return GGML_TYPE_F16; case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; + case LLAMA_FTYPE_MOSTLY_Q1_0: return GGML_TYPE_Q1_0; case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; @@ -823,16 +829,32 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_S: case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S; - default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); + default: return GGML_TYPE_COUNT; } } + +static void init_quantize_state_counters(quantize_state_impl & qs, std::vector & metadata) { + for (auto & tm : metadata) { + tensor_category cat = tensor_get_category(tm.name); + tm.category = cat; + + if (category_is_attn_v(cat)) { + ++qs.n_attention_wv; + } + + if (cat == tensor_category::OUTPUT) { + qs.has_tied_embeddings = false; + } + } + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; +} + // // main quantization driver // static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { - ggml_type default_type; llama_ftype ftype = params->ftype; int nthread = params->nthread; @@ -841,7 +863,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } - default_type = llama_ftype_get_default_type(ftype); + ggml_type default_type = llama_ftype_get_default_type(ftype); + if (default_type == GGML_TYPE_COUNT) { + throw std::runtime_error(format("invalid output file type %d\n", ftype)); + } // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. @@ -851,15 +876,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: constexpr bool use_mmap = false; #endif - llama_model_kv_override * kv_overrides = nullptr; - if (params->kv_overrides) { - auto * v = (std::vector*)params->kv_overrides; - kv_overrides = v->data(); - } - + const llama_model_kv_override * kv_overrides = params->kv_overrides; std::vector splits = {}; llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr, - fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -873,9 +893,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->only_copy) { ftype = ml.ftype; } + std::unordered_map> i_data; const std::unordered_map> * imatrix_data = nullptr; if (params->imatrix) { - imatrix_data = static_cast>*>(params->imatrix); + for (const llama_model_imatrix_data * p = params->imatrix; p->name != nullptr; p++) { + i_data.emplace(p->name, std::vector(p->data, p->data + p->size)); + } + imatrix_data = & i_data; if (imatrix_data) { LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n", __func__, (int)imatrix_data->size()); @@ -896,7 +920,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector prune_list = {}; if (params->prune_layers) { - prune_list = *static_cast *>(params->prune_layers); + for (const int32_t * p = params->prune_layers; * p != -1; p++) { + prune_list.push_back(* p); + } } // copy the KV pairs from the input file @@ -910,20 +936,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str()); if (params->kv_overrides) { - const std::vector & overrides = *(const std::vector *)params->kv_overrides; - for (const auto & o : overrides) { - if (o.key[0] == 0) break; - if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { + for (const llama_model_kv_override * o = params->kv_overrides; o->key[0] != 0; ++o) { + if (o->tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { + gguf_set_val_f32(ctx_out.get(), o->key, o->val_f64); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_INT) { // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context - gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64)); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { - gguf_set_val_str(ctx_out.get(), o.key, o.val_str); + gguf_set_val_u32(ctx_out.get(), o->key, (uint32_t)std::abs(o->val_i64)); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { + gguf_set_val_bool(ctx_out.get(), o->key, o->val_bool); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out.get(), o->key, o->val_str); } else { - LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); + LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o->key); } } } @@ -961,6 +985,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } + // compute tensor metadata once and cache it + std::vector metadata(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + metadata[i].name = ggml_get_name(tensors[i]->tensor); + } + + // initialize quantization state counters and metadata categories + init_quantize_state_counters(qs, metadata); + int idx = 0; uint16_t n_split = 1; @@ -973,25 +1006,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector ctx_outs(n_split); ctx_outs[0] = std::move(ctx_out); - // compute tensor metadata once and cache it - std::vector metadata(tensors.size()); - - // initialize quantization state before preliminary loop (counters for use_more_bits) - { - for (size_t i = 0; i < tensors.size(); ++i) { - const auto cat = tensor_get_category(tensors[i]->tensor->name); - if (category_is_attn_v(cat)) { - ++qs.n_attention_wv; - } - if (cat == tensor_category::OUTPUT) { - qs.has_tied_embeddings = false; - } - metadata[i].category = cat; // save and re-use the category while we're at it - } - // these also need to be set to n_layer by default - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; - } - // flag for --dry-run bool will_require_imatrix = false; @@ -1002,7 +1016,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: for (size_t i = 0; i < tensors.size(); ++i) { const auto * it = tensors[i]; const struct ggml_tensor * tensor = it->tensor; - const std::string name = ggml_get_name(tensor); uint16_t i_split = params->keep_split ? it->idx : 0; if (!ctx_outs[i_split]) { @@ -1031,7 +1044,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: " - offending tensor: %s\n" " - target type: %s\n" "============================================================================\n\n", - name.c_str(), ggml_type_name(metadata[i].target_type)); + metadata[i].name.c_str(), ggml_type_name(metadata[i].target_type)); throw std::runtime_error("this quantization requires an imatrix!"); } } @@ -1104,7 +1117,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_ofstream(weight.idx); } - const std::string name = ggml_get_name(tensor); const size_t tensor_size = ggml_nbytes(tensor); if (!params->dry_run) { @@ -1235,9 +1247,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: total_size_new += new_size; // update the gguf meta data as we go - gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); - gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + gguf_set_tensor_type(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), metadata[i].name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_data); // write tensor data + padding fout.write((const char *) new_data, new_size); @@ -1271,7 +1283,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: llama_model_quantize_params llama_model_quantize_default_params() { llama_model_quantize_params result = { /*.nthread =*/ 0, - /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q8_0, /*.output_tensor_type =*/ GGML_TYPE_COUNT, /*.token_embedding_type =*/ GGML_TYPE_COUNT, /*.allow_requantize =*/ false, @@ -1302,3 +1314,89 @@ uint32_t llama_model_quantize( return 0; } + +// +// Helper functions for external tools exposed in llama-ext.h +// + +quantize_state_impl * llama_quant_init( + const llama_model * model, + const llama_model_quantize_params * params) { + return new quantize_state_impl(*model, params); +} + +void llama_quant_free(quantize_state_impl * qs) { + delete qs; +} + +llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) { + struct llama_model_params mparams = llama_model_default_params(); + auto * model = new llama_model(mparams); + + model->arch = llm_arch_from_string(desc->architecture); + + // infer llm_type: only LLM_TYPE_70B matters for quantization logic + if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) { + model->type = LLM_TYPE_70B; + } + + model->hparams.n_embd = desc->n_embd; + model->hparams.n_embd_head_k_full = desc->n_embd_head_k; + model->hparams.n_embd_head_v_full = desc->n_embd_head_v; + model->hparams.n_layer = desc->n_layer; + model->hparams.n_expert = desc->n_expert; + + for (uint32_t i = 0; i < desc->n_layer; i++) { + model->hparams.n_head_arr[i] = desc->n_head; + model->hparams.n_head_kv_arr[i] = desc->n_head_kv; + model->hparams.n_ff_arr[i] = desc->n_ff; + } + + return model; +} + +bool llama_quant_tensor_allows_quantization( + const quantize_state_impl * qs, + const ggml_tensor * tensor) { + return tensor_allows_quantization(qs->params, qs->model.arch, tensor); +} + +void llama_quant_compute_types( + quantize_state_impl * qs, + llama_ftype ftype, + ggml_tensor ** tensors, + ggml_type * result_types, + size_t n_tensors) { + // reset per-computation state + qs->n_attention_wv = 0; + qs->n_ffn_down = 0; + qs->n_ffn_gate = 0; + qs->n_ffn_up = 0; + qs->i_attention_wv = 0; + qs->i_ffn_down = 0; + qs->i_ffn_gate = 0; + qs->i_ffn_up = 0; + qs->n_fallback = 0; + qs->has_imatrix = false; + qs->has_tied_embeddings = true; + + // build metadata from tensor names + std::vector metadata(n_tensors); + for (size_t i = 0; i < n_tensors; i++) { + metadata[i].name = ggml_get_name(tensors[i]); + } + + // initialize counters and categories + init_quantize_state_counters(*qs, metadata); + + // use a local copy of params with the requested ftype + llama_model_quantize_params local_params = *qs->params; + local_params.ftype = ftype; + + ggml_type default_type = llama_ftype_get_default_type(ftype); + + // compute types + for (size_t i = 0; i < n_tensors; i++) { + result_types[i] = llama_tensor_get_type(*qs, &local_params, tensors[i], default_type, metadata[i]); + } +} diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 68ba292d426..163f222ef61 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -493,6 +493,16 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GEMMA4: + // Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the + // normalizer, then BPE merges run on the whole text without + // word-level pre-splitting. We only need to split on newlines + // since BPE merge lookup asserts no newlines in tokens. + regex_exprs = { + "[^\\n]+|[\\n]+", + }; + byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -506,6 +516,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { } std::vector regex_exprs; + bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8) }; struct llm_tokenizer_bpe_session { @@ -550,9 +561,10 @@ struct llm_tokenizer_bpe_session { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs); + const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); symbols_final.clear(); + auto tok_pre = vocab.get_pre_type(); for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); @@ -565,6 +577,13 @@ struct llm_tokenizer_bpe_session { if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); offset = word.size(); + } else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) { + // fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343 + auto tok = vocab.text_to_token(word); + if (tok != LLAMA_TOKEN_NULL) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } } while (offset < word.size()) { @@ -640,8 +659,17 @@ struct llm_tokenizer_bpe_session { if (token == LLAMA_TOKEN_NULL) { for (auto j = str.begin(); j != str.end(); ++j) { - std::string byte_str(1, *j); - auto token_multibyte = vocab.text_to_token(byte_str); + llama_token token_multibyte = LLAMA_TOKEN_NULL; + if (tokenizer.byte_encode) { + std::string byte_str(1, *j); + token_multibyte = vocab.text_to_token(byte_str); + } else { + // For non-byte-encoded BPE (e.g. gemma-4), byte tokens use <0xXX> format + static const char * hex = "0123456789ABCDEF"; + const uint8_t ch = (uint8_t)*j; + const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; + token_multibyte = vocab.text_to_token(buf); + } if (token_multibyte != LLAMA_TOKEN_NULL) { output.push_back(token_multibyte); } @@ -1863,6 +1891,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = 3; // <|plamo:pad|> special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "gemma4") { + type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + } + + // default special tokens (to be read from GGUF) + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + + tokenizer_pre = "gemma4"; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -1870,6 +1934,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // for now, only BPE models have pre-tokenizers if (type == LLAMA_VOCAB_TYPE_BPE) { add_space_prefix = false; + escape_whitespaces = false; clean_spaces = true; if (tokenizer_pre.empty()) { LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); @@ -1936,6 +2001,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "jais-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; + } else if ( + tokenizer_pre == "gemma4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; + escape_whitespaces = true; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || @@ -1952,7 +2021,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "qwen2" || tokenizer_pre == "deepseek-r1-qwen" || - tokenizer_pre == "kormo") { + tokenizer_pre == "kormo" || + tokenizer_pre == "f2llmv2") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; } else if ( @@ -2129,19 +2199,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { throw std::runtime_error("cannot find tokenizer vocab in model file\n"); } + const uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); + const float * scores = nullptr; const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); if (score_idx != -1) { + const uint32_t n_scores = gguf_get_arr_n(ctx, score_idx); + if (n_scores < n_tokens) { + throw std::runtime_error("Index out of array bounds for scores (" + std::to_string(n_scores) + " < " + std::to_string(n_tokens) + ")\n"); + } scores = (const float * ) gguf_get_arr_data(ctx, score_idx); } const int * toktypes = nullptr; const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); if (toktype_idx != -1) { + const uint32_t n_toktypes = gguf_get_arr_n(ctx, toktype_idx); + if (n_toktypes < n_tokens) { + throw std::runtime_error("Index out of array bounds for toktypes (" + std::to_string(n_toktypes) + " < " + std::to_string(n_tokens) + ")\n"); + } toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); } - uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); id_to_token.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { @@ -2255,6 +2334,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) { add_sep = temp; } + + // workaround for Gemma 4 + // ref: https://github.com/ggml-org/llama.cpp/pull/21500 + if (pre_type == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && !add_bos) { + add_bos = true; + + LLAMA_LOG_WARN("%s: override '%s' to 'true' for Gemma4\n", __func__, kv(LLM_KV_TOKENIZER_ADD_BOS).c_str()); + } } // auto-detect special tokens by text @@ -2480,6 +2567,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "[EOS]" // Kimi-K2 || t.first == "<|end_of_text|>" || t.first == "" // smoldocling + || t.first == "" // gemma4 + || t.first == "" // gemma4 + || t.first == "<|tool_response>" // gemma4 + || t.first == "<|end▁of▁sentence|>" // deepseek-ocr ) { special_eog_ids.insert(t.second); if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { @@ -2564,6 +2655,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } } + + // workaround for gemma4 and paddleocr: do not include as an eog token + { + bool has_tool_response = false; + bool has_s = false; + + llama_token s_id = LLAMA_TOKEN_NULL; + + for (auto tid : special_eog_ids) { + const auto & text = id_to_token[tid].text; + if (text == "<|tool_response>") { + has_tool_response = true; + } else if (text == "") { + has_s = true; + s_id = tid; + } + } + + if (has_tool_response && has_s) { + special_eog_ids.erase(s_id); + + auto & attr = id_to_token[s_id].attr; + attr = LLAMA_TOKEN_ATTR_NORMAL; + + LLAMA_LOG_WARN("%s: special_eog_ids contains '<|tool_response>', removing '' token from EOG list\n", __func__); + } + } } // build special tokens cache @@ -2732,7 +2850,9 @@ uint8_t llama_vocab::impl::token_to_byte(llama_token id) const { return strtol(buf.c_str(), NULL, 16); } case LLAMA_VOCAB_TYPE_BPE: { - GGML_ABORT("fatal error"); + // Gemma4 uses BPE with SPM-style byte fallback tokens (<0xXX>) + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); } case LLAMA_VOCAB_TYPE_WPM: { GGML_ABORT("fatal error"); @@ -3021,6 +3141,10 @@ std::vector llama_vocab::impl::tokenize( if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + if (escape_whitespaces) { + llama_escape_whitespace(text); + } + #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif @@ -3200,9 +3324,19 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t return _try_copy(token_text.data(), token_text.size()); } if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + if (escape_whitespaces) { + // SPM-style BPE: tokens contain ▁ for spaces + std::string result = token_text; + llama_unescape_whitespace(result); + return _try_copy(result.data(), result.size()); + } std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { + char byte = (char) token_to_byte(token); + return _try_copy((char*) &byte, 1); + } break; } case LLAMA_VOCAB_TYPE_RWKV: { @@ -3630,9 +3764,7 @@ int llama_vocab::max_token_len() const { int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { GGML_ASSERT(token_left.find(' ') == std::string::npos); - GGML_ASSERT(token_left.find('\n') == std::string::npos); GGML_ASSERT(token_right.find(' ') == std::string::npos); - GGML_ASSERT(token_right.find('\n') == std::string::npos); auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right)); if (it == pimpl->bpe_ranks.end()) { diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index be5b08012df..dd38f45d3a2 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -58,6 +58,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 872e659edca..e9c3028585d 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -1,6 +1,5 @@ #include "llama.h" -#include "ggml-cpp.h" #include "llama-impl.h" #include "llama-chat.h" @@ -12,6 +11,7 @@ #include "llama-model.h" #include "ggml.h" +#include "ggml-cpp.h" #include "ggml-backend.h" #include "gguf.h" @@ -24,6 +24,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -45,722 +46,6 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } -struct llama_device_memory_data { - int64_t total; - int64_t free; - llama_memory_breakdown_data mb; -}; - -static std::vector llama_get_device_memory_data( - const char * path_model, const llama_model_params * mparams, const llama_context_params * cparams, - std::vector & devs, uint32_t & hp_ngl, uint32_t & hp_n_ctx_train, uint32_t & hp_n_expert, - const ggml_log_level log_level) { - struct user_data_t { - struct { - ggml_log_callback callback; - void * user_data; - } original_logger; - ggml_log_level min_level; // prints below this log level go to debug log - }; - user_data_t ud; - llama_log_get(&ud.original_logger.callback, &ud.original_logger.user_data); - ud.min_level = log_level; - - llama_log_set([](ggml_log_level level, const char * text, void * user_data) { - const user_data_t * ud = (const user_data_t *) user_data; - const ggml_log_level level_eff = level >= ud->min_level ? level : GGML_LOG_LEVEL_DEBUG; - ud->original_logger.callback(level_eff, text, ud->original_logger.user_data); - }, &ud); - - llama_model_params mparams_copy = *mparams; - mparams_copy.no_alloc = true; - mparams_copy.use_mmap = false; - mparams_copy.use_mlock = false; - - llama_model * model = llama_model_load_from_file(path_model, mparams_copy); - if (model == nullptr) { - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to load model"); - } - - llama_context * ctx = llama_init_from_model(model, *cparams); - if (ctx == nullptr) { - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to create llama_context from model"); - } - - std::vector ret(model->devices.size()); - - std::map memory_breakdown = ctx->memory_breakdown(); - - for (const auto & [buft, mb] : memory_breakdown) { - if (ggml_backend_buft_is_host(buft)) { - continue; - } - - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (!dev) { - continue; - } - for (size_t i = 0; i < ret.size(); i++) { - if (model->devices[i] == dev) { - ret[i].mb.model += mb.model; - ret[i].mb.context += mb.context; - ret[i].mb.compute += mb.compute; - break; - } - } - } - for (size_t i = 0; i < ret.size(); i++) { - size_t free; - size_t total; - ggml_backend_dev_memory(model->devices[i], &free, &total); - - // devices can return 0 bytes for free and total memory if they do not - // have any to report. in this case, we will use the host memory as a fallback - // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 - if (free == 0 && total == 0) { - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - ggml_backend_dev_memory(cpu_dev, &free, &total); - } - ret[i].free = free; - ret[i].total = total; - } - - devs = model->devices; - hp_ngl = model->hparams.n_layer; - hp_n_ctx_train = model->hparams.n_ctx_train; - hp_n_expert = model->hparams.n_expert; - - llama_memory_breakdown_print(ctx); // goes to debug log - - llama_free(ctx); - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - return ret; -} - -// enum to identify part of a layer for distributing its tensors: -enum layer_fraction_t { - LAYER_FRACTION_NONE = 0, // nothing - LAYER_FRACTION_ATTN = 1, // attention - LAYER_FRACTION_UP = 2, // attention + up - LAYER_FRACTION_GATE = 3, // attention + up + gate - LAYER_FRACTION_MOE = 4, // everything but sparse MoE weights -}; -// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue - -class llama_params_fit_exception : public std::runtime_error { - using std::runtime_error::runtime_error; -}; - -static void llama_params_fit_impl( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { - constexpr int64_t MiB = 1024*1024; - typedef std::vector dmds_t; - const llama_model_params default_mparams = llama_model_default_params(); - - std::vector devs; - uint32_t hp_ngl = 0; // hparams.n_gpu_layers - uint32_t hp_nct = 0; // hparams.n_ctx_train - uint32_t hp_nex = 0; // hparams.n_expert - - // step 1: get data for default parameters and check whether any changes are necessary in the first place - - LLAMA_LOG_DEBUG("%s: getting device memory data for initial parameters:\n", __func__); - const dmds_t dmds_full = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - const size_t nd = devs.size(); // number of devices - if (nd == 0) { - LLAMA_LOG_INFO("%s: no devices with dedicated memory found\n", __func__); - return; - } - - std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits - margins.reserve(nd); - for (size_t id = 0; id < nd; id++) { - margins.push_back(margins_s[id]); - } - - std::vector dev_names; - { - dev_names.reserve(nd); - size_t max_length = 0; - for (ggml_backend_dev_t dev : devs) { - std::string name = ggml_backend_dev_name(dev); - name += " ("; - name += ggml_backend_dev_description(dev); - name += ")"; - dev_names.push_back(name); - max_length = std::max(max_length, name.length()); - } - for (std::string & dn : dev_names) { - dn.insert(dn.end(), max_length - dn.length(), ' '); - } - } - - int64_t sum_free = 0; - int64_t sum_projected_free = 0; - int64_t sum_projected_used = 0; - int64_t sum_projected_model = 0; - std::vector projected_free_per_device; - projected_free_per_device.reserve(nd); - - if (nd > 1) { - LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); - } - for (size_t id = 0; id < nd; id++) { - const llama_device_memory_data & dmd = dmds_full[id]; - - const int64_t projected_used = dmd.mb.total(); - const int64_t projected_free = dmd.free - projected_used; - projected_free_per_device.push_back(projected_free); - - sum_free += dmd.free; - sum_projected_used += projected_used; - sum_projected_free += projected_free; - sum_projected_model += dmd.mb.model; - - if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); - } - } - assert(sum_free >= 0 && sum_projected_used >= 0); - LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", - __func__, sum_projected_used/MiB, sum_free/MiB); - if (nd == 1) { - if (projected_free_per_device[0] >= margins[0]) { - LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); - return; - } - } else { - bool changes_needed = false; - for (size_t id = 0; id < nd; id++) { - if (projected_free_per_device[id] < margins[id]) { - changes_needed = true; - break; - } - } - if (!changes_needed) { - LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); - return; - } - } - - // step 2: try reducing memory use by reducing the context size - - { - int64_t global_surplus = sum_projected_free; - for (size_t id = 0; id < nd; id++) { - global_surplus -= margins[id]; - } - if (global_surplus < 0) { - if (nd == 1) { - LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", - __func__, margins[0]/MiB, -global_surplus/MiB); - } else { - LLAMA_LOG_INFO( - "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", - __func__, -global_surplus/MiB); - } - if (cparams->n_ctx == 0) { - if (hp_nct > n_ctx_min) { - int64_t sum_used_target = sum_free; - for (size_t id = 0; id < nd; id++) { - sum_used_target -= margins[id]; - } - if (nd > 1) { - // for multiple devices we need to be more conservative in terms of how much context we think can fit: - // - for dense models only whole layers can be assigned to devices - // - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer - // - on average we expect a waste of 0.5 layers/tensors per device - // - use slightly more than the expected average for nd devices to be safe - const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl); - sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6); - } - - int64_t sum_projected_used_min_ctx = 0; - cparams->n_ctx = n_ctx_min; - const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const auto & dmd : dmds_min_ctx) { - sum_projected_used_min_ctx += dmd.mb.total(); - } - if (sum_used_target > sum_projected_used_min_ctx) { - // linear interpolation between minimum and maximum context size: - cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx) - / (sum_projected_used - sum_projected_used_min_ctx); - cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend - - const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min); - const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - if (nd == 1) { - LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); - return; - } - LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__); - } else { - const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - } - } else { - if (n_ctx_min == UINT32_MAX) { - LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct); - } else { - LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", - __func__, hp_nct, n_ctx_min); - } - } - } else { - LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); - } - } - } - - if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { - throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); - } - if (nd > 1) { - if (!tensor_split) { - throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort"); - } - if (mparams->tensor_split) { - for (size_t id = 0; id < nd; id++) { - if (mparams->tensor_split[id] != 0.0f) { - throw llama_params_fit_exception("model_params::tensor_split already set by user, abort"); - } - } - } - if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { - throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); - } - } - if (!tensor_buft_overrides) { - throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort"); - } - if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) { - throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort"); - } - - // step 3: iteratively fill the back to front with "dense" layers - // - for a dense model simply fill full layers, giving each device a contiguous slice of the model - // - for a MoE model, same as dense model but with all MoE tensors in system memory - - // utility function that returns a static C string matching the tensors for a specific layer index and layer fraction: - auto get_overflow_pattern = [&](const size_t il, const layer_fraction_t lf) -> const char * { - constexpr size_t n_strings = 1000; - if (il >= n_strings) { - throw std::runtime_error("at most " + std::to_string(n_strings) + " model layers are supported"); - } - switch (lf) { - case LAYER_FRACTION_ATTN: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|gate|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_UP: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_GATE: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_down.*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_MOE: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate)_(ch|)exps"; - } - return patterns[il].c_str(); - } - default: - GGML_ABORT("fatal error"); - } - }; - - struct ngl_t { - uint32_t n_layer = 0; // number of total layers - uint32_t n_part = 0; // number of partial layers, <= n_layer - - // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: - layer_fraction_t overflow_type = LAYER_FRACTION_MOE; - - uint32_t n_full() const { - assert(n_layer >= n_part); - return n_layer - n_part; - } - }; - - const size_t ntbo = llama_max_tensor_buft_overrides(); - - // utility function to set n_gpu_layers and tensor_split - auto set_ngl_tensor_split_tbo = [&]( - const std::vector & ngl_per_device, - const std::vector & overflow_bufts, - llama_model_params & mparams) { - mparams.n_gpu_layers = 0; - for (size_t id = 0; id < nd; id++) { - mparams.n_gpu_layers += ngl_per_device[id].n_layer; - if (nd > 1) { - tensor_split[id] = ngl_per_device[id].n_layer; - } - } - assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1); - uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides - - mparams.tensor_split = tensor_split; - - size_t itbo = 0; - for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_full(); - for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { - if (itbo + 1 >= ntbo) { - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == " - + std::to_string(ntbo) + " is insufficient for model"); - } - tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); - itbo++; - } - il0 += ngl_per_device[id].n_part; - } - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - }; - - // utility function that returns the memory use per device for given numbers of layers per device - auto get_memory_for_layers = [&]( - const char * func_name, - const std::vector & ngl_per_device, - const std::vector & overflow_bufts) -> std::vector { - llama_model_params mparams_copy = *mparams; - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy); - - const dmds_t dmd_nl = llama_get_device_memory_data( - path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - LLAMA_LOG_DEBUG("%s: memory for test allocation by device:\n", func_name); - for (size_t id = 0; id < nd; id++) { - const ngl_t & n = ngl_per_device[id]; - LLAMA_LOG_DEBUG( - "%s: id=%zu, n_layer=%2" PRIu32 ", n_part=%2" PRIu32 ", overflow_type=%d, mem=%6" PRId64 " MiB\n", - func_name, id, n.n_layer, n.n_part, int(n.overflow_type), dmd_nl[id].mb.total()/MiB); - } - - std::vector ret; - ret.reserve(nd); - for (const llama_device_memory_data & dmd : dmd_nl) { - ret.push_back(dmd.mb.total()); - } - return ret; - }; - - int64_t global_surplus_cpu_moe = 0; - if (hp_nex > 0) { - const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate)_(ch|)exps"; // matches all MoE tensors - ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); - tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft}; - tensor_buft_overrides[1] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - - LLAMA_LOG_DEBUG("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__); - const dmds_t dmds_cpu_moe = llama_get_device_memory_data( - path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - for (size_t id = 0; id < nd; id++) { - global_surplus_cpu_moe += dmds_cpu_moe[id].free; - global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; - } - - if (global_surplus_cpu_moe > 0) { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is a total surplus of %" PRId64 " MiB\n", - __func__, global_surplus_cpu_moe/MiB); - } else { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is still a total deficit of %" PRId64 " MiB\n", - __func__, -global_surplus_cpu_moe/MiB); - } - - // reset - tensor_buft_overrides[0] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - } - - std::vector targets; // maximum acceptable memory use per device - targets.reserve(nd); - for (size_t id = 0; id < nd; id++) { - targets.push_back(dmds_full[id].free - margins[id]); - LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); - } - - std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: - overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd; id++) { - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); - } - - std::vector ngl_per_device(nd); - std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - - // optimize the number of layers per device using the method of false position: - // - ngl_per_device has 0 layers for each device, lower bound - // - try a "high" configuration where a device is given all unassigned layers - // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target - // - check memory use of our guess, replace either the low or high bound - // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits - // - the last device has the output layer, which cannot be a partial layer - if (hp_nex == 0) { - LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__); - } else { - LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__); - } - for (int id = nd - 1; id >= 0; id--) { - uint32_t n_unassigned = hp_ngl + 1; - for (size_t jd = id + 1; jd < nd; ++jd) { - assert(n_unassigned >= ngl_per_device[jd].n_layer); - n_unassigned -= ngl_per_device[jd].n_layer; - } - - std::vector ngl_per_device_high = ngl_per_device; - ngl_per_device_high[id].n_layer = n_unassigned; - if (hp_nex > 0) { - ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1; - } - if (ngl_per_device_high[id].n_layer > 0) { - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); - uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector ngl_per_device_test = ngl_per_device; - ngl_per_device_test[id].n_layer += step_size; - if (hp_nex) { - ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? - step_size - 1 : step_size; // the first layer is the output layer which must always be full - } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer); - } - delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - } - } else { - assert(ngl_per_device_high[id].n_layer == n_unassigned); - ngl_per_device = ngl_per_device_high; - mem = mem_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers, %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); - } - if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); - return; - } - - // step 4: for a MoE model where all dense tensors fit, - // convert the dense-only layers in the back to full layers in the front until all devices are full - // essentially the same procedure as for the dense-only layers except front-to-back - // also, try fitting at least part of one more layer to reduce waste for "small" GPUs with e.g. 24 GiB VRAM - - size_t id_dense_start = nd; - for (int id = nd - 1; id >= 0; id--) { - if (ngl_per_device[id].n_layer > 0) { - id_dense_start = id; - continue; - } - break; - } - assert(id_dense_start < nd); - - LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { - std::vector ngl_per_device_high = ngl_per_device; - for (size_t jd = id_dense_start; jd < nd; jd++) { - const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; - ngl_per_device_high[id].n_layer += n_layer_move; - ngl_per_device_high[jd].n_layer -= n_layer_move; - ngl_per_device_high[jd].n_part = 0; - } - size_t id_dense_start_high = nd - 1; - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - uint32_t n_converted_test = 0; - for (;id_dense_start_test < nd; id_dense_start_test++) { - const uint32_t n_convert_jd = std::min(step_size - n_converted_test, ngl_per_device_test[id_dense_start_test].n_part); - ngl_per_device_test[id_dense_start_test].n_layer -= n_convert_jd; - ngl_per_device_test[id_dense_start_test].n_part -= n_convert_jd; - ngl_per_device_test[id].n_layer += n_convert_jd; - n_converted_test += n_convert_jd; - - if (ngl_per_device_test[id_dense_start_test].n_part > 0) { - break; - } - } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - id_dense_start_high = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", - __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); - } - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - } - } else { - ngl_per_device = ngl_per_device_high; - mem = mem_high; - id_dense_start = id_dense_start_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - - // try to fit at least part of one more layer - if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) { - std::vector ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - ngl_per_device_test[id_dense_start_test].n_layer--; - ngl_per_device_test[id_dense_start_test].n_part--; - ngl_per_device_test[id].n_layer++; - ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_part == 0) { - id_dense_start_test++; - } - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; - std::vector overflow_bufts_test = overflow_bufts; - if (id < nd - 1) { - overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); - } - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } else { - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - // print info for devices that were not changed during the conversion from dense only to full layers: - for (size_t id = id_dense_start + 1; id < nd; id++) { - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); -} - -enum llama_params_fit_status llama_params_fit( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) { - const int64_t t0_us = llama_time_us(); - llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS; - try { - llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); - LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); - } catch (const llama_params_fit_exception & e) { - LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_FAILURE; - } catch (const std::runtime_error & e) { - LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_ERROR; - } - const int64_t t1_us = llama_time_us(); - LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6); - return status; -} - struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { /*.no_perf =*/ true, @@ -828,7 +113,7 @@ int64_t llama_time_us(void) { // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, - const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { + const std::string & fname, std::vector & splits, FILE * file, llama_model & model, llama_model_params & params) { // loading time will be recalculated after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = 0; @@ -837,7 +122,7 @@ static int llama_model_load(struct gguf_context * metadata, llama_model_set_tens model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, params.use_mmap, params.use_direct_io, + llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); @@ -889,8 +174,24 @@ static struct llama_model * llama_model_load_from_file_impl( void * set_tensor_data_ud, const std::string & path_model, std::vector & splits, + FILE * file, struct llama_model_params params) { - GGML_ASSERT((metadata == nullptr) != path_model.empty() && "exactly one out of metadata and path_model needs to be defined"); + { + int n_sources_defined = 0; + if (metadata != nullptr) { + n_sources_defined++; + } + if (!path_model.empty()) { + n_sources_defined++; + } + if (file != nullptr) { + n_sources_defined++; + } + if (n_sources_defined != 1) { + LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__); + return nullptr; + } + } ggml_time_init(); if (!params.vocab_only && ggml_backend_reg_count() == 0) { @@ -919,58 +220,111 @@ static struct llama_model * llama_model_load_from_file_impl( // create list of devices to use with this model if (params.devices) { - for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { - model->devices.push_back(*dev); + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + size_t n_devs = 0; + while (params.devices[n_devs]) { + n_devs++; + } + if (n_devs == 0) { + LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); + return nullptr; + } + LLAMA_LOG_INFO("%s: creating a Meta device with %zu devices\n", __func__, n_devs); + for (size_t i = 0; i < n_devs; ++i) { + LLAMA_LOG_INFO("%s: - device %zu: %s\n", __func__, i, ggml_backend_dev_name(params.devices[i])); + } + model->get_split_state_ud.n_devices = n_devs; + model->get_split_state_ud.model = model; + model->devices.push_back({ + true, ggml_backend_meta_device( + params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud) + }); + } else { + for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { + model->devices.push_back({false, *dev}); + } } } else { // default device selection // build list of available devices - std::vector gpus; - std::vector igpus; - std::vector rpc_servers; - - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - switch (ggml_backend_dev_type(dev)) { - case GGML_BACKEND_DEVICE_TYPE_CPU: - case GGML_BACKEND_DEVICE_TYPE_ACCEL: - // skip CPU backends since they are handled separately - break; - - case GGML_BACKEND_DEVICE_TYPE_GPU: { - ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - if (ggml_backend_reg_name(reg) == std::string("RPC")) { - rpc_servers.push_back(dev); - } else { - // check if there is already a GPU with the same device id - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) { - ggml_backend_dev_props d_props; - ggml_backend_dev_get_props(d, &d_props); - if (props.device_id && d_props.device_id) { - return strcmp(props.device_id, d_props.device_id) == 0; - } - return false; - }); - - if (it != gpus.end()) { - LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", - __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), - props.device_id ? props.device_id : "unknown id", - ggml_backend_dev_name(*it), ggml_backend_dev_description(*it)); + std::vector gpus; + std::vector igpus; + std::vector rpc_servers; + + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + std::vector devs; + devs.reserve(ggml_backend_dev_count()); + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_buffer_type(dev) == ggml_backend_cpu_buffer_type()) { + LLAMA_LOG_INFO("%s: skipping %s (%s) for tensor parallelism\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev)); + continue; + } + devs.push_back(dev); + } + if (devs.empty()) { + LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); + return nullptr; + } + + LLAMA_LOG_INFO("%s: creating a Meta device for tensor parallelism from %zu devices:\n", __func__, devs.size()); + for (size_t i = 0; i < devs.size(); ++i) { + LLAMA_LOG_INFO("%s: - device %zu: %s (%s)\n", __func__, i, ggml_backend_dev_name(devs[i]), ggml_backend_dev_description(devs[i])); + } + + GGML_ASSERT(!devs.empty()); + model->get_split_state_ud.n_devices = devs.size(); + model->get_split_state_ud.model = model; + gpus.push_back({ + true, ggml_backend_meta_device( + devs.data(), devs.size(), llama_meta_device_get_split_state, &model->get_split_state_ud) + }); + } else { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + switch (ggml_backend_dev_type(dev)) { + case GGML_BACKEND_DEVICE_TYPE_CPU: + case GGML_BACKEND_DEVICE_TYPE_ACCEL: + // skip CPU backends since they are handled separately + break; + + case GGML_BACKEND_DEVICE_TYPE_GPU: { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (ggml_backend_reg_name(reg) == std::string("RPC")) { + rpc_servers.push_back({false, dev}); } else { - gpus.push_back(dev); + // check if there is already a GPU with the same device id + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + auto it = std::find_if(gpus.begin(), gpus.end(), [&props](const llama_device & d) { + ggml_backend_dev_props d_props; + ggml_backend_dev_get_props(d.dev, &d_props); + if (props.device_id && d_props.device_id) { + return strcmp(props.device_id, d_props.device_id) == 0; + } + return false; + }); + + if (it != gpus.end()) { + LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", + __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + ggml_backend_dev_name(it->dev), ggml_backend_dev_description(it->dev)); + } else { + gpus.push_back({false, dev}); + } } + break; } - break; - } - case GGML_BACKEND_DEVICE_TYPE_IGPU: - igpus.push_back(dev); - break; + case GGML_BACKEND_DEVICE_TYPE_IGPU: + igpus.push_back({false, dev}); + break; + case GGML_BACKEND_DEVICE_TYPE_META: + GGML_ABORT("fatal error"); + } } } @@ -996,22 +350,22 @@ static struct llama_model * llama_model_load_from_file_impl( llama_model_free(model); return nullptr; } - ggml_backend_dev_t main_gpu = model->devices[params.main_gpu]; + llama_device main_gpu = model->devices[params.main_gpu]; model->devices.clear(); model->devices.push_back(main_gpu); } } - for (auto * dev : model->devices) { + for (const auto & dev : model->devices) { ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); + ggml_backend_dev_get_props(dev.dev, &props); LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + ggml_backend_dev_name(dev.dev), ggml_backend_dev_description(dev.dev), props.device_id ? props.device_id : "unknown id", props.memory_free/1024/1024); } - const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, *model, params); + const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, *model, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -1037,7 +391,7 @@ struct llama_model * llama_model_init_from_user( std::vector splits = {}; params.use_mmap = false; params.use_extra_bufts = false; - return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, params); + return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, /*file*/ nullptr, params); } // deprecated struct llama_model * llama_load_model_from_file( @@ -1050,7 +404,7 @@ struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params) { std::vector splits = {}; - return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, /*file*/ nullptr, params); } struct llama_model * llama_model_load_from_splits( @@ -1066,7 +420,17 @@ struct llama_model * llama_model_load_from_splits( for (size_t i = 0; i < n_paths; ++i) { splits.push_back(paths[i]); } - return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, /*file*/ nullptr, params); +} + +struct llama_model * llama_model_load_from_file_ptr(FILE * file, struct llama_model_params params) { + if (!file) { + LLAMA_LOG_ERROR("%s: file is NULL\n", __func__); + return nullptr; + } + std::string path_model; + std::vector splits = {}; + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, file, params); } void llama_model_save_to_file(const struct llama_model * model, const char * path_model) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index c6e102abe51..eb869814097 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -154,6 +154,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q1_0 = 40, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -191,9 +192,10 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_TENSOR = 3, }; // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -380,22 +382,33 @@ extern "C" { size_t n_samplers; }; + struct llama_model_tensor_override { + const char * pattern; + enum ggml_type type; + }; + + struct llama_model_imatrix_data { + const char * name; + const float * data; + size_t size; + }; + // model quantization parameters typedef struct llama_model_quantize_params { - int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype - enum ggml_type output_tensor_type; // output tensor type - enum ggml_type token_embedding_type; // token embeddings tensor type - bool allow_requantize; // allow quantizing non-f32/f16 tensors - bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored - bool pure; // quantize all tensors to the default type - bool keep_split; // quantize to the same number of shards - bool dry_run; // calculate and show the final quantization size without performing quantization - void * imatrix; // pointer to importance matrix data - void * kv_overrides; // pointer to vector containing overrides - void * tensor_types; // pointer to vector containing tensor types - void * prune_layers; // pointer to vector containing layer indices to prune + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + enum ggml_type output_tensor_type; // output tensor type + enum ggml_type token_embedding_type; // token embeddings tensor type + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards + bool dry_run; // calculate and show the final quantization size without performing quantization + const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data + const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides + const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides + const int32_t * prune_layers; // pointer to layer indices to prune } llama_model_quantize_params; typedef struct llama_logit_bias { @@ -465,6 +478,11 @@ extern "C" { const char * path_model, struct llama_model_params params); + // Load a model from an open FILE pointer + LLAMA_API struct llama_model * llama_model_load_from_file_ptr( + FILE * file, + struct llama_model_params params); + // Load a model from multiple splits (support custom naming scheme) // The paths must be in the correct order LLAMA_API struct llama_model * llama_model_load_from_splits( @@ -493,27 +511,6 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); - enum llama_params_fit_status { - LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path - }; - - // fits mparams and cparams to free device memory (assumes system memory is unlimited) - // - returns true if the parameters could be successfully modified to fit device memory - // - this function is NOT thread safe because it modifies the global llama logger state - // - only parameters that have the same value as in llama_default_model_params are modified - // with the exception of the context size which is modified if and only if equal to 0 - LLAMA_API enum llama_params_fit_status llama_params_fit( - const char * path_model, - struct llama_model_params * mparams, - struct llama_context_params * cparams, - float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements - struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t * margins, // margins of memory to leave per device in bytes - uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use - enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log - LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); @@ -636,7 +633,6 @@ extern "C" { // Load a LoRA adapter from file // The adapter is valid as long as the associated model is not freed - // All adapters must be loaded before context creation LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); @@ -660,9 +656,8 @@ extern "C" { LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); // Manually free a LoRA adapter - // NOTE: loaded adapters will be free when the associated model is deleted - LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter), - "adapters are now freed together with the associated model"); + // NOTE: loaded adapters that are not manually freed will be freed when the associated model is deleted + LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); // Get the invocation tokens if the current lora is an alora LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); @@ -1530,9 +1525,6 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); - // print a breakdown of per-device memory use via LLAMA_LOG: - LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); - // // training // diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 9aabe25c965..2790b12111d 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -41,22 +41,13 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para { ggml_tensor * attn_inp = cur; // save input for gate computation - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // compute gate from input ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp); cb(gate, "attn_gate_proj", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - // Q/K normalization Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -77,10 +68,8 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, - NULL, NULL, // wo will be applied after gating + NULL, NULL, NULL, // wo will be applied after gating Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -91,7 +80,7 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(cur, "attn_gated", il); // now apply output projection - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_o_proj", il); } diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 4d65614e466..af44cea6054 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -32,25 +30,15 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -62,7 +50,7 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 20b9ffd49eb..2e71f5d9e2a 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -36,30 +35,8 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -78,7 +55,7 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index b712e08cbd3..f8ca6aff6ab 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -30,18 +30,8 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +50,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index abd03cd0b97..2d0d05df485 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -29,18 +28,8 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); switch (model.type) { case LLM_TYPE_7B: @@ -67,7 +56,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 25e3369c313..67a7120d622 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -28,30 +28,8 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head_k, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -70,7 +48,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index 42098624663..497b4babd0c 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -3,7 +3,6 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -29,15 +28,8 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll // self_attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 0 * sizeof(float) * (n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -56,7 +48,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index 87331791418..7e046cfd2a4 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -2,7 +2,6 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -28,8 +27,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); auto * inp_attn = build_attn_inp_no_cache(); @@ -39,35 +38,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params ggml_tensor * cur = inpL; { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - // self-attention - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], - 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head * n_head, n_tokens); @@ -100,7 +72,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index ccf5bc8e82b..71526354ca6 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -28,33 +28,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - // B1.K - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - // B1.V - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -73,7 +48,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - NULL, NULL, + NULL, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, @@ -82,8 +57,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(cur, "attn_sub_norm", il); cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); - if (model.layers[il].bo) { - cur = ggml_add(ctx0, cur, model.layers[il].bo); + if (model.layers[il].wo_b) { + cur = ggml_add(ctx0, cur, model.layers[il].wo_b); } cb(cur, "attn_out", il); } @@ -121,6 +96,9 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index b1c19bb58a2..f3b0999bf54 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -2,7 +2,6 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -16,8 +15,8 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -30,22 +29,11 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 2f24105fa14..21deaba1a6d 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -36,22 +36,10 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { - Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur) * n_embd_head, - ggml_element_size(Qcur) * n_embd_head * n_head, - 0); - cb(Qcur, "Qcur", il); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -60,12 +48,6 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr } if (model.layers[il].attn_k_norm) { - Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, - ggml_element_size(Kcur) * n_embd_head, - ggml_element_size(Kcur) * n_embd_head * n_head_kv, - 0); - cb(Kcur, "Kcur", il); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, @@ -73,10 +55,6 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr cb(Kcur, "Kcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -94,7 +72,7 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 5887ed22e7e..7d4a43fdca5 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -3,7 +3,6 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -30,37 +29,8 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( @@ -80,7 +50,7 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -111,8 +81,13 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = build_norm(inpL, diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index e8e13e143f2..3ceb5835b85 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -2,7 +2,6 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -28,15 +27,8 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -55,7 +47,7 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 2ef2b6e389b..be3eeeddac7 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -28,18 +28,20 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa for (int il = 0; il < n_layer; ++il) { // get either the text or image weight tensors - ggml_tensor *wqkv, *wo; + ggml_tensor *wqkv, *wo, *wo_s; ggml_tensor *ffn_gate, *ffn_down, *ffn_up; if (is_text) { wqkv = model.layers[il].wqkv; wo = model.layers[il].wo; + wo_s = model.layers[il].wo_s; ffn_gate = model.layers[il].ffn_gate; ffn_down = model.layers[il].ffn_down; ffn_up = model.layers[il].ffn_up; } else { wqkv = model.layers[il].visexp_attn_wqkv; wo = model.layers[il].visexp_attn_wo; + wo_s = nullptr; ffn_gate = model.layers[il].visexp_ffn_gate; ffn_down = model.layers[il].visexp_ffn_down; ffn_up = model.layers[il].visexp_ffn_up; @@ -64,7 +66,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa Kcur = ggml_rope(ctx0, Kcur, inp_pos, n_embd_head, rope_type); cur = build_attn(inp_attn, - wo, nullptr, + wo, nullptr, wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); @@ -86,6 +88,10 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/cohere2-iswa.cpp b/examples/talk-llama/models/cohere2-iswa.cpp index 7c71a59ae7f..670b08e7d97 100644 --- a/examples/talk-llama/models/cohere2-iswa.cpp +++ b/examples/talk-llama/models/cohere2-iswa.cpp @@ -36,30 +36,8 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (is_swa) { Qcur = ggml_rope_ext( @@ -80,7 +58,7 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index ba1230f0419..067961caa08 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -32,27 +32,8 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM, il); @@ -73,7 +54,7 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index 73eb5cd24e7..0e882721807 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -2,7 +2,6 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -30,19 +29,8 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(cur, "wqkv_clamped", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -61,7 +49,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index ac448bfcaa8..30272eabd69 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -47,27 +45,8 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -80,7 +59,7 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index 3432359e03a..671b72dfead 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -35,27 +35,8 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -68,7 +49,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index d437fe29e71..303fc72c610 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -2,6 +2,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + bool is_ocr = model.arch == LLM_ARCH_DEEPSEEK2OCR; + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA @@ -54,7 +57,38 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + const int ocr_rope_type = GGML_ROPE_TYPE_NEOX; + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); + + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn_kv, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + else { ggml_tensor * q = NULL; const bool is_lite = model.layers[il].wq; @@ -148,7 +182,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn_k, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); } else { ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); @@ -185,7 +219,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) cur = build_attn(inp_attn_kv, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 07236dd27c9..5d1750fedda 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -29,18 +29,8 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -59,7 +49,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 4edc8530cb3..8e7d9ae64c7 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //copied from qwen2 @@ -31,22 +29,8 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -59,7 +43,7 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 63baf152c40..fc6a3e17a09 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -30,27 +30,8 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -63,7 +44,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index d548de0547b..033ba409eab 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -29,27 +29,8 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap } // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -62,7 +43,7 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index e8628d165d0..43fff4daf3a 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -24,17 +24,8 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap LLM_NORM_RMS, il); { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - Qcur = build_lora_mm(model.layers[il].wq, cur); - Kcur = build_lora_mm(model.layers[il].wk, cur); - Vcur = build_lora_mm(model.layers[il].wv, cur); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -53,7 +44,7 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } @@ -82,6 +73,7 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap cur = ggml_add(ctx0, cur, ffn_inp); + // input for next layer inpL = cur; } cur = inpL; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index ea75701c528..7b88a31d39d 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -35,18 +35,8 @@ llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -65,7 +55,7 @@ llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn_iswa, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index d4eea58e2f1..4f845bf4106 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -34,27 +32,8 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -67,7 +46,7 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 755af3b747b..34bee3b8fe9 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,6 +1,5 @@ #include "models.h" - template llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -39,18 +38,8 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ { ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -69,7 +58,7 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index ff842d93a41..05accf90fad 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -27,19 +27,8 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self-attention - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -52,7 +41,7 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr cb(Vcur, "Vcur-post-rope", il); ggml_tensor * attn_out = build_attn(inp->get_attn(), - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 9fcba508878..2f65fa56e1f 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -42,12 +40,8 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cur = attn_norm; } - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -67,7 +61,7 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index 98110d45e3b..b6de9551c52 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -9,7 +9,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -31,18 +31,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -65,7 +55,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 1869efd389a..09d2ff8bae7 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -29,18 +28,8 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +49,7 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para cb(Qcur, "Qcur_scaled", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma2-iswa.cpp b/examples/talk-llama/models/gemma2-iswa.cpp index 3927ddd297b..0ef07df8d01 100644 --- a/examples/talk-llama/models/gemma2-iswa.cpp +++ b/examples/talk-llama/models/gemma2-iswa.cpp @@ -31,18 +31,8 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -61,7 +51,7 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index bbb4d9a81e8..0da4af21c17 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -9,7 +9,7 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -47,18 +47,8 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -84,7 +74,7 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n-iswa.cpp index 8ce2ae39c2f..f8095417e06 100644 --- a/examples/talk-llama/models/gemma3n-iswa.cpp +++ b/examples/talk-llama/models/gemma3n-iswa.cpp @@ -1,5 +1,12 @@ #include "models.h" +// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim +static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int) x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -12,7 +19,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // TODO: is causal == true correct? might need some changes auto * inp_attn = build_attn_inp_kv_iswa(); - // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] - ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); + ggml_tensor * inp_per_layer = build_inp_per_layer(); + ggml_build_forward_expand(gf, inp_per_layer); + + // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); // inpL now has only 1 altup, project it to the rest of the altups // these "added" altups will be concat to the last dim of inpL @@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] cb(inpL, "inp_stacked", -1); } - // inpL now has shape: [n_embd, n_tokens, n_altup] - // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + // inpL now has shape: [n_embd, n_tokens, n_altup] for (int il = 0; il < n_layer; ++il) { // this block is made to be closely resemble Gemma3p5DecoderLayer on python code @@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] // predicted value will go through self-attention and laurel - ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] - cur = active_prediction; + ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens] + cur = active_prediction; cb(cur, "active_prediction", il); // norm @@ -62,19 +71,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // self-attention if (hparams.has_kv(il)) { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -94,7 +91,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(Kcur, "Kcur_pos", il); cur = build_attn(inp_attn, model.layers[il].wo, - NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, + NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } else { // reuse KV cache of earlier layers @@ -110,7 +107,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(Qcur, "Qcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); @@ -151,12 +148,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_tensor * first_prediction; // [n_embd, n_tokens] { - first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] + first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] cb(first_prediction, "first_prediction_gated", il); - ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] + + ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] cb(first_prediction, "first_prediction_scaled", il); @@ -167,7 +165,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const } // equivalent to python code: corrected_predictions[1:] += first_prediction { - ggml_tensor * slice_first = view_2d_slice(corrected, 0); + ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0); ggml_tensor * slice_rest = ggml_view_3d( ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd), ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected)); @@ -185,7 +183,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // cur now has multiple altup(s), we want to merge them back to 1 altup { - ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] + ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens] // do a view to skip the first slice (active altup) ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd), @@ -197,9 +195,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(altup_unembd, "altup_unembd", -1); // equivalent to torch.mean(hidden_states, dim=0) - cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] + cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens] for (int i = 0; i < n_altup - 1; ++i) { - cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); + cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i)); } cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] cb(cur, "unembd_merged", -1); @@ -235,39 +233,34 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); } -// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim -ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { - GGML_ASSERT(idx < (int) x->ne[2]); - return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), - idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); -} - // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { +ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() { auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; + float tok_embd_scale = sqrtf((float) n_embd_altup); if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); ggml_set_input(inp->tokens); res->t_inp_tokens = inp->tokens; - inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); + inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); - inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); + inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); cb(inp_per_layer, "inp_per_layer_selected", -1); res->add_input(std::move(inp)); } else { - // Vision embedding path: use padding token (ID=0) embedding + // Multimodal embedding path: use padding token (ID=0) embedding // TODO: verify if this is the correct behavior in transformers implementation - const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer + const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer // Extract and dequantize padding token embedding (row 0) - ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); - inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32); + ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); + inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); // Reshape to [n_embd_altup, n_layer, 1] inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); - cb(inp_per_layer, "inp_per_layer_vision", -1); + cb(inp_per_layer, "inp_per_layer_multimodal", -1); } return inp_per_layer; } @@ -275,18 +268,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { // equivalent to project_per_layer_inputs() in python code // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim // output shape: [n_embd_altup, n_tokens, n_layer] -ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { +ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); const float per_layer_input_scale = 1.0f / sqrtf(2.0f); - ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); - per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); - per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); - per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, - -1); // [n_embd_altup, n_layer, n_tokens] + ggml_tensor * per_layer_proj; + per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); + per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); + + per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1); cb(per_layer_proj, "per_layer_proj", -1); - inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer); + inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); cb(inp_per_layer, "inp_per_layer", -1); @@ -337,7 +331,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso // input cur shape: [n_embd, n_tokens, n_altup] // output shape: [n_embd, n_tokens, n_altup] ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { - ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] + ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens] ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); @@ -365,7 +359,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); - ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); + ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] cb(innovation, "innovation", il); diff --git a/examples/talk-llama/models/gemma4-iswa.cpp b/examples/talk-llama/models/gemma4-iswa.cpp new file mode 100644 index 00000000000..c7fb7747414 --- /dev/null +++ b/examples/talk-llama/models/gemma4-iswa.cpp @@ -0,0 +1,322 @@ +#include "models.h" + +// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim +static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int) x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + +llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params), + model(model), + n_embd_per_layer(model.hparams.n_embd_per_layer) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: is causal == true correct? might need some changes + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inp_per_layer = nullptr; + if (model.per_layer_tok_embd) { + inp_per_layer = build_inp_per_layer(); + ggml_build_forward_expand(gf, inp_per_layer); + + // inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer] + inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); + } + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_embd_head = hparams.n_embd_head_k(il); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il)); + + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * freq_factors = nullptr; + if (!hparams.is_swa(il)) { + // full_attention layers use rope_freqs for proportional rope + freq_factors = model.layers[il].rope_freqs; + } + + // Q projection (shared for both non-KV and KV layers) + // this is to mirror Gemma4Attention in pytorch code + ggml_tensor * Qcur; + { + Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); + cb(Qcur, "Qcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + } + + // self-attention + if (hparams.has_kv(il)) { + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = model.layers[il].wv + ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) + : Kcur; // if v_proj is not present, use Kcur as Vcur + cb(Vcur, "Vcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); + + cb(Kcur, "Kcur_normed", il); + cb(Vcur, "Vcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Kcur, "Kcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, + nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, + hparams.f_attention_scale, il); + } else { + // reuse KV cache of earlier layers + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + } + + // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + cur = build_norm(cur, + model.layers[il].attn_post_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + // feed-forward network + const bool is_moe_layer = model.layers[il].ffn_gate_inp != nullptr; + if (is_moe_layer) { + // MLP (shared exp) + ggml_tensor * cur_mlp = build_norm(attn_out, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_norm_1", il); + + cur_mlp = build_ffn(cur_mlp, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cur_mlp = build_norm(cur_mlp, + model.layers[il].ffn_post_norm_1, nullptr, + LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_mlp", il); + + // Expert FFN + ggml_tensor * cur_moe = build_norm(attn_out, + model.layers[il].ffn_pre_norm_2, nullptr, + LLM_NORM_RMS, il); + cb(cur_moe, "ffn_norm_2", il); + + // custom MoE logits calculation (router operates on attn_out, not cur) + ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps); + tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd)); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s); + ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + cur_moe = build_moe_ffn(cur_moe, + nullptr, // gate_inp + nullptr, // up_exps + nullptr, // gate_exps + model.layers[il].ffn_down_exps, + nullptr, // exp_probs_b (not used for gemma4) + n_expert, n_expert_used, + LLM_FFN_GELU, true, + 1.0f, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, logits, + model.layers[il].ffn_gate_up_exps, + nullptr, // up_exps_s + nullptr, // gate_exps_s + model.layers[il].ffn_down_exps_s); + cur_moe = build_norm(cur_moe, + model.layers[il].ffn_post_norm_2, nullptr, + LLM_NORM_RMS, il); + cb(cur_moe, "ffn_moe", il); + + cur = ggml_add(ctx0, cur_mlp, cur_moe); + cb(cur, "ffn_moe_combined", il); + } else { + cur = build_norm(attn_out, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, nullptr, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, attn_out); + + // per-layer embedding + if (inp_per_layer) { + ggml_tensor * pe_in = cur; + cb(cur, "pe_in", il); + + cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens] + cur = ggml_gelu(ctx0, cur); + + ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] + + // TODO @ngxson : improve this + if (il == n_layer - 1 && inp_out_ids) { + inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); + } + + cur = ggml_mul(ctx0, cur, inp_this_layer); + cur = build_lora_mm(model.layers[il].per_layer_proj, cur); // [n_embd, n_tokens] + cur = build_norm(cur, model.layers[il].per_layer_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "per_layer_embd_out", il); + + // residual connection + cur = ggml_add(ctx0, pe_in, cur); + } + + // layer_scalar + if (model.layers[il].out_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + } + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// equivalent to get_per_layer_inputs() in python code +// output shape: [n_embd_per_layer, n_layer, n_tokens] +ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() { + auto inp = std::make_unique(n_embd); + + ggml_tensor * inp_per_layer; + float tok_embd_scale = sqrtf((float) n_embd_per_layer); + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; + + inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens); + inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); + cb(inp_per_layer, "inp_per_layer_selected", -1); + + res->add_input(std::move(inp)); + } else { + // Multimodal embedding path: use padding token (ID=0) embedding + // TODO: verify if this is the correct behavior in transformers implementation + const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer + + // Extract and dequantize padding token embedding (row 0) + ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); + inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); + + // Reshape to [n_embd_per_layer, n_layer, 1] + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, 1); + cb(inp_per_layer, "inp_per_layer_multimodal", -1); + } + return inp_per_layer; +} + +// equivalent to project_per_layer_inputs() in python code +// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim +// inp_batch shape: [n_embd, n_tokens] +// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer) +// output shape: [n_embd_per_layer, n_tokens, n_layer] +ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { + const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); + const float per_layer_input_scale = 1.0f / sqrtf(2.0f); + + // note: this matrix multiplication will be performed in the input layer (i.e. on the CPU) + ggml_tensor * per_layer_proj; + per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); + per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens); + + per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1); + cb(per_layer_proj, "per_layer_proj", -1); + + inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); + cb(inp_per_layer, "inp_per_layer", -1); + + // permute to shape: [n_embd_per_layer, n_tokens, n_layer] + inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); + return inp_per_layer; +} diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 7938545ed8a..8d4f4a01553 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -38,27 +38,8 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Apply Q/K norm if available (GLM-4.5 355B variant) if (model.layers[il].attn_q_norm) { @@ -94,7 +75,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_transformer_layers - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index b6ad8febed3..f0bfda393fa 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -1,10 +1,7 @@ #include "models.h" - - llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -41,40 +38,8 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], - 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_mrope) { Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, @@ -100,7 +65,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_transformer_layers - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index cb1238f2d34..f8dc53eb723 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -2,7 +2,6 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -34,22 +33,11 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 1c8fe6c836d..0016ddede43 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -28,15 +26,8 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -55,7 +46,7 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 9b54a38c386..e983742bef5 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -73,31 +73,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const llama_model & model, const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const bool use_rope = hparams.rope_finetuned; if (use_rope) { @@ -116,7 +92,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 7a7e1664c29..6ea90285225 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -76,31 +76,8 @@ ggml_tensor * llm_build_granite::build_attention_layer( const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const bool use_rope = hparams.rope_finetuned; if (use_rope) { @@ -124,7 +101,7 @@ ggml_tensor * llm_build_granite::build_attention_layer( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 580d63e36ae..b8f35afdc03 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -30,27 +30,8 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index aa60d3e9388..151108a2a71 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -30,18 +30,8 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -60,7 +50,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/hunyuan-dense.cpp b/examples/talk-llama/models/hunyuan-dense.cpp index 6a51707c85b..1cd85d6d9d4 100644 --- a/examples/talk-llama/models/hunyuan-dense.cpp +++ b/examples/talk-llama/models/hunyuan-dense.cpp @@ -6,6 +6,11 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); + const bool use_mrope = hparams.use_mrope(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; @@ -34,44 +39,39 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + if (use_mrope) { + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); @@ -83,7 +83,7 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 806c30b3667..ffe1664b0e1 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -35,27 +35,8 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -84,7 +65,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index 441d250268e..83be2ca0aee 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -30,27 +30,8 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 135bf288ba1..31101f3c14b 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -2,7 +2,6 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -24,22 +23,11 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -66,8 +54,14 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = build_norm(inpL, model.output_norm, diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index 2cfe484eb52..507e04fa4aa 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -31,25 +31,8 @@ llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_para // Self-attention with separate Q, K, V projections { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur_bias", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur_bias", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur_bias", il); - - // Reshape for attention - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Apply RoPE Qcur = ggml_rope_ext( @@ -68,7 +51,7 @@ llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index c0c89de187a..f82b7795c87 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -24,25 +24,12 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para } else { // Attention - struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // No RoPE :) cur = build_attn(inp_hybrid->get_attn(), - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index 4d62f4e7159..58c89c417fc 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -268,7 +268,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * Vcur = kv_cmpr; cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); + cur = build_attn(inp_attn_k, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); cb(cur, "mla_out", il); } else { // MLA KV cache disabled. Fall back to MHA KV cache. Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); @@ -299,7 +299,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Direct softmax attention (with MHA KV cache) // Use build_attn with inp_attn for proper mask handling - cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); + cur = build_attn(inp_attn_kv, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); cb(cur, "mla_out", il); } } @@ -362,6 +362,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cur = build_cvec(cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } cur = inpL; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index dfa322166b1..eb8ec3c803a 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -42,16 +42,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ const auto n_embd_head = hparams.n_embd_head_v(); const auto n_head_kv = hparams.n_head_kv(il); - auto * q = build_lora_mm(model.layers[il].wq, cur); - cb(q, "model.layers.{}.self_attn.q_proj", il); - auto * k = build_lora_mm(model.layers[il].wk, cur); - cb(k, "model.layers.{}.self_attn.k_proj", il); - auto * v = build_lora_mm(model.layers[il].wv, cur); - cb(v, "model.layers.{}.self_attn.v_proj", il); - - q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); - k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); - v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); + auto [q, k, v] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // qk norm q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); @@ -66,7 +58,7 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ attn_factor, beta_fast, beta_slow); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "model.layers.{}.self_attn.out_proj", il); @@ -177,6 +169,9 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ cb(ffn_norm_out, "model.layers.{}.ffn_out", il); cur = ggml_add(ctx0, cur, ffn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); } cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index 18de88fde1f..c756d6fde5f 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -30,18 +30,8 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,7 +56,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 0dac9d616ae..501df3c7eaf 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -30,17 +30,8 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para // self-attention { // compute separate Q, K, V projections without bias, matching LLaDALlamaBlock - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -53,7 +44,7 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index e08ae0c0b0e..8d478dc6747 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -43,27 +43,8 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -89,11 +70,8 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } cb(cur, "attn_out", il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llama-iswa.cpp b/examples/talk-llama/models/llama4.cpp similarity index 81% rename from examples/talk-llama/models/llama-iswa.cpp rename to examples/talk-llama/models/llama4.cpp index 67cb9a10ec5..4e4bfb43f33 100644 --- a/examples/talk-llama/models/llama-iswa.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -18,7 +19,14 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_tensor * inp_attn_scale = nullptr; inp_attn_scale = build_inp_attn_scale(); - auto * inp_attn = build_attn_inp_kv_iswa(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -46,27 +54,8 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext( @@ -95,7 +84,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -176,3 +165,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_build_forward_expand(gf, cur); } + +// Explicit template instantiations +template struct llm_build_llama4; +template struct llm_build_llama4; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index a72b7790a1f..8a76931c007 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -30,18 +30,8 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -66,7 +56,7 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/mamba-base.cpp b/examples/talk-llama/models/mamba-base.cpp index 9de587db55f..c37f29c487e 100644 --- a/examples/talk-llama/models/mamba-base.cpp +++ b/examples/talk-llama/models/mamba-base.cpp @@ -42,7 +42,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} - ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur); + ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur, layer.ssm_in_s); // split the above in two // => {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); @@ -137,7 +137,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(layer.ssm_out, y); + cur = build_lora_mm(layer.ssm_out, y, layer.ssm_out_s); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} @@ -184,7 +184,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur, model.layers[il].ssm_in_s); // split the above in three ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0], @@ -278,7 +278,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); + cur = build_lora_mm(model.layers[il].ssm_out, y, model.layers[il].ssm_out_s); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} diff --git a/examples/talk-llama/models/mimo2-iswa.cpp b/examples/talk-llama/models/mimo2-iswa.cpp index 06956915ea0..52c6acfe214 100644 --- a/examples/talk-llama/models/mimo2-iswa.cpp +++ b/examples/talk-llama/models/mimo2-iswa.cpp @@ -58,7 +58,7 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_ ggml_tensor * sinks = model.layers[il].attn_sinks; cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); } diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index 89dd7105157..bf12ab73c74 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -134,7 +134,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap cb(k_states, "k_states", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 83d0916c08c..b809b79f2b9 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -64,7 +64,7 @@ llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 42a5117ff02..b5ae72a2ee1 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -41,27 +41,8 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -86,7 +67,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index a86b2b1ebd7..94991c55fe8 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -256,9 +256,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params); ggml_tensor * calc_magnitude(ggml_tensor * x); - ggml_tensor * view_2d_slice(ggml_tensor * x, int idx); - ggml_tensor * get_per_layer_inputs(); - ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + ggml_tensor * gaussian_topk(ggml_tensor * x); ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); ggml_tensor * altup_predict(ggml_tensor * cur, int il); @@ -266,6 +268,18 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); }; +struct llm_build_gemma4_iswa : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_per_layer; + + llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); +}; + struct llm_build_gemma_embedding : public llm_graph_context { llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params); }; @@ -393,8 +407,9 @@ struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_llama_iswa : public llm_graph_context { - llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_llama4 : public llm_graph_context { + llm_build_llama4(const llama_model & model, const llm_graph_params & params); }; struct llm_build_maincoder : public llm_graph_context { @@ -481,7 +496,7 @@ struct llm_build_phi2 : public llm_graph_context { llm_build_phi2(const llama_model & model, const llm_graph_params & params); }; -template +template struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params); }; @@ -687,12 +702,13 @@ struct llm_build_step35_iswa : public llm_graph_context { llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_t5_dec : public llm_graph_context { - llm_build_t5_dec(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_t5 : public llm_graph_context { + llm_build_t5(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_t5_enc : public llm_graph_context { - llm_build_t5_enc(const llama_model & model, const llm_graph_params & params); +struct llm_build_t5encoder : public llm_build_t5 { + llm_build_t5encoder(const llama_model & model, const llm_graph_params & params); }; struct llm_build_wavtokenizer_dec : public llm_graph_context { diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index 26020584c6d..5c6a1b5e1bc 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -2,7 +2,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -15,8 +14,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -37,14 +36,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll } // self attention - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - const size_t type_size = ggml_type_size(cur->type); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // RoPE Qcur = ggml_rope_ext( @@ -64,7 +57,7 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index ce44a805f5c..8596bbb2024 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -1,10 +1,7 @@ #include "models.h" - - llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -38,25 +35,8 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & { cur = attn_norm; - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - - if (hparams.f_clamp_kqv > 0.0f) { - cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(cur, "wqkv_clamped", il); - } - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 0 * sizeof(float) * (n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Q/K Layernorm if (model.layers[il].attn_q_norm) { @@ -76,7 +56,7 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 7af99174d16..dc07d43df58 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -65,40 +65,12 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * const llama_model & model, int64_t n_embd_head, int il) { - // compute Q and K - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; @@ -107,9 +79,9 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -136,7 +108,10 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il, - router_logits); + router_logits, nullptr, + model.layers[il].ffn_up_exps_s, + nullptr, // no gate + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); if (model.layers[il].ffn_latent_up) { @@ -144,9 +119,9 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla } ggml_tensor * ffn_shexp = build_ffn(inp_emb, - model.layers[il].ffn_up_shexp, NULL, NULL, - NULL /* no gate */ , NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + NULL /* no gate */ , NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 34aa6fa5ec4..054b16fe0ef 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -31,27 +31,8 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -70,7 +51,7 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index 2fdf4a3692f..da68024a34d 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -2,7 +2,6 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -27,17 +26,8 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_grap LLM_NORM_RMS, il); { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - // self-attention - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // RoPE Qcur = ggml_rope_ext( @@ -57,7 +47,7 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index 26f4b6ee628..a9974025f07 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -30,27 +30,8 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 5076359e3f9..308d2a600c2 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -89,7 +89,7 @@ llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 83a56a0b3b6..ed46a00ef90 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -68,7 +68,7 @@ llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/openai-moe-iswa.cpp b/examples/talk-llama/models/openai-moe-iswa.cpp index 403f130bc41..50992b8d506 100644 --- a/examples/talk-llama/models/openai-moe-iswa.cpp +++ b/examples/talk-llama/models/openai-moe-iswa.cpp @@ -28,27 +28,8 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_rot, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -67,7 +48,7 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); cb(cur, "attn_out", il); diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index 5df6fe3e3ce..514ac33517f 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -73,7 +73,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_ cb(Qcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index 48c01efe368..a5874b6dee7 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -30,30 +30,8 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - // if (model.layers[il].bq) { - // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - // cb(Qcur, "Qcur", il); - // } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - // if (model.layers[il].bk) { - // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - // cb(Kcur, "Kcur", il); - // } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - // if (model.layers[il].bv) { - // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - // cb(Vcur, "Vcur", il); - // } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -72,7 +50,7 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 340455c2d5f..56cb1d94c5f 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -35,27 +35,8 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr } // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_multi( ctx0, Qcur, inp_pos, nullptr, @@ -74,7 +55,7 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { diff --git a/examples/talk-llama/models/pangu-embedded.cpp b/examples/talk-llama/models/pangu-embedded.cpp index 1cf0938e68f..53464f21d22 100644 --- a/examples/talk-llama/models/pangu-embedded.cpp +++ b/examples/talk-llama/models/pangu-embedded.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -31,21 +30,8 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co // self attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -63,7 +49,7 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 32d40d71fb7..0fb3ffa2e63 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -30,29 +28,8 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], attn_norm_output, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -74,7 +51,7 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 3d11a9459c4..39af285d3c5 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -3,7 +3,6 @@ template llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -39,27 +38,8 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ LLM_NORM_RMS, il); cb(attn_norm_output, "attn_norm", il); - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } - else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], attn_norm_output, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -80,7 +60,7 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ cb(Qcur, "Qcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index b7a71211042..4d5c84506c2 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -30,18 +30,8 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +50,7 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index f02acbc1869..b6142daebd9 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -71,6 +71,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, residual); cb(cur, "ffn_residual", il); + // input for next layer inpL = cur; } @@ -140,7 +141,7 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv ext_factor, attn_factor, beta_fast, beta_slow); cur = build_attn(inp, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il); } diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 32af6e04663..67844c09f24 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -73,7 +73,7 @@ llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_gr const float attn_scale = 1.0f / sqrtf(float(head_dim_q)); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il); cb(cur, "attn_out", il); @@ -109,6 +109,8 @@ llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_gr cur = build_cvec(cur, il); cb(cur, "l_out", il); + + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index bcb651ce543..abce6b34d04 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -120,7 +120,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & cb(k_states, "k_states", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 7390f1320bf..44e75d87437 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -28,15 +27,8 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -56,7 +48,7 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 58c10622508..2892dd75087 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -30,30 +30,8 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -72,7 +50,7 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 60761789dc9..5f0a6861b68 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -30,27 +30,8 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 9004bab9db1..da7937c7667 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -33,21 +33,8 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_multi( ctx0, Qcur, inp_pos, nullptr, @@ -66,7 +53,7 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 52081668477..883dd5f9a90 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -30,18 +30,8 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,11 +56,8 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 3108bf331ac..87790f08e4e 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -64,6 +64,9 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_ffn", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -176,7 +179,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -222,9 +225,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); + cb(beta, "beta_sigmoid", il); ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); - alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -266,7 +270,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -342,7 +346,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index 165e2412e56..7dc6a23c751 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -64,6 +64,9 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_moe", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -176,7 +179,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -222,9 +225,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); + cb(beta, "beta_sigmoid", il); ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); - alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -266,7 +270,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -342,7 +346,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index dba46618ff2..16bedba994d 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -30,18 +30,8 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,11 +56,8 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index cc479dd075c..1beda70b7cf 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -56,6 +56,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_moe", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -154,7 +157,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -169,7 +172,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cur = ggml_mul(ctx0, cur, gate); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -351,7 +354,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -411,19 +414,19 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( GGML_ASSERT(num_v_heads % num_k_heads == 0); int64_t repeat_factor = num_v_heads / num_k_heads; - // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back - ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); - ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + // repeat interleave: reshape to (repeat part, 1, remaining part...), do repeat, then reshape back + ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, 1, num_k_heads, n_seq_tokens * n_seqs); + ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, 1, num_k_heads, n_seq_tokens * n_seqs); // Repeat along the third dimension (the new dimension with size 1) ggml_tensor * q_repeated = - ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs); ggml_tensor * k_repeated = - ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs); // Reshape back to merge the head and repeat dimensions - // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] - // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] + // From [head_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs] + // Back to [head_dim, repeat_factor * num_k_heads, n_seq_tokens, n_seqs] q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); } @@ -442,7 +445,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vl-moe.cpp index 195daea66c9..29ee8278a4d 100644 --- a/examples/talk-llama/models/qwen3vl-moe.cpp +++ b/examples/talk-llama/models/qwen3vl-moe.cpp @@ -36,18 +36,8 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -72,7 +62,7 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index bbd5f42ba5b..faa5f2ef3c8 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -36,18 +36,8 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -72,7 +62,7 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index 140700d9e2d..398eb368db0 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -24,25 +24,15 @@ llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index c8e1f43400f..a917c19f25a 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -32,18 +32,8 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -68,7 +58,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 15453fbf50f..032b219d6cb 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -8,7 +8,7 @@ llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_para ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 5caf6553dfe..16ffa6901b9 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -9,7 +9,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para ggml_tensor * v_first = nullptr; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index a4d0b75d846..6db8d9781fe 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -32,27 +32,8 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -71,7 +52,7 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index e2155aacef4..55d09ec325d 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -45,18 +45,8 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, // self_attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, @@ -69,7 +59,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, cb(Kcur, "Kcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -101,6 +91,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, cur = ffn_out; cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); cb(cur, "l_out", il); diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index e267fd8f32f..83636dbf546 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -34,27 +34,8 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext( @@ -74,7 +55,7 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index ff5aced93b3..9c19abd8835 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -30,30 +30,8 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, @@ -87,7 +65,7 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index 941cee98219..cf9fe95c35b 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -2,7 +2,6 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -33,22 +32,11 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_gr // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index a5965aceb3b..b6d4d5aac1a 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -30,27 +30,8 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35-iswa.cpp index 176209cd93e..86aa98909e7 100644 --- a/examples/talk-llama/models/step35-iswa.cpp +++ b/examples/talk-llama/models/step35-iswa.cpp @@ -68,7 +68,7 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); ggml_tensor * attn_out = build_attn(inp_attn, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); // head-wise attention gate: sigmoid(g_proj(x)) in torch @@ -92,7 +92,7 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll } // output projection - cur = build_lora_mm(model.layers[il].wo, attn_out); + cur = build_lora_mm(model.layers[il].wo, attn_out, model.layers[il].wo_s); cb(cur, "attn_proj", il); } @@ -145,9 +145,11 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/t5-enc.cpp b/examples/talk-llama/models/t5-enc.cpp deleted file mode 100644 index 395dfb51042..00000000000 --- a/examples/talk-llama/models/t5-enc.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "models.h" - -llm_build_t5_enc::llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v(); - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); - - auto * inp_attn = build_attn_inp_no_cache(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm_enc, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; - ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); - - cur = build_attn(inp_attn, - model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); - cb(cur, "kqv_out", il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm_enc, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // T5 uses relu, flan-T5 uses gelu-gated - cur = build_ffn(cur, - model.layers[il].ffn_up_enc, NULL, NULL, - model.layers[il].ffn_gate_enc, NULL, NULL, - model.layers[il].ffn_down_enc, NULL, NULL, - NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, - il); - cb(cur, "ffn_out", il); - } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - cb(cur, "result_embd", -1); - - cur = build_norm(cur, - model.output_norm_enc, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/t5-dec.cpp b/examples/talk-llama/models/t5.cpp similarity index 64% rename from examples/talk-llama/models/t5-dec.cpp rename to examples/talk-llama/models/t5.cpp index 8ca8372bd4c..9f9dfef4012 100644 --- a/examples/talk-llama/models/t5-dec.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template <> +llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -34,24 +35,13 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); cur = build_attn(inp_attn_self, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -82,7 +72,7 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); cur = build_attn(inp_attn_cross, - model.layers[il].wo_cross, nullptr, + model.layers[il].wo_cross, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -164,3 +154,99 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } + +template <> +llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); + + cur = build_attn(inp_attn, + model.layers[il].wo_enc, nullptr, nullptr, + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/t5encoder.cpp b/examples/talk-llama/models/t5encoder.cpp new file mode 100644 index 00000000000..5c1f9eb4030 --- /dev/null +++ b/examples/talk-llama/models/t5encoder.cpp @@ -0,0 +1,3 @@ +#include "models.h" + +llm_build_t5encoder::llm_build_t5encoder(const llama_model & model, const llm_graph_params & params) : llm_build_t5(model, params) {} diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index 537a0d41248..a7776d9cdc9 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -93,7 +93,7 @@ llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model cur = build_norm(cur, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); + LLM_NORM, 0); cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index 3a8dfafcceb..53085ec80f6 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -28,18 +28,8 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -58,7 +48,7 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index 122c8ca04a5..dc13e53f09f 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -470,6 +470,141 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & return bpe_offsets; } +// Qwen2 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +static std::vector unicode_regex_split_custom_qwen2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + //if (len > 0) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + //} + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?\p{L}+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters + pos++; + while (_get_flags(pos).is_letter) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + template static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { using BidirIt = typename std::basic_string::const_iterator; @@ -753,6 +888,35 @@ static std::vector unicode_regex_split_custom_afmoe(const std::string & return bpe_offsets; } +// regex: [^\n]+|[\n]+ +// splits text into runs of non-newline characters and runs of newline characters +static std::vector unicode_regex_split_custom_newlines(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; + bpe_offsets.reserve(offsets.size()); + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + size_t pos = offset_ini; + while (pos < offset_end) { + const bool is_newline = (cpts[pos] == '\n'); + const size_t run_start = pos; + while (pos < offset_end && (cpts[pos] == '\n') == is_newline) { + pos++; + } + bpe_offsets.push_back(pos - run_start); + } + } + + return bpe_offsets; +} + static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { std::vector bpe_offsets; @@ -761,14 +925,18 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" || regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { - bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); } else if (regex_expr == "\\p{AFMoE_digits}") { // AFMOE digit pattern - use custom implementation for proper splitting bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); + } else if (regex_expr == "[^\\n]+|[\\n]+") { + bpe_offsets = unicode_regex_split_custom_newlines(text, offsets); } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") { // tiny_aya digit grouping pattern from tokenizer.json: // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} @@ -912,7 +1080,7 @@ bool unicode_cpt_is_han(uint32_t cpt) { return false; } -std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs, bool byte_encode) { // unicode categories static const std::map k_ucat_enum = { { "\\p{N}", unicode_cpt_flags::NUMBER }, @@ -1099,5 +1267,9 @@ std::vector unicode_regex_split(const std::string & text, const std start += offset; } - return unicode_byte_encoding_process(bpe_words); + if (byte_encode) { + return unicode_byte_encoding_process(bpe_words); + } + + return bpe_words; } diff --git a/examples/talk-llama/unicode.h b/examples/talk-llama/unicode.h index 5bd1362ff41..600ab9216b9 100644 --- a/examples/talk-llama/unicode.h +++ b/examples/talk-llama/unicode.h @@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt); bool unicode_cpt_is_han(uint32_t cpt); -std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs, bool byte_encode = true); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index d446138ecb7..b043c278cae 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,12 +1,15 @@ cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. + project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 8) +set(GGML_VERSION_MINOR 10) +set(GGML_VERSION_PATCH 2) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") + find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) if(GIT_EXE) # Get current git commit hash @@ -168,15 +171,16 @@ if (NOT MSVC) option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF) option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF) endif() -option(GGML_LASX "ggml: enable lasx" ON) -option(GGML_LSX "ggml: enable lsx" ON) -option(GGML_RVV "ggml: enable rvv" ON) -option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) -option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) -option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) -option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON) -option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) -option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) +option(GGML_LASX "ggml: enable lasx" ON) +option(GGML_LSX "ggml: enable lsx" ON) +option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) +option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) +option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) +option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON) +option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF) +option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) +option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") @@ -205,12 +209,14 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) +option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON) set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING "ggml: cuda link binary compression mode; requires cuda 12.8+") set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") option(GGML_HIP "ggml: use HIP" OFF) -option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph" ON) +option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) @@ -244,6 +250,7 @@ option(GGML_RPC "ggml: use RPC" option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) +option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON) option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") diff --git a/ggml/cmake/FindNCCL.cmake b/ggml/cmake/FindNCCL.cmake new file mode 100644 index 00000000000..67511e2d56a --- /dev/null +++ b/ggml/cmake/FindNCCL.cmake @@ -0,0 +1,36 @@ +# cmake/FindNCCL.cmake + +# NVIDIA does not distribute CMake files with NCCl, therefore use this file to find it instead. + +find_path(NCCL_INCLUDE_DIR + NAMES nccl.h + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES include +) + +find_library(NCCL_LIBRARY + NAMES nccl + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES lib lib64 +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL + DEFAULT_MSG + NCCL_LIBRARY NCCL_INCLUDE_DIR +) + +if(NCCL_FOUND) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + + if(NOT TARGET NCCL::NCCL) + add_library(NCCL::NCCL UNKNOWN IMPORTED) + set_target_properties(NCCL::NCCL PROPERTIES + IMPORTED_LOCATION "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}" + ) + endif() +endif() + +mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 9fd3f7f32a0..d0c7e5a1be0 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -68,7 +68,7 @@ extern "C" { GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); // tensor copy between different backends - GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst); // // Backend (stream) @@ -83,13 +83,17 @@ extern "C" { GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); - GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); // "offset" refers to the offset in tensor->data for setting/getting data - GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set ( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get (const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); @@ -109,7 +113,7 @@ extern "C" { // the copy is performed after all the currently queued operations in backend_src // backend_dst will wait for the copy to complete before performing other operations // automatic fallback to sync copy if async is not supported - GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend); @@ -135,7 +139,9 @@ extern "C" { // integrated GPU device using host memory GGML_BACKEND_DEVICE_TYPE_IGPU, // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) - GGML_BACKEND_DEVICE_TYPE_ACCEL + GGML_BACKEND_DEVICE_TYPE_ACCEL, + // "meta" device wrapping multiple other devices for tensor parallelism + GGML_BACKEND_DEVICE_TYPE_META, }; // functionality supported by the device @@ -196,7 +202,12 @@ extern "C" { // Common functions that may be obtained using ggml_backend_reg_get_proc_address - // Split buffer type for tensor parallelism + // Context management and operations for faster communication between backends, used for tensor parallelism (meta backend) + typedef void * (*ggml_backend_comm_init_t)(ggml_backend_t * backends, size_t n_backends); + typedef void (*ggml_backend_comm_free_t)(void * comm_ctx); + typedef bool (*ggml_backend_comm_allreduce_tensor_t)(void * comm_ctx, struct ggml_tensor ** tensors); + + // Split buffer type for tensor parallelism (old) typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); // Set the number of threads for the backend typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); @@ -340,6 +351,53 @@ extern "C" { // Set a callback to be called for each resulting node during graph compute GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); + // + // Meta backend + // + +#define GGML_BACKEND_META_MAX_DEVICES 16 + + enum ggml_backend_meta_split_axis { + // tensor split by tensor dimensions: + GGML_BACKEND_SPLIT_AXIS_0 = 0, + GGML_BACKEND_SPLIT_AXIS_1 = 1, + GGML_BACKEND_SPLIT_AXIS_2 = 2, + GGML_BACKEND_SPLIT_AXIS_3 = 3, + + GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends + GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum + + // for internal bookkeeping only: + GGML_BACKEND_SPLIT_AXIS_NONE = 98, + GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99, + }; + GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis); + + struct ggml_backend_meta_split_state { + enum ggml_backend_meta_split_axis axis; + + // for tensors with axis >= 0 && axis < GGML_MAX_DIMS: + // - each device has a slice of the tensor along the split axis + // - most tensors have n_segments == 1 and a contiguous slice of the tensor data + // - some tensors have an inhomogenenous data layout along the split axis, + // those tensors are divided into segments which are each individually split across devices + // - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis, + // the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1], + // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments + // that each need to be split individually across devices so that each device gets a slice of Q, K, and V + int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES]; + uint32_t n_segments; + }; + + // function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible: + typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata); + + // create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this: + // TODO: this looks a bit strange - a backend API creates a device. I think we should try + // express this as a backend registry functionality instead + GGML_API ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud); + // // Utils // diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index 22ad2c00963..5436c7ef579 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -27,6 +27,9 @@ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend); // device buffer GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); +// conduct allreduce operation between devices +GGML_BACKEND_API bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // split tensor buffer that splits matrices by rows across multiple devices GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split); diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 1c11495b66e..6fcf5a43393 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -6,9 +6,9 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 6 -#define RPC_PROTO_PATCH_VERSION 1 +#define RPC_PROTO_MAJOR_VERSION 4 +#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 0 #ifdef __cplusplus static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 669f66b650f..703e3783136 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -428,7 +428,8 @@ extern "C" { // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) - GGML_TYPE_COUNT = 41, + GGML_TYPE_Q1_0 = 41, + GGML_TYPE_COUNT = 42, }; // precision @@ -465,6 +466,7 @@ extern "C" { GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors }; // available tensor operations: @@ -900,15 +902,17 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * ids); - GGML_API struct ggml_tensor * ggml_add1( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b), + "use ggml_add instead"); - GGML_API struct ggml_tensor * ggml_add1_inplace( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b), + "use ggml_add_inplace instead"); // dst = a // view(dst, nb1, nb2, nb3, offset) += b @@ -1769,8 +1773,32 @@ extern "C" { int n_dims, int mode); - // custom RoPE + // RoPE operations with extended options + // a is the input tensor to apply RoPE to, shape [n_embd, n_head, n_token] + // b is an int32 vector with size n_token // c is freq factors (e.g. phi3-128k), (optional) + // mode can be GGML_ROPE_TYPE_NORMAL or NEOX; for MROPE and VISION mode, use ggml_rope_multi + // + // pseudo-code for computing theta: + // for i in [0, n_dims/2): + // theta[i] = b[i] * powf(freq_base, -2.0 * i / n_dims); + // theta[i] = theta[i] / c[i]; # if c is provided, divide theta by c + // theta[i] = rope_yarn(theta[i], ...); # note: theta = theta * freq_scale is applied here + // + // other params are used by YaRN RoPE scaling, these default values will disable YaRN: + // freq_scale = 1.0f + // ext_factor = 0.0f + // attn_factor = 1.0f + // beta_fast = 0.0f + // beta_slow = 0.0f + // + // example: + // (marking: c = cos, s = sin, 0 = unrotated) + // given a single head with size = 8 --> [00000000] + // GGML_ROPE_TYPE_NORMAL n_dims = 4 --> [cscs0000] + // GGML_ROPE_TYPE_NORMAL n_dims = 8 --> [cscscscs] + // GGML_ROPE_TYPE_NEOX n_dims = 4 --> [ccss0000] + // GGML_ROPE_TYPE_NEOX n_dims = 8 --> [ccccssss] GGML_API struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1786,6 +1814,36 @@ extern "C" { float beta_fast, float beta_slow); + // multi-dimensional RoPE, for Qwen-VL and similar vision models + // mode can be either VISION, MROPE, IMROPE, cannot be combined with NORMAL or NEOX + // sections specify how many dimensions to rotate in each section: + // section length is equivalent to number of cos/sin pairs, NOT the number of dims + // (i.e. sum of 4 sections are expected to be n_dims/2) + // last sections can be 0, means ignored + // all other options are identical to ggml_rope_ext + // + // important note: + // - NEOX ordering is automatically applied and cannot be disabled for MROPE and VISION + // if you need normal ordering, there are 2 methods: + // (1) split the tensor manually using ggml_view + // (2) permute the weight upon conversion + // - for VISION, n_dims must be head_size/2 + // + // example M-RoPE: + // given sections = [t=4, y=2, x=2, 0] + // given a single head with size = 18 --> [000000000000000000] + // GGML_ROPE_TYPE_MROPE n_dims = 16 --> [ttttyyxxttttyyxx00] (cos/sin are applied in NEOX ordering) + // GGML_ROPE_TYPE_IMROPE n_dims = 16 --> [ttyxttyxttyxttyx00] (interleaved M-RoPE, still NEOX ordering) + // note: the theta for each dim is computed the same way as ggml_rope_ext, no matter the section + // in other words, idx used for theta: [0123456789... until n_dims/2], not reset for each section + // + // example vision RoPE: + // given sections = [y=4, x=4, 0, 0] (last 2 sections are ignored) + // given a single head with size = 8 --> [00000000] + // GGML_ROPE_TYPE_VISION n_dims = 4 --> [yyyyxxxx] + // other values of n_dims are untested and is undefined behavior + // note: unlike MROPE, the theta for each dim is computed differently for each section + // in other words, idx used for theta: [0123] for y section, then [0123] for x section GGML_API struct ggml_tensor * ggml_rope_multi( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h index 79ee202062b..02d5f221c03 100644 --- a/ggml/include/gguf.h +++ b/ggml/include/gguf.h @@ -77,6 +77,7 @@ extern "C" { }; GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); //GGML_API struct gguf_context * gguf_init_from_buffer(..); @@ -189,6 +190,7 @@ extern "C" { // // write the entire context to a binary file + GGML_API bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta); GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index ddd7f5747ed..ecd054633e8 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -185,9 +185,8 @@ endif() # ggml -if (GGML_BACKEND_DL AND NOT BUILD_SHARED_LIBS) - message(FATAL_ERROR "GGML_BACKEND_DL requires BUILD_SHARED_LIBS") -endif() +# QVAC-18993: GGML_BACKEND_DL works with a static core when PIC is +# enabled below (mirrors the speech-stack patch in qvac-ext-ggml). add_library(ggml-base ../include/ggml.h @@ -200,6 +199,7 @@ add_library(ggml-base ggml.cpp ggml-alloc.c ggml-backend.cpp + ggml-backend-meta.cpp ggml-opt.cpp ggml-threading.cpp ggml-threading.h @@ -265,7 +265,7 @@ function(ggml_add_backend_library backend) target_link_libraries(${backend} PRIVATE ggml-base) target_include_directories(${backend} PRIVATE ..) - if (${BUILD_SHARED_LIBS}) + if (BUILD_SHARED_LIBS OR GGML_BACKEND_DL) target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD) target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED) endif() @@ -469,11 +469,10 @@ endforeach() target_link_libraries(ggml-base PRIVATE Threads::Threads) -find_library(MATH_LIBRARY m) -if (MATH_LIBRARY) - if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) - target_link_libraries(ggml-base PRIVATE m) - endif() +if (DEFINED MATH_LIBRARY) + target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) +elseif (NOT WIN32 AND NOT DEFINED ENV{ONEAPI_ROOT}) + target_link_libraries(ggml-base PRIVATE m) endif() if (CMAKE_SYSTEM_NAME MATCHES "Android") @@ -484,7 +483,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "visionOS") target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE) endif() -if (BUILD_SHARED_LIBS) +if (BUILD_SHARED_LIBS OR GGML_BACKEND_DL) foreach (target ggml-base ggml) set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(${target} PRIVATE GGML_BUILD) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 7f414b2311c..a4b01ccf8a1 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -2,6 +2,7 @@ #include "ggml-backend-impl.h" #include "ggml.h" #include "ggml-impl.h" + #include #include #include @@ -1236,6 +1237,9 @@ size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { size_t nbytes_total = 0; + if (ggml_backend_buft_is_meta(buft)) { + return ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft); + } return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false); } diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 59190b7c465..9c56ec30c5f 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -49,6 +49,10 @@ extern "C" { void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) 2d data copies + void (*set_tensor_2d)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d)(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // clear the entire buffer @@ -80,6 +84,20 @@ extern "C" { GGML_API bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); GGML_API void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + // + // Backend (meta) + // + + GGML_API bool ggml_backend_is_meta (ggml_backend_t backend); + GGML_API bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf); + GGML_API bool ggml_backend_buft_is_meta (ggml_backend_buffer_type_t buft); + + GGML_API size_t ggml_backend_meta_n_backends (ggml_backend_t meta_backend); + GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index); + + // temporary workaround to statically allocate tensors from a context in a deduplicated way: + GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); + // // Backend (stream) // @@ -90,8 +108,10 @@ extern "C" { void (*free)(ggml_backend_t backend); // (optional) asynchronous tensor data access - void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_async) (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async) (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_2d_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); // (optional) complete all pending operations (required if the backend supports async operations) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp new file mode 100644 index 00000000000..c0ffd9a048b --- /dev/null +++ b/ggml/src/ggml-backend-meta.cpp @@ -0,0 +1,2143 @@ +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-cpp.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ggml_backend_meta_device; +struct ggml_backend_meta_buffer_type; +struct ggml_backend_meta_buffer; +struct ggml_backend_meta; + +const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) { + switch (split_axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + return "0"; + case GGML_BACKEND_SPLIT_AXIS_1: + return "1"; + case GGML_BACKEND_SPLIT_AXIS_2: + return "2"; + case GGML_BACKEND_SPLIT_AXIS_3: + return "3"; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + return "MIRRORED"; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: + return "PARTIAL"; + case GGML_BACKEND_SPLIT_AXIS_NONE: + return "NONE"; + case GGML_BACKEND_SPLIT_AXIS_UNKNOWN: + return "UNKNOWN"; + default: + GGML_ABORT("fatal error"); + } +} + +// +// meta backend device +// + +struct ggml_backend_meta_device_context { + std::vector simple_devs; + ggml_backend_meta_get_split_state_t get_split_state; + void * get_split_state_ud; + + std::string name; + std::string description; + + ggml_backend_meta_device_context( + std::vector simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) : + simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) { + name = std::string("Meta("); + description = std::string("Meta("); + for (size_t i = 0; i < simple_devs.size(); i++) { + if (i > 0) { + name += ","; + description += ","; + } + name += ggml_backend_dev_name (simple_devs[i]); + description += ggml_backend_dev_description(simple_devs[i]); + } + name += ")"; + description += ")"; + } + + bool operator<(const ggml_backend_meta_device_context & other) const { + return std::tie(simple_devs, get_split_state, get_split_state_ud) + < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud); + } +}; + +static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev); + +static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->name.c_str(); +} + +static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->description.c_str(); +} + +static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + *free = 0; + *total = 0; + for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) { + size_t tmp_free, tmp_total; + ggml_backend_dev_memory(dev, &tmp_free, &tmp_total); + *free += tmp_free; + *total += tmp_total; + } +} + +static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_META; + + GGML_UNUSED(dev); +} + +static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + // TODO replace placeholders + props->name = ggml_backend_meta_device_get_name(dev); + props->description = ggml_backend_meta_device_get_description(dev); + props->type = ggml_backend_meta_device_get_type(dev); + props->device_id = 0; + + ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ false, // Not implemented. + /* .buffer_from_host_ptr = */ false, // Not implemented. + /* .events = */ false, // Not implemented. + }; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_dev_props tmp_props; + ggml_backend_dev_get_props(simple_dev, &tmp_props); + props->caps.async = props->caps.async && tmp_props.caps.async; + props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer; + props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr; + props->caps.events = props->caps.events && tmp_props.caps.events; + } +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev); + +static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(), + [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); }); +} + +static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft); + if (!ggml_backend_dev_is_meta(dev_buft)) { + return false; + } + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context; + if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) { + return false; + } + for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) { + if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) { + return false; + } + } + return true; +} + +static const ggml_backend_device_i ggml_backend_meta_device_iface = { + /* .get_name = */ ggml_backend_meta_device_get_name, + /* .get_description = */ ggml_backend_meta_device_get_description, + /* .get_memory = */ ggml_backend_meta_device_get_memory, + /* .get_type = */ ggml_backend_meta_device_get_type, + /* .get_props = */ ggml_backend_meta_device_get_props, + /* .init_backend = */ ggml_backend_meta_device_init_backend, + /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ nullptr, + /* .supports_op = */ ggml_backend_meta_device_supports_op, + /* .supports_buft = */ ggml_backend_meta_device_supports_buft, + /* .offload_op = */ nullptr, + /* .event_new = */ nullptr, + /* .event_free = */ nullptr, + /* .event_synchronize = */ nullptr, +}; + +static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) { + return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name; +} + +static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + return meta_dev_ctx->simple_devs.size(); +} + +static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + GGML_ASSERT(index < meta_dev_ctx->simple_devs.size()); + return meta_dev_ctx->simple_devs[index]; +} + +ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) { + GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES); + // TODO: this is not thread-safe - needs to be fixed + static std::vector> ctxs; + static std::map meta_devs; + + std::vector simple_devs; + simple_devs.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_devs.push_back(devs[i]); + } + ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud); + + { + auto it = meta_devs.find(ctx); + if (it != meta_devs.end()) { + return &it->second; + } + } + ctxs.push_back(std::make_unique(ctx)); + + struct ggml_backend_device meta_dev = { + /*iface =*/ ggml_backend_meta_device_iface, + /*reg =*/ nullptr, + /*ctx =*/ ctxs.back().get(), + }; + + auto result = meta_devs.emplace(*ctxs.back(), meta_dev); + return &result.first->second; +} + +// +// meta backend buffer type +// + +struct ggml_backend_meta_buffer_type_context { + std::vector simple_bufts; + + std::string name; + + ggml_backend_meta_buffer_type_context(std::vector simple_bufts) : simple_bufts(std::move(simple_bufts)) { + name = "Meta("; + for (size_t i = 0; i < simple_bufts.size(); i++) { + if (i > 0) { + name += ","; + } + name += ggml_backend_buft_name(simple_bufts[i]); + } + name += ")"; + } + + bool operator<(const ggml_backend_meta_buffer_type_context & other) const { + return simple_bufts < other.simple_bufts; + } +}; + +static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + return meta_buft_ctx->simple_bufts.size(); +} + +static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context; + return meta_buft_ctx->name.c_str(); +} + +static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size()); + return meta_buft_ctx->simple_bufts[index]; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); + +static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alignment = 1; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i)); + max_alignment = std::max(max_alignment, alignment); + GGML_ASSERT(max_alignment % alignment == 0); + } + return max_alignment; +} + +static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_size = SIZE_MAX; + for (size_t i = 0; i < n_simple_bufts; i++) { + max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i))); + } + return max_size; +} + +static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alloc_size = 0; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor); + max_alloc_size = std::max(max_alloc_size, alloc_size); + } + return max_alloc_size; +} + +static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + for (size_t i = 0; i < n_simple_bufts; i++) { + if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) { + return false; + } + } + return true; +} + +static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = { + /* .get_name = */ ggml_backend_meta_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_meta_buffer_type_is_host, +}; + +bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) { + return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) { + static std::map meta_bufts; + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + { + auto it = meta_bufts.find(dev); + if (it != meta_bufts.end()) { + return &it->second; + } + } + + const size_t n_devs = ggml_backend_meta_dev_n_devs(dev); + std::vector simple_bufts; + simple_bufts.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i))); + } + ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts); + + struct ggml_backend_buffer_type meta_buft = { + /*iface =*/ ggml_backend_meta_buffer_type_iface, + /*device =*/ dev, + /*ctx =*/ buft_ctx, + }; + auto result = meta_bufts.emplace(dev, meta_buft); + return &result.first->second; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + ggml_backend_buffer_type_t host_buft = nullptr; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev); + if (simple_host_buft == nullptr) { + return nullptr; + } + if (host_buft == nullptr) { + host_buft = simple_host_buft; + } else if (host_buft != simple_host_buft) { + // if different simple devices have different host buffer types, + // we cannot provide a single host buffer type for the meta device + return nullptr; + } + } + return host_buft; +} + +// +// meta backend buffer +// + +struct ggml_backend_meta_buffer_context { + static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding); + + std::map, std::pair> split_state_cache; + std::map< const ggml_tensor *, std::vector> simple_tensors; + + struct buffer_config { + ggml_context * ctx; + ggml_backend_buffer_t buf; + + buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {} + }; + std::vector buf_configs; + + int debug; + + ggml_backend_meta_buffer_context() { + const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); + debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; + } +}; + +static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + for (auto & [ctx, buf] : buf_ctx->buf_configs) { + ggml_backend_buffer_free(buf); + ggml_free(ctx); + } + delete buf_ctx; +} + +static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + return buf_ctx->buf_configs.size(); +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + GGML_ASSERT(index < buf_ctx->buf_configs.size()); + return buf_ctx->buf_configs[index].buf; +} + +static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + GGML_ASSERT(index < buf_ctx->buf_configs.size()); + + auto it = buf_ctx->simple_tensors.find(tensor); + if (it == buf_ctx->simple_tensors.end()) { + return nullptr; + } + return it->second[index]; +} + +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + + auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool { + if (a.axis != b.axis) { + return false; + } + for (size_t j = 0; j < n_bufs; j++) { + int64_t sum_a = 0; + for (size_t s = 0; s < a.n_segments; s++) { + sum_a += a.ne[s*n_bufs + j]; + } + int64_t sum_b = 0; + for (size_t s = 0; s < b.n_segments; s++) { + sum_b += b.ne[s*n_bufs + j]; + } + if (sum_a != sum_b) { + return false; + } + } + return true; + }; + + auto handle_generic = [&](const std::vector & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { + ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + continue; + } + if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + ret = src_ss[i]; + } else if (!split_states_equal(src_ss[i], ret)) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + break; + } + } + if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + return ret; + }; + + // Some ops process data on a per-row bases: + auto handle_per_row = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0); + return src_ss[0]; + }; + + // Some ops broadcast the src1 data across src0: + auto handle_bin_bcast = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS && + tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis || + (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) { + return src_ss[0]; // GGML_OP_ADD_ID + } + GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_concat = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0)); + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) { + GGML_ASSERT(concat_axis != src_ss[1].axis); + return src_ss[1]; + } + if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + GGML_ASSERT(concat_axis != src_ss[0].axis); + return src_ss[0]; + } + if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) { + return src_ss[0]; + } + return handle_generic(src_ss, /*scalar_only =*/ true); + }; + + auto handle_mul_mat = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_ss[0]; + ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + ret.n_segments = 1; + return ret; + } + if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_ss[1]; + ret.n_segments = 1; + return ret; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, 1}; + } + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + }; + + auto handle_cpy = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + int64_t ne_split_src = tensor->src[0]->ne[0]; + for (int dim = 1; dim <= src_ss[0].axis; dim++) { + ne_split_src *= tensor->src[0]->ne[dim]; + } + int64_t ne_split_dst = 1; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + ne_split_dst *= tensor->ne[dim]; + if (ne_split_dst == ne_split_src) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_reshape = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + GGML_ASSERT(!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0])); + if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1) { + return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, 1}; + } + std::vector base_ne_in; + base_ne_in.reserve(GGML_MAX_DIMS - src_ss[0].axis); + { + base_ne_in.push_back(1); + int dim = 0; + for (; dim <= src_ss[0].axis; dim++) { + base_ne_in[0] *= tensor->src[0]->ne[dim]; + } + for (; dim <= GGML_MAX_DIMS; dim++) { + base_ne_in.push_back(base_ne_in.back() * tensor->src[0]->ne[dim]); + } + } + int64_t base_ne_out = 1; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; + for (const int64_t & bni : base_ne_in) { + if (bni == base_ne_out_next) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + if (base_ne_out_next > base_ne_in[0]) { + GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); + return {ggml_backend_meta_split_axis(dim + 1), {0}, 1}; + } + base_ne_out = base_ne_out_next; + } + GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op)); + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_view = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { + return handle_reshape(src_ss); + } + const int axis = src_ss[0].axis; + { + bool all_strides_the_same = true; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) { + continue; + } + if (tensor->nb[dim] != tensor->src[0]->nb[dim]) { + all_strides_the_same = false; + break; + } + } + if (all_strides_the_same) { + return src_ss[0]; + } + } + if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { + for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { + if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + GGML_ABORT("fatal error"); + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + return src_ss[0]; + } + GGML_ABORT("view of permuted tensor not implemented"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + }; + + auto handle_permute = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, 1}; + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_transpose = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: { + return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, 1}; + } + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_get_rows = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + return handle_generic(src_ss, /*scalar_only =*/ true); + }; + + auto handle_set_rows = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2])); + return src_ss[0]; + }; + + auto handle_rope = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return src_ss[0]; + }; + + auto handle_pad = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0); + GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0); + } + return src_ss[0]; + }; + + auto handle_flash_attn_ext = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT( src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + }; + + auto handle_ssm_conv = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == src_ss[1].axis) { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + } + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_gated_delta_net = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + }; + + auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { + if (ggml_nelements(tensor) == 0) { + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { + ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); + const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud); + if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { + const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; + int64_t ne_sum = 0; + for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { + GGML_ASSERT(ret.ne[sj] % granularity == 0); + ne_sum += ret.ne[sj]; + } + GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); + } + return ret; + } + + std::vector src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}); + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + continue; + } + src_ss[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + } + + ggml_backend_meta_split_state split_state; + switch (tensor->op) { + case GGML_OP_NONE: { + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + } break; + case GGML_OP_DUP: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_ADD: + case GGML_OP_ADD_ID: { + split_state = handle_bin_bcast(src_ss); + } break; + case GGML_OP_ADD1: + case GGML_OP_ACC: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: { + split_state = handle_bin_bcast(src_ss); + } break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SIN: + case GGML_OP_COS: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_SUM: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_CONCAT: { + split_state = handle_concat(src_ss); + } break; + case GGML_OP_SILU_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { + split_state = handle_mul_mat(src_ss); + } break; + case GGML_OP_OUT_PROD: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SCALE: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_SET: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CPY: { + split_state = handle_cpy(src_ss); + } break; + case GGML_OP_CONT: + case GGML_OP_RESHAPE: { + split_state = handle_reshape(src_ss); + } break; + case GGML_OP_VIEW: { + split_state = handle_view(src_ss); + } break; + case GGML_OP_PERMUTE: { + split_state = handle_permute(src_ss); + } break; + case GGML_OP_TRANSPOSE: { + split_state = handle_transpose(src_ss); + } break; + case GGML_OP_GET_ROWS: { + split_state = handle_get_rows(src_ss); + } break; + case GGML_OP_GET_ROWS_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SET_ROWS: { + split_state = handle_set_rows(src_ss); + } break; + case GGML_OP_DIAG: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_DIAG_MASK_ZERO: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_ROPE: { + split_state = handle_rope(src_ss); + } break; + case GGML_OP_ROPE_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CLAMP: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_BACK: + case GGML_OP_IM2COL_3D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + case GGML_OP_POOL_2D_BACK: + case GGML_OP_UPSCALE: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_PAD: { + split_state = handle_pad(src_ss); + } break; + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_ROLL: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_LEAKY_RELU: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_TRI: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_FILL: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_FLASH_ATTN_EXT: { + split_state = handle_flash_attn_ext(src_ss); + } break; + case GGML_OP_FLASH_ATTN_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SSM_CONV: { + split_state = handle_ssm_conv(src_ss); + } break; + case GGML_OP_SSM_SCAN: + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_ADD_REL_POS: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_RWKV_WKV7: + case GGML_OP_SOLVE_TRI: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_GATED_DELTA_NET: { + split_state = handle_gated_delta_net(src_ss); + } break; + case GGML_OP_UNARY: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_MAP_CUSTOM1: + case GGML_OP_MAP_CUSTOM2: + case GGML_OP_MAP_CUSTOM3: + case GGML_OP_CUSTOM: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + case GGML_OP_GLU: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + default: { + GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } break; + } + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + bool first_src_split_by_axis = true; + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) { + continue; + } + if (first_src_split_by_axis) { + for (size_t j = 0; j < n_bufs; j++) { + // Take over ratio from src: + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + split_state.ne[s*n_bufs + j] = 0; + } + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + split_state.ne[j] += src_ss[i].ne[s*n_bufs + j]; + } + split_state.ne[j] *= tensor->ne[split_state.axis]; + if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { + GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_ss[i].axis] == 0); + split_state.ne[j] /= tensor->src[i]->ne[src_ss[i].axis]; + } + } + } else { + for (size_t j = 0; j < n_bufs; j++) { + int64_t sum = 0; + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + sum += src_ss[i].ne[s*n_bufs + j]; + } + // Assert that ratio is consistent: + GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis] + == sum * tensor->ne[split_state.axis]); + } + } + first_src_split_by_axis = false; + } + GGML_ASSERT(!first_src_split_by_axis); + } + return split_state; + }; + + const std::pair key = std::make_pair(tensor, assume_sync); + auto it = buf_ctx->split_state_cache.find(key); + if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) { + buf_ctx->split_state_cache.clear(); + it = buf_ctx->split_state_cache.end(); + } + + if (it == buf_ctx->split_state_cache.end()) { + buf_ctx->split_state_cache[key].first = calculate_split_state(); + memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second)); + if (buf_ctx->debug > 0) { + std::string srcs_info; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr) { + continue; + } + if (!srcs_info.empty()) { + srcs_info += ", "; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(split_state.ne[j]); + } + srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; + } + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(buf_ctx->split_state_cache[key].first.ne[j]); + } + GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), + ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); + } + } + + ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first; + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE); +#ifndef NDEBUG + if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + int64_t ne_ret = 0; + for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { + ne_ret += ret.ne[sj]; + } + assert(ne_ret == tensor->ne[int(ret.axis)]); + } +#endif // NDEBUG + return ret; +} + +static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_UNUSED(buffer); + return (void *) 0x1000000000000000; // FIXME +} + +static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true); + GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + GGML_ASSERT(split_state.n_segments <= 16); + + int split_dim = split_state.axis; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ne[k] = tensor->ne[k]; + nb[k] = tensor->nb[k]; + } + + std::vector simple_tensors; + simple_tensors.reserve(n_simple_bufs); + for (size_t j = 0; j < n_simple_bufs; j++) { + ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx; + ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf; + + if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + // TODO: the following assert fails for llama-parallel even though the results are correct: + // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); + ne[split_dim] = 0; + for (size_t s = 0; s < split_state.n_segments; s++) { + ne[split_dim] += split_state.ne[s*n_simple_bufs + j]; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->nb[i] > tensor->nb[split_dim]) { + nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + + ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne); + t_ij->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + t_ij->nb[i] = nb[i]; + } + t_ij->flags = tensor->flags; + memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params)); + ggml_set_name(t_ij, tensor->name); + t_ij->buffer = simple_buf; + t_ij->view_src = tensor->view_src; + t_ij->view_offs = tensor->view_offs; + if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { + t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); + if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + GGML_ASSERT(tensor->ne[split_dim] != 0); + const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis; + GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS); + + // The offset can be internal to the data split, in those cases the view offset should not be scaled. + // If however, the offset is larger than the data split then it needs to be scaled proportionally. + bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src]; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + const size_t dim_size = tensor->ne[i] * tensor->nb[i]; + if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) { + split_internal_offset = true; + break; + } + } + if (!split_internal_offset) { + t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + if (t_ij->view_src != nullptr) { + t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs; + } else if (simple_buf != nullptr) { + t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf) + + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(buffer)); + } + t_ij->extra = tensor->extra; + for (int i = 0; i < GGML_MAX_SRC; i++) { + t_ij->src[i] = tensor->src[i]; + if (tensor->src[i] == tensor) { + t_ij->src[i] = t_ij; + } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) { + t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j); + } + } + + simple_tensors.push_back(t_ij); + } + + // If one of the sources has a zero-sized slice, disable the computation: + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) { + continue; + } + + const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) { + continue; + } + for (size_t j = 0; j < n_simple_bufs; j++) { + int64_t ne_sum = 0; + for (size_t s = 0; s < split_state_src.n_segments; s++) { + ne_sum += split_state_src.ne[s*n_simple_bufs + j]; + } + if (ne_sum == 0) { + simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + + buf_ctx->simple_tensors[tensor] = simple_tensors; + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + if (split_state.n_segments != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(tensor->ne[3] == 1); + + size_t offset_data = 0; + std::vector simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + + const size_t row_stride = tensor->nb[1]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, + r_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*r_count == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + + const size_t row_stride = tensor->nb[2]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, + r_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*r_count == size); + return; + } + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + const size_t simple_offset = i_start * chunk_size_j; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, data, offset, size); + } + } break; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + const int64_t ne = ggml_nelements(tensor); + std::vector tmp; + tmp.reserve(ne); + for (int64_t i = 0; i < ne; i++) { + tmp.push_back(((const float *) data)[i] / n_bufs); + } + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + if (split_state.n_segments != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(tensor->ne[3] == 1); + + size_t offset_data = 0; + std::vector simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + + const size_t row_stride = tensor->nb[1]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, + r_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*r_count == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + + const size_t row_stride = tensor->nb[2]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, + r_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*r_count == size); + return; + } + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++){ + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + const size_t simple_offset = i_start * chunk_size_j; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get(simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value); + } +} + +static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i)); + } +} + +static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = { + /* .free_buffer = */ ggml_backend_meta_buffer_free_buffer, + /* .get_base = */ ggml_backend_meta_buffer_get_base, + /* .init_tensor = */ ggml_backend_meta_buffer_init_tensor, + /* .memset_tensor = */ nullptr, // TODO implement + /* .set_tensor = */ ggml_backend_meta_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_meta_buffer_get_tensor, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_meta_buffer_clear, + /* .reset = */ ggml_backend_meta_buffer_reset, +}; + +bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) { + return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + ggml_init_params params = { + /*.mem_size =*/ 1024*1024*1024, // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(); + size_t max_size = 0; + buf_ctx->buf_configs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); + max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); + buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); + } + + return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size); +} + +struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + ggml_init_params params = { + /*.mem_size =*/ 1024*1024*1024, // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(); + meta_buf_ctx->buf_configs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr); + } + + ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf; + ggml_backend_meta_buffer_init_tensor(meta_buf, t); + t->data = (void *) 0x2000000000000000; // FIXME + } + for (size_t i = 0; i < n_simple_bufts; i++) { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft( + meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i)); + meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); + } + return meta_buf; +} + +// +// meta backend +// + +static ggml_guid_t ggml_backend_meta_guid() { + static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda}; + return &guid; +} + +struct ggml_backend_meta_context { + struct cgraph_config { + ggml_cgraph * cgraph_main = nullptr; + int offset = 0; // Node offset vs. original graph + + std::vector cgraphs_aux; + }; + struct backend_config { + ggml_backend_t backend; + + std::vector cgraphs; + std::vector nodes; + std::vector bufs; + + backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) { + bufs.resize(n_reduce_steps); + } + }; + std::string name; + std::vector backend_configs; + ggml_context_ptr ctx; + std::vector cgraphs_aux; + std::vector nodes_aux; + size_t n_reduce_steps; + int max_nnodes = 0; + size_t max_tmp_size = 0; + size_t max_subgraphs = 0; + size_t n_subgraphs = 0; + uint64_t uid = 0; + + void * comm_ctx = nullptr; + ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; + + ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { + const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); + n_reduce_steps = std::ceil(std::log2(n_devs)); + name = "Meta("; + std::vector simple_backends; + backend_configs.reserve(n_devs); + simple_backends.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i); + if (i > 0) { + name += ","; + } + name += ggml_backend_dev_name(simple_dev); + simple_backends.push_back(ggml_backend_dev_init(simple_dev, params)); + backend_configs.emplace_back(simple_backends.back(), n_reduce_steps); + } + name += ")"; + + if (n_devs > 1) { + ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init"); + if (comm_init != nullptr) { + comm_ctx = comm_init(simple_backends.data(), simple_backends.size()); + } + } + if (comm_ctx != nullptr) { + comm_allreduce = (ggml_backend_comm_allreduce_tensor_t) + ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg( + ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor"); + GGML_ASSERT(comm_allreduce != nullptr); + } + } + + ~ggml_backend_meta_context() { + if (comm_ctx != nullptr) { + ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free"); + GGML_ASSERT(comm_free != nullptr); + comm_free(comm_ctx); + } + for (auto & bc : backend_configs) { + ggml_backend_free(bc.backend); + } + } +}; + +static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context; + return backend_ctx->name.c_str(); +} + +static void ggml_backend_meta_free(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + delete backend_ctx; + delete backend; +} + +static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_backends; j++) { + ggml_backend_tensor_set_async( + ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_synchronize(ggml_backend_t backend) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + for (size_t i = 0; i < n_backends; i++) { + ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i)); + } +} + +static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(cgraph->grads == nullptr); + const size_t n_backends = ggml_backend_meta_n_backends(backend); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + + // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend. + const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid); + + bool max_nnodes_raised = false; + if (cgraph->n_nodes > backend_ctx->max_nnodes) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.nodes.resize(cgraph->n_nodes); + bcj.cgraphs.resize(cgraph->n_nodes); + } + backend_ctx->max_nnodes = cgraph->n_nodes; + max_nnodes_raised = true; + assert(needs_rebuild); + } + + if (needs_rebuild) { + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. + // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. + bcj.nodes[i] = node; + continue; + } + bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); + GGML_ASSERT(bcj.nodes[i]); + } + } + + { + // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: + auto get_i_delayed = [&](const int i) -> int { + int id = i; // i_delayed + int idr = i; // i_delayed return, last safe return value + + ggml_tensor * node = cgraph->nodes[id]; + int32_t n_used = ggml_node_get_use_count(cgraph, id); + + // Skip MIRRORED nodes that don't consume node + auto skip_unrelated = [&]() { + while (id + 1 < cgraph->n_nodes) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + break; + } + bool safe = true; + for (int s = 0; s < GGML_MAX_SRC; s++) { + if (next->src[s] == nullptr) { + continue; + } + if (next->src[s] == node) { + safe = false; + break; + } + if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + safe = false; + break; + } + } + if (!safe) { + break; + } + id++; + } + }; + + skip_unrelated(); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_ADD_ID && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && + ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + // Chain of MULs with MIRRORED src[1] + while (true) { + skip_unrelated(); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_MUL && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } else { + break; + } + } + + if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + return idr; + } + for (int32_t k = 0; k < n_used; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || + next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || + ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + for (int32_t k = 0; k < n_used - 2; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + idr = id; + return idr; + }; + + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + continue; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); + } + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; + } + + const int i_delayed = get_i_delayed(i); + + // If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices. + // A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has + // its compute flag disabled and thus gets its data zeroed out. + // If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled. + if (i_delayed > i) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + for (int ii = i + 1; ii <= i_delayed; ii++) { + bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + } + + i = i_delayed; + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs[n_subgraphs].offset = i_start; + } + n_subgraphs++; + i_start = i + 1; + } + GGML_ASSERT(i_start == cgraph->n_nodes); + } + + backend_ctx->uid = cgraph->uid; + backend_ctx->n_subgraphs = n_subgraphs; + + if (max_tmp_size > backend_ctx->max_tmp_size) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) { + bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + } + backend_ctx->max_tmp_size = max_tmp_size; + } + + if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { + backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); + const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device + const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device + const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); + const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); + const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); + ggml_init_params params = { + /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + backend_ctx->ctx.reset(ggml_init(params)); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < n_subgraphs; i++) { + bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + } + } + backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { + backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); + } + backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { + backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); + } + } + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { + ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; + const size_t i_node_start = bcj.cgraphs[i_graph].offset; + const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; + cgraph_ij->n_nodes = i_node_stop - i_node_start; + ggml_hash_set_reset(&cgraph_ij->visited_hash_set); + for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { + ggml_tensor * node_ij = bcj.nodes[i_node]; + cgraph_ij->nodes[i_node - i_node_start] = node_ij; + const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); + const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); + cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + } + cgraph_ij->uid = ggml_graph_next_uid(); + } + } + } + + size_t iga = 0; // i graph aux + size_t ina = 0; // i node aux + + auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * { + ggml_tensor * ret = backend_ctx->nodes_aux[ina++]; + memset(ret, 0, sizeof(ggml_tensor)); + ret->op = GGML_OP_NONE; + ret->type = t->type; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ret->ne[k] = t->ne[k]; + ret->nb[k] = t->nb[k]; + } + return ret; + }; + auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf]; + if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) { + buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size)); + } + tensor->buffer = buf_ptr.get(); + tensor->data = ggml_backend_buffer_get_base(buf_ptr.get()); + }; + // FIXME usage_counts + auto get_cgraph_aux = [&]() -> ggml_cgraph * { + ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; + return ret; + }; + + // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: + auto allreduce_fallback = [&](size_t i) -> ggml_status { + std::vector step_cgraphs(n_backends, nullptr); + + // Zero out nodes that were disabled due to having a zero-sized slice: + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1]; + if (node->flags & GGML_TENSOR_FLAG_COMPUTE) { + continue; + } + ggml_tensor * node_zero = get_node_aux(node); + node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN + node_zero->src[0] = node; + ggml_set_op_params_f32(node_zero, 0, 0.0f); + node_zero->data = node->data; + node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; + + step_cgraphs[j] = get_cgraph_aux(); + step_cgraphs[j]->nodes[0] = node_zero; + step_cgraphs[j]->n_nodes = 1; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) { + assert(step_cgraphs[j_dst] == nullptr); + auto & bcj_src = backend_ctx->backend_configs[j_src]; + auto & bcj_dst = backend_ctx->backend_configs[j_dst]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + GGML_ASSERT(ggml_is_contiguous(node_src)); + GGML_ASSERT(ggml_is_contiguous(node_dst)); + + ggml_tensor * node_tmp = get_node_aux(node_dst); + set_tmp_data(node_tmp, j_dst, i_buf); + + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp); + + ggml_tensor * node_red = get_node_aux(node_dst); + node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src; + node_red->view_offs = node_dst->view_offs; + node_red->op = GGML_OP_ADD; + node_red->src[0] = node_dst; + node_red->src[1] = node_tmp; + node_red->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red); + + ggml_cgraph * cgraph_aux = get_cgraph_aux(); + cgraph_aux->nodes[0] = node_red; + cgraph_aux->n_nodes = 1; + step_cgraphs[j_dst] = cgraph_aux; + }; + + size_t offset_j = n_backends/2; + while ((offset_j & (offset_j - 1)) != 0) { + offset_j--; + } + const size_t offset_j_max = offset_j; + size_t i_buf = 0; + + // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction: + for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) { + const size_t j_dst = j_src - 2*offset_j_max; + push_data(j_src, j_dst, i_buf); + const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + i_buf = 1; + } + + // Butterfly reduction: + for (; offset_j >= 1; offset_j /= 2) { + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + for (size_t j = 0; j < 2*offset_j_max; j++) { + const size_t j_other = j ^ offset_j; + if (j_other >= n_backends) { + continue; + } + push_data(j, j_other, i_buf); + } + + for (size_t j = 0; j < 2*offset_j_max; j++) { + if (step_cgraphs[j] == nullptr) { + continue; + } + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + i_buf++; + } + assert(i_buf == backend_ctx->n_reduce_steps); + + // If n_backends is not a power of 2, copy back the reduced tensors to the excess: + for (size_t j = 2*offset_j_max; j < n_backends; j++) { + auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max]; + auto & bcj_dst = backend_ctx->backend_configs[j]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst); + } + + return GGML_STATUS_SUCCESS; + }; + + + for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + + if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) { + bool backend_allreduce_success = false; + if (backend_ctx->comm_ctx) { + std::vector nodes; + nodes.reserve(n_backends); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main; + nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]); + } + backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data()); + } + + if (!backend_allreduce_success) { + const ggml_status status = allreduce_fallback(i); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + } + } + return GGML_STATUS_SUCCESS; +} + +static const ggml_backend_i ggml_backend_meta_i = { + /* .get_name = */ ggml_backend_meta_get_name, + /* .free = */ ggml_backend_meta_free, + /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async, + /* .set_tensor_2d_async = */ nullptr, + /* .get_tensor_2d_async = */ nullptr, + /* .cpy_tensor_async = */ nullptr, + /* .synchronize = */ ggml_backend_meta_synchronize, + /* .graph_plan_create = */ nullptr, + /* .graph_plan_free = */ nullptr, + /* .graph_plan_update = */ nullptr, + /* .graph_plan_compute = */ nullptr, + /* .graph_compute = */ ggml_backend_meta_graph_compute, + /* .event_record = */ nullptr, + /* .event_wait = */ nullptr, + /* .graph_optimize = */ nullptr, +}; + +bool ggml_backend_is_meta(ggml_backend_t backend) { + return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name; +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) { + ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params); + + ggml_backend_t backend = new struct ggml_backend; + backend->guid = ggml_backend_meta_guid(); + backend->iface = ggml_backend_meta_i; + backend->device = dev; + backend->context = backend_ctx; + return backend; +} + +size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs.size(); +} + +ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs[index].backend; +} + diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 0587109212e..f3c52baff82 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -181,6 +181,12 @@ struct ggml_backend_registry { return; } + for (auto & entry : backends) { + if (entry.reg == reg) { + return; + } + } + #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg)); @@ -192,6 +198,12 @@ struct ggml_backend_registry { } void register_device(ggml_backend_dev_t device) { + for (auto & dev : devices) { + if (dev == device) { + return; + } + } + #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device)); #endif @@ -534,6 +546,97 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } } } +#ifdef __ANDROID__ + // QVAC-18993: Android app packaging keeps native libraries + // compressed inside the APK with no on-disk directory to scan + // (the default since AGP 3.6's `useLegacyPackaging=false`). + // The directory-iterator loop above therefore finds nothing on + // Android and the per-search_path `fs::exists` fallback also + // returns false. Try the bare filename instead and let Android's + // dynamic linker resolve it via the in-APK lookup. + // + // For backends that ship as a single library (Vulkan / OpenCL / + // ...) the bare `libggml-.so` filename is enough. + // For the CPU backend with `GGML_CPU_ALL_VARIANTS=ON` there is + // no plain `libggml-cpu.so`, only the per-arch + // `libggml-cpu-android_armv*_*.so` files, so we also + // try each known per-arch variant and let `ggml_backend_score` + // pick the highest-scoring one the device's HWCAP supports. + // + // Mirrors qvac-ext-ggml@speech commit 9562ed04 ("ggml-backend: + // android per-arch CPU variant dlopen fallback") so APK + // consumers of the bundled-ggml whisper-cpp port get the same + // Android-correct CPU init as ggml-speech consumers (parakeet, + // chatterbox, ...). + // + // TODO: keep this list in sync with the + // `ggml_add_cpu_backend_variant(android_armv*_*)` calls in + // ggml/src/CMakeLists.txt. New tiers added there must be + // appended here as well, or devices on the new tier will + // silently fall back to a lower one (perf hit). + std::vector candidate_names = { name_path }; + if (strcmp(name, "cpu") == 0) { + candidate_names.emplace_back("cpu-android_armv8.0_1"); + candidate_names.emplace_back("cpu-android_armv8.2_1"); + candidate_names.emplace_back("cpu-android_armv8.2_2"); + candidate_names.emplace_back("cpu-android_armv8.6_1"); + candidate_names.emplace_back("cpu-android_armv9.0_1"); + candidate_names.emplace_back("cpu-android_armv9.2_1"); + candidate_names.emplace_back("cpu-android_armv9.2_2"); + } + + if (candidate_names.size() == 1) { + // Fast path: Vulkan / OpenCL / Metal / ... single-shot dlopen. + fs::path filename = backend_filename_prefix().native() + + candidate_names[0].native() + + backend_filename_extension().native(); + if (auto reg = get_reg().load_backend(filename, silent)) { + return reg; + } + return nullptr; + } + + // Multi-candidate (Android CPU today): iterate worst -> best with + // a synthetic per-index offset on top of the runtime score so + // that on a device that accepts every variant (e.g. armv9.2 + // phone) the highest tier wins on tie, while a device that + // legitimately supports only the baseline still picks armv8.0_1. + for (size_t idx = 0; idx < candidate_names.size(); ++idx) { + fs::path filename = backend_filename_prefix().native() + + candidate_names[idx].native() + + backend_filename_extension().native(); + dl_handle_ptr handle { dl_load_library(filename) }; + if (!handle) { + if (!silent) { + GGML_LOG_DEBUG("%s: dlopen(%s) failed: %s\n", __func__, + path_str(filename).c_str(), dl_error()); + } + continue; + } + auto score_fn = (ggml_backend_score_t) + dl_get_sym(handle.get(), "ggml_backend_score"); + int s = 1; // base score for backends without ggml_backend_score + if (score_fn) { + s = score_fn(); + if (s == 0) { + continue; + } + } + s += static_cast(idx); +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, + path_str(filename).c_str(), s); +#endif + if (s > best_score) { + best_score = s; + best_path = filename; + } + } + + if (best_score > 0) { + return get_reg().load_backend(best_path, silent); + } +#endif // __ANDROID__ return nullptr; } diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 22c656996cc..d9f8aaec52f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -123,7 +123,7 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { GGML_ASSERT(buffer); // get_base is optional if the buffer is zero-sized - if (buffer->size == 0) { + if (!ggml_backend_buffer_is_meta(buffer) && buffer->size == 0) { return NULL; } @@ -279,15 +279,57 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten } } +void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set_async(backend, tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.set_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); @@ -297,18 +339,62 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); buf->iface.get_tensor(buf, tensor, data, offset, size); } +void ggml_backend_tensor_set_2d(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set(tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + buf->iface.set_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + buf->iface.get_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -388,7 +474,7 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { // backend copy -void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -402,7 +488,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); -#endif +#endif // NDEBUG size_t nbytes = ggml_nbytes(src); void * data = malloc(nbytes); ggml_backend_tensor_get(src, data, 0, nbytes); @@ -411,7 +497,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } } -void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -500,6 +586,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { } void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) { + GGML_ASSERT(device); memset(props, 0, sizeof(*props)); device->iface.get_props(device, props); } @@ -610,6 +697,8 @@ static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = { /* .memset_tensor = */ NULL, /* .set_tensor = */ NULL, /* .get_tensor = */ NULL, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_multi_buffer_clear, /* .reset = */ NULL, @@ -941,6 +1030,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra GGML_ABORT("%s: failed to initialize context\n", __func__); } + graph->uid = ggml_graph_next_uid(); + // pass 1: assign backends to ops with pre-allocated inputs for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; @@ -1388,6 +1479,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra assert(graph_copy->size > graph_copy->n_leafs); graph_copy->leafs[graph_copy->n_leafs++] = leaf; } + + // set ids for all splits + for (int i = 0; i < sched->n_splits; ++i) { + sched->splits[i].graph.uid = ggml_graph_next_uid(); + } } static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { @@ -1899,8 +1995,9 @@ enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->view_src == NULL); GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); - GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= - (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer) || + (char *) addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *) ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); tensor->buffer = buffer; tensor->data = addr; @@ -2174,6 +2271,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, @@ -2186,6 +2285,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 5de64b816fc..b4c735267e0 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -121,6 +121,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg bli_thread_set_num_threads(ctx->n_threads); #elif defined(GGML_BLAS_USE_NVPL) nvpl_blas_set_num_threads(ctx->n_threads); +#elif defined(GGML_BLAS_USE_MKL) + mkl_set_num_threads(ctx->n_threads); #endif for (int64_t i13 = 0; i13 < ne13; i13++) { @@ -261,6 +263,8 @@ static struct ggml_backend_i blas_backend_i = { /* .free = */ ggml_backend_blas_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index fc7c3e3b724..2dc0f40917d 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -25,6 +25,7 @@ #include "ggml-impl.h" #include "ggml.h" + #include #include #include @@ -45,7 +46,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -62,6 +65,7 @@ #include #include #include +#include #include #include #include @@ -69,11 +73,15 @@ #include #include #include +#include #include #include #include #include +#include #include +#include +#include #include #include #include @@ -151,6 +159,107 @@ void ggml_cann_op_unary_gated(std::functionsrc[1] != nullptr || swapped != 0) { + ggml_cann_op_unary_gated(silu_fn, ctx, dst); + return; + } + + // aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise. + if (dst->src[0]->ne[0] % 2 != 0) { + ggml_cann_op_unary_gated(silu_fn, ctx, dst); + return; + } + + ggml_tensor * src0 = dst->src[0]; + size_t elem_size = ggml_element_size(src0); + + // src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3] + // CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last). + int64_t ne0_x2 = src0->ne[0]; + int64_t ne1 = src0->ne[1]; + int64_t ne23 = src0->ne[2] * src0->ne[3]; + int64_t src3d_ne[] = { ne0_x2, ne1, ne23 }; + size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] }; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), + elem_size, src3d_ne, src3d_nb, 3); + + // dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3] + int64_t ne0 = dst->ne[0]; + int64_t dst3d_ne[] = { ne0, ne1, ne23 }; + size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] }; + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + elem_size, dst3d_ne, dst3d_nb, 3); + + // CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0. + GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get()); +} + +// Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim), +// activates the LEFT half with GELU, multiplies by right half. +// approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention. +// outGelu is a required-but-discard output buffer. +// +// Falls back to the generic two-kernel path when src[1] != nullptr (two +// independent halves) or swapped != 0 (reversed activation order), as +// aclnnGeGluV3 only handles the single interleaved tensor in standard order. +void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) { + auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst); + }; + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + if (dst->src[1] != nullptr || swapped != 0) { + ggml_cann_op_unary_gated(gelu_fn, ctx, dst); + return; + } + + // aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise. + if (dst->src[0]->ne[0] % 2 != 0) { + ggml_cann_op_unary_gated(gelu_fn, ctx, dst); + return; + } + + ggml_tensor * src0 = dst->src[0]; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + + // Allocate a temporary buffer for the required outGelu output (same shape as dst). + // Build contiguous strides since the pool allocation is a fresh buffer. + size_t elem_size = ggml_element_size(dst); + int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] }; + size_t nb[GGML_MAX_DIMS]; + nb[0] = elem_size; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1]; + ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size); + + acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor( + gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS); + // V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention. + // GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor). + GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true, + acl_dst.get(), acl_gelu_out.get()); +} + /** * @brief Repeats elements of a tensor along each dimension according to the * specified repeat array. @@ -434,6 +543,9 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); @@ -442,21 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes); void * buffer = temp_buffer_allocator.get(); - int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; - size_t div_nb[GGML_MAX_DIMS]; - div_nb[0] = sizeof(float); + int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; + size_t norm_nb[GGML_MAX_DIMS]; + norm_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; ++i) { - div_nb[i] = div_nb[i - 1] * div_ne[i - 1]; + norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1]; } - acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS); std::vector norm_dims = { 3 }; acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size()); float p_value = 2.0f; acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get()); + + ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool()); + acl_tensor_ptr acl_clamped; + + if (eps > 0.0f) { + void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes); + acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS); + acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get()); + } + + aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get(); + GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get()); } void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -472,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * logits_nb[1] = logits_nb[0] * logits_ne[0]; acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); - size_t log_softmax_type_size = sizeof(float); - int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size; - ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes); - void * log_softmax_buffer = log_softmax_allocator.get(); - - int64_t log_softmax_ne[] = { nc, nr }; - size_t log_softmax_nb[2]; - log_softmax_nb[0] = log_softmax_type_size; - log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0]; - acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size, - log_softmax_ne, log_softmax_nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get()); - int64_t labels_ne[] = { nc, nr }; size_t labels_nb[2]; labels_nb[0] = ggml_type_size(src1->type); labels_nb[1] = labels_nb[0] * labels_ne[0]; acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2); - size_t mul_type_size = sizeof(float); - int64_t mul_n_bytes = nr * nc * mul_type_size; - ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes); - void * mul_buffer = mul_allocator.get(); + size_t loss_per_sample_type_size = sizeof(float); + int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size; + ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes); + void * loss_per_sample_buffer = loss_per_sample_allocator.get(); - int64_t mul_ne[] = { nc, nr }; - size_t mul_nb[2]; - mul_nb[0] = mul_type_size; - mul_nb[1] = mul_nb[0] * mul_ne[0]; - acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2); + int64_t loss_per_sample_ne[] = { nr }; + size_t loss_per_sample_nb[1]; + loss_per_sample_nb[0] = loss_per_sample_type_size; + acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor( + loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1); - GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get()); + size_t backprop_n_bytes = nr * nc * sizeof(float); + ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes); + void * backprop_buffer = backprop_allocator.get(); + acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); - size_t sum_per_sample_type_size = sizeof(float); - int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size; - ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes); - void * sum_per_sample_buffer = sum_per_sample_allocator.get(); - - int64_t sum_per_sample_ne[] = { nr }; - size_t sum_per_sample_nb[1]; - sum_per_sample_nb[0] = sum_per_sample_type_size; - acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor( - sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1); - - std::vector sum_dims = { 1 }; - acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size()); - bool keep_dims = false; - - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT, - acl_sum_per_sample.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(), + acl_loss_per_sample.get(), acl_backprop.get()); size_t total_sum_type_size = sizeof(float); int64_t total_sum_n_bytes = 1 * total_sum_type_size; @@ -537,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * std::vector total_sum_dims = { 0 }; acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size()); + bool keep_dims = false; - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, acl_total_sum.get()); - float value = -1.0f / static_cast(nr); + float value = 1.0f / static_cast(nr); acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1); @@ -579,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { acl_mean_out.get(), acl_rstd_out.get()); } +void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 }; + + // Create a view of dst at the target offset with src1's dimensions + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); + acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1); + + if (!inplace) { + // First copy src0 to dst entirely + size_t cpy_size = ggml_nbytes(dst); + ACL_CHECK( + aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + } + + // Copy src1 into the target region of dst + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get()); +} + void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; @@ -642,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { aclnn_reduce_sum(ctx, dst, reduce_dims, 4); } +void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + // GGML cumsum operates along dim 0 (innermost / ne[0]). + // ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0], + // so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor). + GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3, + ggml_cann_type_mapping(dst->type), acl_dst.get()); +} + +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular + ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3] + + acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1); + acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst); + + // mOut: triangular copy of A (required output), same shape as A. + const size_t a_bytes = ggml_nbytes(src0); + ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes); + acl_tensor_ptr acl_m = ggml_cann_create_tensor( + m_alloc.get(), ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); + + // Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false. + GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve, + acl_b.get(), acl_a.get(), false, false, false, + acl_x.get(), acl_m.get()); +} + +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(src->ne[1] == 1); + + const int64_t N = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + // Fill dst with zeros. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + { + float zero = 0.0f; + acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get()); + } + + // Copy src vector onto the diagonal of dst via strided views. + // src viewed as [N, n_batch], contiguous strides. + int64_t ne_vec[2] = { N, n_batch }; + size_t nb_src_vec[2] = { nb_f32, N * nb_f32 }; + // dst diagonal view: stride (N+1)*4 steps along the diagonal. + size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 }; + + acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2); + acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get()); +} + +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + float c = ggml_get_op_params_f32(dst, 0); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get()); +} + +void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + const int64_t S = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + int64_t ne3d[3] = { S, S, n_batch }; + size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 }; + + const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3); + + switch (ttype) { + case GGML_TRI_TYPE_LOWER: + // Tril(-1): preserve row > col (strict lower), zero upper + diagonal. + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get()); + break; + case GGML_TRI_TYPE_UPPER_DIAG: + // Triu(0): preserve row <= col (upper + diagonal), zero strict lower. + GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get()); + break; + case GGML_TRI_TYPE_UPPER: + // Triu(1): preserve row < col (strict upper), zero lower + diagonal. + GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get()); + break; + case GGML_TRI_TYPE_LOWER_DIAG: + // Tril(0): preserve row >= col (lower + diagonal), zero strict upper. + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get()); + break; + default: + GGML_ABORT("unsupported tri type"); + } +} + void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); @@ -1544,8 +1777,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, end = 2 * ((n_head - 1) - n_head_log2) + 1; step = 2; count = n_head - n_head_log2; - aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step, - dtype); + aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1, + step, dtype); } } @@ -1685,150 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get()); } -/** - * @brief Performs index select operation on a 4D tensor using the CANN backend. - * - * This function applies the `IndexSelect` operation along a specific dimension - * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`). - * It iterates over the last two dimensions of the source tensor, creates the corresponding - * CANN tensors for the source, index, and output slices, and executes the `IndexSelect` - * operation for each slice. - * - * @param ctx The context for CANN backend operations. - * @param src_buffer The source buffer containing the 4D input tensor data. - * @param src_ne The dimensions of the source tensor. - * @param src_nb The strides (byte offsets) of the source tensor. - * @param dst_buffer The destination buffer where the output tensor data will be written. - * @param dst_ne The dimensions of the destination tensor. - * @param dst_nb The strides (byte offsets) of the destination tensor. - * @param index The index tensor specifying the indices to select from the source tensor. - * @param type The data type of the source and destination tensors. - */ -static void aclnn_index_select_4d(ggml_backend_cann_context & ctx, - void * src_buffer, - int64_t * src_ne, - size_t * src_nb, - void * dst_buffer, - int64_t * dst_ne, - size_t * dst_nb, - ggml_tensor * index, - ggml_type type) { - for (int64_t i = 0; i < src_ne[3]; i++) { - for (int64_t j = 0; j < src_ne[2]; j++) { - // src - acl_tensor_ptr acl_src_tensor = - ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); - - // index - acl_tensor_ptr acl_index = ggml_cann_create_tensor( - (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); - - // out - acl_tensor_ptr acl_out = - ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get()); - } - } -} - -/** - * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend. - * - * This function applies the `IndexCopy` operation along a specific dimension of the - * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`) - * to positions specified by the index tensor (`index`). - * It iterates over the last two dimensions of the tensors, creates the corresponding - * CANN tensors for source, index, and destination slices, and performs the index copy - * operation for each slice. - * - * @param ctx The context for CANN backend operations. - * @param src_buffer The source buffer containing the 4D input tensor data to be copied. - * @param src_ne The dimensions of the source tensor. - * @param src_nb The strides (byte offsets) of the source tensor. - * @param dst_buffer The destination buffer where values will be copied to. - * @param dst_ne The dimensions of the destination tensor. - * @param dst_nb The strides (byte offsets) of the destination tensor. - * @param index The index tensor specifying target positions in the destination tensor. - * @param type The data type of the source and destination tensors. - */ -static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx, - void * src_buffer, - int64_t * src_ne, - size_t * src_nb, - void * dst_buffer, - int64_t * dst_ne, - size_t * dst_nb, - ggml_tensor * index, - ggml_type type) { - for (int64_t i = 0; i < src_ne[3]; i++) { - for (int64_t j = 0; j < src_ne[2]; j++) { - // src - acl_tensor_ptr acl_src_tensor = - ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); - - // index - acl_tensor_ptr acl_index = ggml_cann_create_tensor( - (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); - - // out - acl_tensor_ptr acl_out = - ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get()); - } - } -} void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // src + ggml_tensor * src0 = dst->src[0]; // weight ggml_tensor * src1 = dst->src[1]; // index - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 + || dst->type == GGML_TYPE_BF16); + + // n_idx: number of row indices per (i2, i3) batch slice. + // ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1. + const int64_t n_idx = src1->ne[0]; + + // Gather all (i2, i3) batch slices from src into dst. + // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0]. + // GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis). + // nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape, + // nb[2..3] for computing per-batch-slice base pointer offsets). + auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size, + const size_t * nb) { + int64_t src_ne[2] = { src0->ne[0], src0->ne[1] }; + size_t src_nb_2d[2] = { nb[0], nb[1] }; + int64_t dst_ne[2] = { src0->ne[0], n_idx }; + size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] }; + int64_t idx_ne[1] = { n_idx }; + size_t idx_nb[1] = { (size_t)ggml_element_size(src1) }; + + for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) { + acl_tensor_ptr acl_src = ggml_cann_create_tensor( + (char *)src_base + i3 * nb[3] + i2 * nb[2], + acl_type, type_size, src_ne, src_nb_2d, 2); + acl_tensor_ptr acl_idx = ggml_cann_create_tensor( + (char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1], + ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1), + idx_ne, idx_nb, 1); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2], + acl_type, type_size, dst_ne, dst_nb_2d, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get()); + } + } + }; switch (src0->type) { + case GGML_TYPE_BF16: case GGML_TYPE_F16: case GGML_TYPE_F32: if (src0->type == dst->type) { - aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + gather_batched(src0->data, + ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + src0->nb); } else { - acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); - void * src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = dst->nb[0]; + // Cast src0 to dst type, then gather. + ggml_cann_pool_alloc src_cast_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_element_size(dst)); + size_t src_cast_nb[GGML_MAX_DIMS]; + src_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1]; } - acl_tensor_ptr src_trans_tensor = - ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); - aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor( + src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_cast_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type)); + + gather_batched(src_cast_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src_cast_nb); } break; case GGML_TYPE_Q8_0: { - // add 1 dim for bcast mul. + // Dequantize Q8_0 to dst type, then gather. size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1]; int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; - int64_t scale_offset = 0; - // [3,4,5,64] -> [3,4,5,2,32] - weight_ne[0] = QK8_0; - weight_ne[1] = src0->ne[0] / QK8_0; - weight_nb[0] = sizeof(int8_t); - weight_nb[1] = weight_nb[0] * weight_ne[0]; + weight_ne[0] = QK8_0; + weight_ne[1] = src0->ne[0] / QK8_0; + weight_nb[0] = sizeof(int8_t); + weight_nb[1] = weight_nb[0] * weight_ne[0]; for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { weight_ne[i] = src0->ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,1] scale_ne[0] = 1; scale_ne[1] = src0->ne[0] / QK8_0; scale_nb[0] = sizeof(uint16_t); @@ -1837,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { scale_ne[i] = src0->ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,32] dequant_ne = weight_ne; dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; } - scale_offset = ggml_nelements(src0) * sizeof(int8_t); - ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(), - ggml_nelements(src0) * ggml_type_size(dst->type)); - acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), - weight_ne, weight_nb, GGML_MAX_DIMS + 1); - acl_tensor_ptr acl_scale_tensor = - ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, - GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); - acl_tensor_ptr dequant_tensor = - ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); - aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get()); - dequant_nb[0] = ggml_type_size(dst->type); + const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t); + ggml_cann_pool_alloc dequant_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), + weight_ne, weight_nb, GGML_MAX_DIMS + 1); + acl_tensor_ptr acl_scale = ggml_cann_create_tensor( + src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, + GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); + acl_tensor_ptr acl_dequant = ggml_cann_create_tensor( + dequant_allocator.get(), ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); + aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get()); + + // Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides. dequant_ne = src0->ne; + dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; } - aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, - dst->nb, src1, dst->type); + gather_batched(dequant_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + dequant_nb); break; } default: @@ -1871,30 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // src - ggml_tensor * src1 = dst->src[1]; // index + ggml_tensor * src0 = dst->src[0]; // source values + ggml_tensor * src1 = dst->src[1]; // row indices + + // n_idx: number of source rows to scatter per batch slice. + // ggml guarantees: src0->ne[1] == src1->ne[0]. + const int64_t n_idx = src1->ne[0]; + + // Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index. + // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst. + // InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis). + // src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape, + // nb[2..3] for computing per-batch-slice base pointer offsets). + auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size, + const size_t * src_nb) { + int64_t d_ne[2] = { dst->ne[0], dst->ne[1] }; + size_t d_nb[2] = { dst->nb[0], dst->nb[1] }; + int64_t s_ne[2] = { dst->ne[0], n_idx }; + size_t s_nb_2d[2] = { src_nb[0], src_nb[1] }; + int64_t i_ne[1] = { n_idx }; + size_t i_nb[1] = { (size_t)ggml_element_size(src1) }; + + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) { + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2], + acl_type, type_size, d_ne, d_nb, 2); + acl_tensor_ptr acl_idx = ggml_cann_create_tensor( + (char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1], + ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1), + i_ne, i_nb, 1); + acl_tensor_ptr acl_src = ggml_cann_create_tensor( + (char *)src_base + i3 * src_nb[3] + i2 * src_nb[2], + acl_type, type_size, s_ne, s_nb_2d, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get()); + } + } + }; switch (dst->type) { case GGML_TYPE_F32: - { - aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type); - break; - } + scatter_batched(src0->data, + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->nb); + break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: { - acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); - void * src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(uint16_t); + // Cast src0 (F32) to dst type first. + ggml_cann_pool_alloc src_cast_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + size_t src_cast_nb[GGML_MAX_DIMS]; + src_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1]; } - acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); - aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor( + src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_cast_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type)); + + scatter_batched(src_cast_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src_cast_nb); break; } default: @@ -1965,7 +2180,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * // Only check env once. static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); - if (weight_to_nz && is_matmul_weight(weight)) { + if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); @@ -2146,6 +2361,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) { switch (type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif ggml_cann_mat_mul_fp(ctx, dst); break; case GGML_TYPE_Q4_0: @@ -2943,6 +3161,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // Rotate full tensor (no tail), using trans tensors GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(), acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get()); + } else if (src0->data == dst->data && !ggml_is_contiguous(src0)) { + // In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely + // read and write the same non-contiguous buffer. Use contiguous temporaries. + size_t contiguous_nb[GGML_MAX_DIMS]; + contiguous_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1]; + } + int64_t total_elements = ggml_nelements(src0); + ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float)); + ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float)); + + acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float), + src0->ne, contiguous_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float), + dst->ne, contiguous_nb, GGML_MAX_DIMS); + + cann_copy(ctx, acl_src.get(), acl_src_contig.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(), + acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get()); + cann_copy(ctx, acl_dst_contig.get(), acl_dst.get()); } else { // Rotate full tensor (no tail), using original tensors GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(), @@ -2984,6 +3223,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } } +void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + GGML_TENSOR_UNARY_OP_LOCALS + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_imrope || mrope_used) { + is_neox = true; + } + + int64_t rope_dims = n_dims; + if (is_vision) { + rope_dims = src0->ne[0]; + } + + // Run the full cache init on the non-captured stream. This performs all + // host-to-device memcpy, aclrtMalloc/Free, and on-device computations + // so that the memory pool is warmed up and cache metadata is populated. + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, + mrope_used, is_imrope, is_vision, rope_dims); + + // Reset `cached` so that during graph capture the on-device computations + // (sin/cos, position multiply, repeat, etc.) still execute and get recorded + // into the captured graph. The cache metadata (theta_scale_length, + // theta_scale, sections, position_length, etc.) remains set, which causes + // all host-to-device copy and malloc/free branches to be skipped. + ctx.rope_cache.cached = false; +} + void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; @@ -3179,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst int64_t paddingsArray[2] = { opts[0], opts[1] }; acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2); - for (int64_t i = 0; i < src0->ne[3]; i++) { - acl_tensor_ptr acl_src = - ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type), - ggml_element_size(src0), src0->ne, src0->nb, 3); + // Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3 + // is contiguous with respect to dim2 in both src and dst. + GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]); + GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]); - acl_tensor_ptr acl_dst = - ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type), - ggml_element_size(dst), dst->ne, dst->nb, 3); + int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] }; + int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] }; - GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); - } + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), + ggml_element_size(src0), src_ne_3d, src0->nb, 3); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + ggml_element_size(dst), dst_ne_3d, dst->nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); } void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; + // Write element-wise equality (0 or 1) into a temporary buffer to avoid + // modifying src0 in-place. Use the same type as src0 so ReduceSum can + // consume it directly without a type cast. + ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0)); + size_t eq_nb[GGML_MAX_DIMS]; + eq_nb[0] = ggml_element_size(src0); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1]; + } + acl_tensor_ptr acl_eq = ggml_cann_create_tensor( + eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0), + src0->ne, eq_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0); acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1); - - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get()); - - ggml_cann_sum(ctx, dst); + GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get()); + + // Sum the 0/1 values into dst. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + int64_t dims[4] = { 0, 1, 2, 3 }; + acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4); + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true, + ggml_cann_type_mapping(dst->type), acl_dst.get()); } void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -3217,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get()); } +void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + + float beta_val = 1.0f; + float threshold_val = 20.0f; + acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT); + acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get()); +} + +void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst); +} + /** * @brief Performs expert-specific matrix multiplication (MoE) with * floating-point precision using the CANN backend. @@ -3599,6 +3932,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS); acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); + // Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16 + // (required by FusedInferAttentionScoreV2) + const int64_t D = src0->ne[0]; + const int64_t D_padded = GGML_PAD(D, 16); + const bool needs_padding = (D != D_padded); + + ggml_cann_pool_alloc q_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc k_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc v_pad_allocator(ctx.pool()); + + if (needs_padding) { + int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 }; + + auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne, + ggml_cann_pool_alloc & allocator) { + int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] }; + size_t pad_nb[GGML_MAX_DIMS]; + pad_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1]; + } + int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3]; + void * buffer = allocator.alloc(nelements * faElemSize); + acl_tensor_ptr padded = + ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS); + aclnn_pad(ctx, tensor.get(), padded.get(), paddings); + tensor = std::move(padded); + }; + + pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator); + pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator); + pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator); + + src0_bsnd_ne[0] = D_padded; + src1_bsnd_ne[0] = D_padded; + src2_bsnd_ne[0] = D_padded; + } + // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp acl_tensor_ptr bcast_pse_tensor; @@ -3688,17 +4059,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); acl_tensor_ptr fa_dst_tensor; - acl_tensor_ptr acl_dst_tensor; ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - if (dst->type == GGML_TYPE_F32) { - void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - + if (dst->type == GGML_TYPE_F32 || needs_padding) { int64_t * out_f16_ne = src0_bsnd_ne; size_t out_f16_nb[GGML_MAX_DIMS]; out_f16_nb[0] = faElemSize; for (int i = 1; i < GGML_MAX_DIMS; ++i) { out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; } + int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3]; + void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize); fa_dst_tensor = ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS); @@ -3730,8 +4100,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst nullptr // softmaxLse ); - if (dst->type == GGML_TYPE_F32) { - // Step 6: post-processing, permute and cast to f32 + // Step 6: post-processing — slice padded output and/or cast to f32 + if (needs_padding) { + ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool()); + + if (dst->type == GGML_TYPE_F32) { + int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] }; + size_t sliced_nb[GGML_MAX_DIMS]; + sliced_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1]; + } + int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3]; + void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize); + acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize, + sliced_ne, sliced_nb, GGML_MAX_DIMS); + + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get()); + + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } else { + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get()); + } + } else if (dst->type == GGML_TYPE_F32) { acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); } @@ -3741,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst } static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // weight - ggml_tensor * src1 = dst->src[1]; // input + ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03] + ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13] GGML_TENSOR_BINARY_OP_LOCALS - acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + // dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T. + // + // ggml_cann_create_tensor reverses dimension order, so ACL sees: + // acl_src0 slice: ggml[m,K] -> ACL[K,m] + // acl_src1 slice: ggml[n,K] -> ACL[K,n] + // acl_dst slice: ggml[m,n] -> ACL[n,m] + // + // Build a transposed view of src1 by swapping ne[0]/ne[1]: + // src1_t: ggml[K,n] (swapped strides) -> ACL[n,K] + // + // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓ + // + // The outer batch loop is kept because src0 may have fewer batch slices than + // dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported + // by standard CANN Matmul broadcasting. + + const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type); + const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type); + const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type); + const size_t src0_type_sz = ggml_type_size(src0->type); + const size_t src1_type_sz = ggml_type_size(src1->type); + const size_t dst_type_sz = ggml_type_size(dst->type); const int64_t dps2 = ne2 / ne02; const int64_t dps3 = ne3 / ne03; + for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { const int64_t i02 = i2 / dps2; const int64_t i03 = i3 / dps3; - const int64_t i12 = i2; - const int64_t i13 = i3; - acl_tensor_ptr accumulator = - ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst->ne, dst->nb, 2); - - // The outer product needs to be accumulated in this dimension. - for (int64_t i1 = 0; i1 < ne11; i1++) { - acl_tensor_ptr acl_input = ggml_cann_create_tensor( - (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src1->ne, src1->nb, 1); - - acl_tensor_ptr acl_weight = ggml_cann_create_tensor( - (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, src0->nb, 1); - - ggml_cann_pool_alloc output_allocator(ctx.pool()); - void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); - acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst->ne, dst->nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); - float alpha_value = 1.0f; - aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); - } + // src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m] + int64_t src0_ne[2] = { ne00, ne01 }; + size_t src0_nb[2] = { nb00, nb01 }; + acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor( + (char *) src0->data + i02 * nb02 + i03 * nb03, + src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2); + + // src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K] + int64_t src1_t_ne[2] = { ne11, ne10 }; + size_t src1_t_nb[2] = { nb11, nb10 }; + acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor( + (char *) src1->data + i2 * nb12 + i3 * nb13, + src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2); + + // dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m] + int64_t dst_ne[2] = { ne0, ne1 }; + size_t dst_nb[2] = { nb0, nb1 }; + acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor( + (char *) dst->data + i2 * nb2 + i3 * nb3, + dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2); + + // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓ + GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, + acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1); } } } @@ -4019,3 +4433,4 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * } } } + diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 3effa1c289c..cdbf9260f85 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -32,6 +32,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -47,6 +50,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -69,6 +75,9 @@ */ void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate); + /** * @brief Applies the Leaky ReLU activation function to a tensor using the CANN * backend. @@ -325,6 +334,48 @@ void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Computes the cumulative sum of a ggml tensor along dim 0 using the + * CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_CUMSUM`. + */ +void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Computes a triangular mask (tril/triu) of a square ggml tensor + * using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_TRI`. + */ +void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Solves a triangular linear system AX=B using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`. + */ +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Creates a diagonal matrix from a vector using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_DIAG`. + */ +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Fills a tensor with a constant scalar value using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_FILL`. + */ +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Upsamples a ggml tensor using nearest neighbor interpolation using * the CANN backend. @@ -461,6 +512,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * // @see ggml_cann_dup. void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst); +// @see ggml_cann_acc, but copies src1 into dst instead of adding. +void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the softmax activation with optional masking. * @@ -543,6 +597,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Pre-load the RoPE cache before ACL graph capture. + * + * This function must be called outside of graph capture to perform + * host-to-device memory copies and device memory allocations that are + * not allowed on a captured stream. After pre-loading, the rope cache + * metadata is updated so that the subsequent call to + * aclnn_rope_cache_init (inside graph capture) skips these operations + * and only records the on-device computations into the captured graph. + * + * @param ctx CANN backend context. + * @param dst A ROPE destination tensor from the computation graph. + */ +void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the index of the maximum value along the specified dimension * of a ggml tensor using the CANN backend. @@ -798,6 +867,8 @@ void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst); * dst->op is expected to be `GGML_OP_STEP`. */ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Performs the Flash Attention extended operator using the CANN backend. diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 0120f0dfd1e..1c6e685c38c 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -216,14 +216,16 @@ struct ggml_cann_pool_alloc { #ifdef USE_ACL_GRAPH struct ggml_graph_node_properties { // dst tensor - void * node_address; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; + void * node_address; + ggml_type node_type; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; // src tensor - void * src_address[GGML_MAX_SRC]; - int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; - size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; + ggml_type src_type[GGML_MAX_SRC]; + int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; // op ggml_op node_op; @@ -247,6 +249,10 @@ struct ggml_graph_node_properties { return false; } + if (node->type != this->node_type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != this->ne[i]) { return false; @@ -262,6 +268,10 @@ struct ggml_graph_node_properties { return false; } + if (node->src[i]->type != this->src_type[i]) { + return false; + } + for (int d = 0; d < GGML_MAX_DIMS; d++) { if (node->src[i]->ne[d] != this->src_ne[i][d]) { return false; @@ -277,10 +287,7 @@ struct ggml_graph_node_properties { } } - if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { - return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; - } - return true; + return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } }; @@ -322,6 +329,7 @@ struct ggml_cann_graph { prop.node_address = node->data; prop.node_op = node->op; + prop.node_type = node->type; std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); @@ -329,10 +337,12 @@ struct ggml_cann_graph { for (int src = 0; src < GGML_MAX_SRC; ++src) { if (node->src[src]) { prop.src_address[src] = node->src[src]->data; + prop.src_type[src] = node->src[src]->type; std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); } else { prop.src_address[src] = nullptr; + prop.src_type[src] = GGML_TYPE_COUNT; std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3f3de9f0bcb..5f51ea3bb3c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -36,10 +36,13 @@ #include #include #include +#include #include #include #include +#include #include +#include #define GGML_COMMON_DECL_C @@ -770,6 +773,21 @@ std::unique_ptr ggml_backend_cann_context::new_pool_for_device(i } // cann buffer + +/** + * @brief Tracks multi-threaded write progress for a single tensor. + * + * When multiple threads call set_tensor on different chunks of the same tensor, + * this tracker accumulates progress and defers post-processing (quantized format + * transform or ND-to-NZ conversion) until all data has been written. + */ +struct TensorSetTracker { + std::mutex mtx; ///< Protects concurrent access to this tracker + size_t bytes_written = 0; ///< Accumulated bytes written so far + size_t total_bytes = 0; ///< Target size (full tensor) + std::vector host_buffer; ///< Host staging buffer for quantized tensors +}; + /** * @brief Context for managing a CANN buffer associated with a specific device. * @@ -780,6 +798,9 @@ struct ggml_backend_cann_buffer_context { int32_t device; ///< The device ID associated with this buffer context. void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer. + std::mutex tracker_mutex; ///< Protects the trackers map + std::unordered_map> trackers; + /** * @brief Constructor to initialize the CANN buffer context. * @@ -792,6 +813,31 @@ struct ggml_backend_cann_buffer_context { * @brief Destructor to free the device memory allocated for the buffer. */ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } + + /** + * @brief Get or create a tracker for the given tensor. + */ + TensorSetTracker * get_or_create_tracker(ggml_tensor * tensor) { + std::lock_guard lock(tracker_mutex); + auto key = tensor->data; + auto it = trackers.find(key); + if (it == trackers.end()) { + auto tracker = std::make_unique(); + tracker->total_bytes = ggml_nbytes(tensor); + auto * ptr = tracker.get(); + trackers[key] = std::move(tracker); + return ptr; + } + return it->second.get(); + } + + /** + * @brief Remove the tracker for the given tensor. + */ + void remove_tracker(ggml_tensor * tensor) { + std::lock_guard lock(tracker_mutex); + trackers.erase(tensor->data); + } }; // cann buffer type @@ -1124,6 +1170,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer * designed to be used with a global array, one per device. */ struct ggml_cann_nz_workspace { + std::mutex mtx; // Protects ptr/allocated from concurrent access void * ptr; // Pointer to allocated device buffer size_t allocated; // Size of currently allocated buffer in bytes @@ -1190,13 +1237,15 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES]; * @note The workspace buffer used in this function is managed globally and reused * across calls. This reduces overhead from repeated memory allocation and deallocation. */ -static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) { - acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset); +static void weight_format_to_nz(ggml_tensor * tensor, int device) { + acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, 0); uint64_t workspaceSize = 0; aclOpExecutor * executor; // TransMatmulWeight ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor)); + + std::lock_guard lock(g_nz_workspaces[device].mtx); // Avoid frequent malloc/free of the workspace. g_nz_workspaces[device].realloc(workspaceSize); @@ -1210,7 +1259,13 @@ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) * @brief Set tensor data in a CANN buffer. * * This function sets tensor data in a CANN buffer, handling transformations - * if needed based on the tensor's type. + * if needed based on the tensor's type. It supports multi-threaded calls + * where different threads write different chunks of the same tensor. + * + * For quantized tensors (Q4_0/Q8_0), data is staged in a host buffer and + * the format transform is deferred until all chunks are written. + * For NZ weight tensors, chunks are uploaded directly but the ND-to-NZ + * conversion is deferred until all chunks are written. * * @param buffer The CANN buffer where the tensor data will be set. * @param tensor Pointer to the tensor whose data will be set. @@ -1226,25 +1281,72 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; ggml_cann_set_device(ctx->device); - // TODO: refer to cann(#6017), it use thread's default stream. - // For acl, synchronous functions use this default stream. - // Why aclrtSynchronizeDevice? // Only check env once. static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); - if (!need_transform(tensor->type)) { + + bool is_quantized = need_transform(tensor->type); + bool is_nz = !is_quantized && tensor->type != GGML_TYPE_BF16 && weight_to_nz && + is_matmul_weight((const ggml_tensor *) tensor); + + // Plain tensor (not quantized, not NZ): direct copy, no tracking needed + if (!is_quantized && !is_nz) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); - if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { + return; + } + + // Single-shot write (full tensor at once): handle directly without tracking overhead + if (offset == 0 && size == ggml_nbytes(tensor)) { + if (is_quantized) { + void * transform_buffer = malloc(size); + ggml_backend_cann_transform(tensor, data, transform_buffer); + ACL_CHECK(aclrtMemcpy(tensor->data, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); + free(transform_buffer); + } else { + // NZ weight GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - weight_format_to_nz(tensor, offset, ctx->device); + ACL_CHECK(aclrtMemcpy(tensor->data, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + weight_format_to_nz(tensor, ctx->device); + } + return; + } + + // Chunked write: use tracker to accumulate progress and defer transform/conversion + TensorSetTracker * tracker = ctx->get_or_create_tracker(tensor); + std::unique_lock lock(tracker->mtx); + + if (is_quantized) { + // Stage data in host buffer; transform requires full tensor data + if (tracker->host_buffer.empty()) { + tracker->host_buffer.resize(tracker->total_bytes); } + memcpy(tracker->host_buffer.data() + offset, data, size); } else { - void * transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, data, transform_buffer); + // NZ weight: upload chunk to device immediately, defer conversion + ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + } - ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); - free(transform_buffer); + tracker->bytes_written += size; + + // All chunks received: perform deferred transform/conversion + if (tracker->bytes_written >= tracker->total_bytes) { + if (is_quantized) { + void * transform_buffer = malloc(tracker->total_bytes); + ggml_backend_cann_transform(tensor, tracker->host_buffer.data(), transform_buffer); + ACL_CHECK(aclrtMemcpy(tensor->data, tracker->total_bytes, transform_buffer, tracker->total_bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + free(transform_buffer); + } + + if (is_nz) { + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); + weight_format_to_nz(tensor, ctx->device); + } + + // Unlock before removing tracker, as remove_tracker destroys the mutex + lock.unlock(); + ctx->remove_tracker(tensor); } } @@ -1326,6 +1428,22 @@ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, return false; } +/** + * @brief Set a region of a tensor's device memory to a specified value. + * + * @param buffer The CANN buffer containing the tensor. + * @param tensor Pointer to the tensor whose memory will be set. + * @param value The value to which each byte in the region will be set. + * @param offset Byte offset within the tensor's data to start setting. + * @param size Number of bytes to set. + */ +static void ggml_backend_cann_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; + + ggml_cann_set_device(ctx->device); + ACL_CHECK(aclrtMemset((char *) tensor->data + offset, size, value, size)); +} + /** * @brief Clear a CANN buffer by setting all its memory to a specified value. * @@ -1352,9 +1470,11 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .get_base = */ ggml_backend_cann_buffer_get_base, /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_cann_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor, /* .clear = */ ggml_backend_cann_buffer_clear, /* .reset = */ NULL, @@ -1443,7 +1563,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t if (ne0 % MATRIX_ROW_PADDING != 0) { size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { + } else if (weight_to_nz && tensor->type != GGML_TYPE_BF16 + && is_matmul_weight((const ggml_tensor *) tensor)) { // NZ format weight are not support quantized yet. // If ND tensor transform to NZ, size may changed. int64_t shape[] = { tensor->ne[1], tensor->ne[0] }; @@ -1730,6 +1851,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_UNARY_OP_STEP: ggml_cann_step(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cann_softplus(ctx, dst); + break; default: return false; } @@ -1740,20 +1864,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg GGML_CANN_CALL_OP_UNARY_GATED(Relu); break; case GGML_GLU_OP_GEGLU: + ggml_cann_geglu(ctx, dst, 0); // approximate=0 → tanh + break; case GGML_GLU_OP_GEGLU_ERF: - // aclnnGelu internally uses the erf-based approximation. - GGML_CANN_CALL_OP_UNARY_GATED(Gelu); + ggml_cann_geglu(ctx, dst, 1); // approximate=1 → erf break; case GGML_GLU_OP_SWIGLU: - GGML_CANN_CALL_OP_UNARY_GATED(Silu); + ggml_cann_swiglu(ctx, dst); break; case GGML_GLU_OP_GEGLU_QUICK: - { - auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); - }; - ggml_cann_op_unary_gated(lambda, ctx, dst); - } + ggml_cann_geglu_quick(ctx, dst); break; default: return false; @@ -1815,6 +1935,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_CPY: ggml_cann_cpy(ctx, dst); break; + case GGML_OP_SET: + ggml_cann_set(ctx, dst); + break; case GGML_OP_CONT: ggml_cann_dup(ctx, dst); break; @@ -1884,6 +2007,21 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_SSM_CONV: ggml_cann_ssm_conv(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cann_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cann_tri(ctx, dst); + break; + case GGML_OP_FILL: + ggml_cann_fill(ctx, dst); + break; + case GGML_OP_DIAG: + ggml_cann_diag(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_cann_solve_tri(ctx, dst); + break; default: return false; } @@ -2219,10 +2357,24 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, if (use_cann_graph) { // If no matching graph is found, the graph needs to be recaptured. graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph); + if (graph_capture_required) { // If no matching graph is found, add a new ACL graph. ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); cann_ctx->graph_lru_cache.push(new_graph); + + // Pre-load rope cache before graph capture. During capture the + // stream cannot perform host-to-device memcpy or device memory + // malloc/free. Running the full cache init now populates the + // cache metadata so these branches are skipped during capture, + // while also warming up the memory pool. + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->op == GGML_OP_ROPE) { + ggml_cann_rope_cache_preload(*cann_ctx, node); + break; + } + } } } #else @@ -2264,6 +2416,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_SOFTPLUS: return true; default: return false; @@ -2283,6 +2436,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif case GGML_TYPE_F16: case GGML_TYPE_F32: return true; @@ -2320,6 +2476,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif case GGML_TYPE_Q8_0: return true; default: @@ -2332,6 +2491,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten switch (op->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif return true; default: return false; @@ -2341,20 +2503,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_CPY: { ggml_tensor * src = op->src[0]; +#ifdef ASCEND_310P if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) { - // only support F32 and F16. + // only support F32 and F16 on 310P. + return false; + } +#else + if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) || + (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) { + // only support F32, F16 and BF16. return false; } +#endif return true; } break; case GGML_OP_CONT: { - // TODO: support GGML_TYPE_BF16 switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif return true; default: return false; @@ -2435,6 +2607,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: + case GGML_OP_SET: case GGML_OP_GROUP_NORM: return true; case GGML_OP_PAD: @@ -2503,10 +2676,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten // different head sizes of K and V are not supported yet return false; } - if (op->src[0]->ne[0] % 16 != 0) { - // TODO: padding to support - return false; - } float logitSoftcap = 0.0f; memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float)); if (logitSoftcap != 0.0f) { @@ -2516,6 +2685,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten } case GGML_OP_SSM_CONV: return true; + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_TRI: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_DIAG: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SOLVE_TRI: + return op->src[0]->type == GGML_TYPE_F32; default: return false; } @@ -2567,6 +2746,8 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .free = */ ggml_backend_cann_free, /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, /* .synchronize = */ ggml_backend_cann_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 92cf739e7a7..f05683b44cd 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -93,6 +93,10 @@ typedef sycl::half2 ggml_half2; // QR = QK / number of values before dequantization // QI = number of 32 bit integers before dequantization +#define QI1_0 (QK1_0 / 32) +#define QR1_0 1 + + #define QI4_0 (QK4_0 / (4 * QR4_0)) #define QR4_0 2 @@ -170,6 +174,13 @@ typedef sycl::half2 ggml_half2; #define GGML_EXTENSION __extension__ #endif // _MSC_VER +#define QK1_0 128 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK1_0 / 8]; // bits / quants +} block_q1_0; +static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding"); + #define QK4_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 6ca3176a2f2..c1c225f0197 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -460,6 +460,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() if(NOT GGML_CPU_ALL_VARIANTS) set(MARCH_STR "rv64gc") + if (GGML_RVV) + string(APPEND MARCH_STR "v") + endif() + if (GGML_RV_ZFH) string(APPEND MARCH_STR "_zfh") endif() @@ -467,7 +471,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_XTHEADVECTOR) string(APPEND MARCH_STR "_xtheadvector") elseif (GGML_RVV) - string(APPEND MARCH_STR "_v") if (GGML_RV_ZVFH) string(APPEND MARCH_STR "_zvfh") endif() @@ -475,12 +478,21 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(APPEND MARCH_STR "_zvfbfwma") endif() endif() + if (GGML_RV_ZICBOP) string(APPEND MARCH_STR "_zicbop") endif() if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_CPU_RISCV64_SPACEMIT) + # `xsmtvdotii' is only required for GCC >= 15. + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND MARCH_STR "_xsmtvdotii") + endif() + endif() + list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) else() # Begin with the lowest baseline @@ -570,24 +582,36 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") - if (POLICY CMP0135) - cmake_policy(SET CMP0135 NEW) - endif() - - # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+ - # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28 - FetchContent_Declare(KleidiAI_Download + set(KLEIDIAI_FETCH_ARGS URL ${KLEIDIAI_DOWNLOAD_URL} - DOWNLOAD_EXTRACT_TIMESTAMP NEW - URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) + URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5} + ) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW) + endif() - FetchContent_GetProperties(KleidiAI_Download - SOURCE_DIR KLEIDIAI_SRC - POPULATED KLEIDIAI_POPULATED) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28") + FetchContent_Declare(KleidiAI_Download + ${KLEIDIAI_FETCH_ARGS} + EXCLUDE_FROM_ALL + ) - if (NOT KLEIDIAI_POPULATED) - FetchContent_Populate(KleidiAI_Download) + FetchContent_MakeAvailable(KleidiAI_Download) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) + else() + FetchContent_Declare(KleidiAI_Download + ${KLEIDIAI_FETCH_ARGS} + ) + + FetchContent_GetProperties(KleidiAI_Download + SOURCE_DIR KLEIDIAI_SRC + POPULATED KLEIDIAI_POPULATED + ) + + if (NOT KLEIDIAI_POPULATED) + FetchContent_Populate(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) + endif() endif() add_compile_definitions(GGML_USE_CPU_KLEIDIAI) diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 9baf3e025e6..1118f7169c9 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, /* .cpy_tensor = */ nullptr, /* .clear = */ ggml_backend_amx_buffer_clear, /* .reset = */ nullptr, diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 93a6d397f79..d9383a04be8 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -2005,12 +2005,12 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v const int lda = KB * sizeof(TA); //const int ldb = KB * sizeof(TB); - static thread_local packed_B_t Tile0[TILE_N * TILE_K]; - static thread_local packed_B_t Tile1[TILE_N * TILE_K]; - static thread_local int8_t Tile23[TILE_M * TILE_K]; + alignas(64) static thread_local packed_B_t Tile0[TILE_N * TILE_K]; + alignas(64) static thread_local packed_B_t Tile1[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K]; - static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; - static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; + alignas(64) static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; + alignas(64) static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; // double buffering C to interleave avx512 and amx int32_t * C_cur = TileC0; @@ -2187,21 +2187,21 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v const int m1 = std::max(M - TILE_M, 0); //const int lda = KB * sizeof(TA); - static thread_local int8_t Tile0[TILE_N * TILE_K]; - static thread_local int8_t Tile1[TILE_N * TILE_K]; - static thread_local int8_t Tile23[TILE_M * TILE_K]; + alignas(64) static thread_local int8_t Tile0[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile1[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K]; // mat mul result for each group - static thread_local int32_t Tile4[TILE_M * TILE_N]; - static thread_local int32_t Tile5[TILE_M * TILE_N]; - static thread_local int32_t Tile6[TILE_M * TILE_N]; - static thread_local int32_t Tile7[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile4[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile5[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile6[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile7[TILE_M * TILE_N]; // sum of each QK_K block, contains 8 groups, int32 - static thread_local int32_t Sumi4[TILE_M * TILE_N]; - static thread_local int32_t Sumi5[TILE_M * TILE_N]; - static thread_local int32_t Sumi6[TILE_M * TILE_N]; - static thread_local int32_t Sumi7[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi4[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi5[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi6[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi7[TILE_M * TILE_N]; const int k_group_size = std::is_same::value ? 16 : 32; for (int i = 0; i < KB; ++i) { diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 41da829315b..595ded09f03 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -16,6 +16,7 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -112,6 +113,7 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -160,6 +162,7 @@ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -200,6 +203,7 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -240,6 +244,7 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -303,6 +308,7 @@ #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 82b048bb3ae..fe621332970 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -137,6 +137,89 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; // 128 + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + // Process 4 Q8_0 blocks (each has 32 elements) + for (int k = 0; k < 4; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + + // Get the 4 bytes of bits for this Q8_0 block (32 bits = 4 bytes) + // Bits are at offset k*4 bytes in x[i].qs + const uint8_t * bits = &x[i].qs[k * 4]; + + // Load 32 int8 values from y + const int8x16_t y0 = vld1q_s8(yb->qs); + const int8x16_t y1 = vld1q_s8(yb->qs + 16); + + // Byte 0-1: bits for y0[0..15] + const uint64_t expand0 = table_b2b_0[bits[0]]; + const uint64_t expand1 = table_b2b_0[bits[1]]; + // Byte 2-3: bits for y1[0..15] + const uint64_t expand2 = table_b2b_0[bits[2]]; + const uint64_t expand3 = table_b2b_0[bits[3]]; + + // Build the sign vectors by reinterpreting the table values + uint8x8_t e0 = vcreate_u8(expand0); + uint8x8_t e1 = vcreate_u8(expand1); + uint8x8_t e2 = vcreate_u8(expand2); + uint8x8_t e3 = vcreate_u8(expand3); + + // Shift right by 4 to get 0 or 1 + int8x8_t s0 = vreinterpret_s8_u8(vshr_n_u8(e0, 4)); + int8x8_t s1 = vreinterpret_s8_u8(vshr_n_u8(e1, 4)); + int8x8_t s2 = vreinterpret_s8_u8(vshr_n_u8(e2, 4)); + int8x8_t s3 = vreinterpret_s8_u8(vshr_n_u8(e3, 4)); + + // Convert 0/1 to -1/+1: sign = 2*val - 1 + int8x8_t one = vdup_n_s8(1); + s0 = vsub_s8(vadd_s8(s0, s0), one); // 2*s0 - 1 + s1 = vsub_s8(vadd_s8(s1, s1), one); + s2 = vsub_s8(vadd_s8(s2, s2), one); + s3 = vsub_s8(vadd_s8(s3, s3), one); + + // Combine into 16-element vectors + int8x16_t signs0 = vcombine_s8(s0, s1); + int8x16_t signs1 = vcombine_s8(s2, s3); + + // Multiply signs with y values and accumulate + // dot(signs, y) where signs are +1/-1 + int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), signs0, y0); + int32x4_t p1 = ggml_vdotq_s32(p0, signs1, y1); + + // Scale by d1 and accumulate + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1); + } + } + + *s = vaddvq_f32(sumv); +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -680,6 +763,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b)); const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4)); +#if defined(__ARM_FEATURE_DOTPROD) const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs); const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16); const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b)); @@ -691,15 +775,40 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b)); const int32x4_t p0 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); + vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), + vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); const int32x4_t p1 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); + vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), + vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); - const int32x4_t sums = vpaddq_s32(p0, p1); + const int32x4_t sumi = vpaddq_s32(p0, p1); +#else + const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0); + const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0); + const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0); + const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0); + const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1); + const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1); + const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1); + const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1); + + const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs); + const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8); + const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16); + const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24); + const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs); + const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8); + const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16); + const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24); + + const int32x4_t sumi = (int32x4_t){ + vaddvq_s32(ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi)), + }; +#endif - // Decode 4 UE4M3 scales to f32 and multiply with q8 scales const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d); const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d); const float32x4_t nvsc = { @@ -710,7 +819,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo }; const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1}); - acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales); + acc = vfmaq_f32(acc, vcvtq_f32_s32(sumi), scales); } sumf = vaddvq_f32(acc); #else diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 80ff5ce549b..a7534443091 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; + + static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7}; + svuint32_t idx = svld1(svptrue_b32(), idx_arr); + static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0}; + svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1); + static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3}; + svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2); + + for (int y = 0; y < nr; y += 4) { + const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb; + + for (int x = 0; x < nc; x += ncols_interleaved) { + const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb; + const block_q8_0x4 * a_ptr = a_ptr_base; + + svfloat32_t acc_f32_01 = svdup_f32(0); + svfloat32_t acc_f32_23 = svdup_f32(0); + + for (int b = 0; b < nb; b++) { + + svint32_t acc_01 = svdup_s32(0); + svint32_t acc_23 = svdup_s32(0); + + // Process 4 chunks of 8 positions each + for (int chunk = 0; chunk < 4; chunk++) { + svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32); + svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16); + svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32); + + acc_01 = svmmla_s32(acc_01, s_a01, s_b0123); + acc_23 = svmmla_s32(acc_23, s_a23, s_b0123); + } + + // Reorder outputs from 2×2 tiles to row-major + // acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3] + // acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3] + + svint32_t row01 = svtbl_s32(acc_01, idx); + svint32_t row23 = svtbl_s32(acc_23, idx); + + svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d); + svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d); + svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1); + svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2); + + acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0)); + acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2)); + a_ptr++; + b_ptr++; + } + + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01); + svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4)); + svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23); + svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4)); + } + } + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index f531e916b9e..74e0c086c6d 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2156,4 +2156,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index d3dfd049eaf..644c380c738 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2302,4 +2302,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 826055dd9a4..d3278d6489f 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -15,6 +15,12 @@ #include // for qsort #include // for GGML_ASSERT +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -115,10 +121,10 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { assert(k % QK_K == 0); - block_q8_K * y_blocks = (block_q8_K *)y; size_t nb = k / QK_K; -#if defined(__riscv_v_intrinsic) +#if defined __riscv_v_intrinsic + block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); for (size_t i = 0; i < nb; i++) { @@ -2052,7 +2058,120 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16m1_t qh = __riscv_vle16_v_u16m1(x[i].qh, 8); + + // Calculate ls. + vuint16m1_t temp = __riscv_vsrl_vx_u16m1(qh, 12, 8); + temp = __riscv_vand_vx_u16m1(temp, 7, 8); + vint32m2_t ls = __riscv_vreinterpret_v_u32m2_i32m2(__riscv_vwmulu_vx_u32m2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32m2(ls, 1, 8); + + // Calculate delta. + vbool16_t mask = __riscv_vmseq_vx_u16m1_b16(__riscv_vand_vx_u16m1(qh, 0x8000, 8), 0, 8); + vint32m2_t delta_neg = __riscv_vmv_v_x_i32m2(-1, 8); + vint32m2_t delta_pos = __riscv_vmv_v_x_i32m2(1, 8); + vint32m2_t delta = __riscv_vmerge_vvm_i32m2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8m2_t qs = __riscv_vle8_v_u8m2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m4_t qh_shift = __riscv_vreinterpret_v_u64m4_u16m4(__riscv_vmv_v_x_u64m4(shift, 8)); + vuint16m4_t qh_gather_index = __riscv_vreinterpret_v_i16m4_u16m4( + __riscv_vdiv_vx_i16m4(__riscv_vreinterpret_v_u16m4_i16m4(__riscv_vid_v_u16m4(32)), 4, 32)); + vuint16m4_t qh_ext = __riscv_vlmul_ext_v_u16m2_u16m4(__riscv_vlmul_ext_v_u16m1_u16m2(qh)); + vuint16m4_t qh_index = __riscv_vrgather_vv_u16m4(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m4(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m4(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m4(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m4(qh_index, __riscv_vzext_vf2_u16m4(qs, 32), 32); + vuint16m4_t index = __riscv_vsll_vx_u16m4(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-2 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 0); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[0], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 3-4 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 1); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[64], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 5-6 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 2); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[128], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 7-8 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 3); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[192], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32m2_t lsums = __riscv_vle32_v_i32m2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16m2_t bsums_0 = __riscv_vle16_v_i16m2(y[i].bsums, 16); + const vuint32m2_t bsums_i32 = __riscv_vreinterpret_v_u16m2_u32m2(__riscv_vreinterpret_v_i16m2_u16m2(bsums_0)); + const vint16m1_t bsums_i32_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(bsums_i32, 0, 8)); + const vint16m1_t bsums_i32_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(bsums_i32, 16, 8)); + const vint32m2_t bsums = __riscv_vwadd_vv_i32m2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32m2_t sumi_v = __riscv_vmul_vv_i32m2(ls, lsums, 8); + vint32m2_t sumi1_v = __riscv_vmul_vv_i32m2(__riscv_vmul_vv_i32m2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2147,10 +2266,14 @@ static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq1_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2163,7 +2286,175 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m4_t acc1 = __riscv_vmv_v_x_i32m4(0, 16); + vint32m4_t acc2 = __riscv_vmv_v_x_i32m4(0, 16); + + // We process 8 16-element sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/128; ib++) { + // Load qh for 8 sub-blocks. + const vuint8mf2_t qh_8 = __riscv_vle8_v_u8mf2(qh, 8); + const vuint16m1_t qh_16_lo = __riscv_vzext_vf2_u16m1(qh_8, 8); + const vuint16m1_t qh_16_hi = __riscv_vsll_vx_u16m1(qh_16_lo, 8, 8); + const vuint16m2_t qhb = __riscv_vzext_vf2_u16m2( + __riscv_vreinterpret_v_u16m1_u8m1(__riscv_vor_vv_u16m1(qh_16_lo, qh_16_hi, 8)), 16); + qh += 8; + + // Prepare grid indices. + const vuint16m2_t qsb = __riscv_vzext_vf2_u16m2(__riscv_vle8_v_u8m1(&qs[0], 16), 16); + const vuint16m2_t shift = __riscv_vreinterpret_v_u32m2_u16m2(__riscv_vmv_v_x_u32m2(0x00040008, 8)); + vuint16m2_t index = __riscv_vor_vv_u16m2(qsb, __riscv_vand_vx_u16m2(__riscv_vsll_vv_u16m2(qhb, shift, 16), 0x700, 16), 16); + index = __riscv_vsll_vx_u16m2(index, 3, 16); + qs += 16; + + // Prepare the deltas. + const vbool8_t mask = __riscv_vmsgtu_vx_u16m2_b8( + __riscv_vand_vv_u16m2(qhb, __riscv_vreinterpret_v_u32m2_u16m2(__riscv_vmv_v_x_u32m2(0x00800008, 8)), 16), 0, 16); + const vint64m8_t delta_pos = __riscv_vmv_v_x_i64m8(0x0101010101010101, 16); + const vint8m8_t delta = __riscv_vreinterpret_v_i64m8_i8m8( + __riscv_vmerge_vxm_i64m8(delta_pos, 0xffffffffffffffff, mask, 16)); + + // Sub-blocks 0-3 + { + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, __riscv_vget_v_u16m2_u16m1(index, 0), 8))); + + // Calculate the lsums. + // + // Sub-block 0, 1 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 0), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 0), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-block 2, 3 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 1), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 1), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + sc += 1; + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 4-7 + { + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, __riscv_vget_v_u16m2_u16m1(index, 1), 8))); + + // Calculate the lsums. + // + // Sub-block 4, 5 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 0), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 2), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-block 6, 7 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 1), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 3), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + sc += 1; + } + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(acc1, one, 16)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(acc2, one, 16)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2190,9 +2481,10 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16); vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16); - // We process 4 sub-blocks together. + // We process 8 16-element sub-blocks together. + #pragma GCC unroll 1 for (int ib = 0; ib < QK_K/128; ib++) { - // Load qh for 4 sub-blocks. + // Load qh for 8 sub-blocks. const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8); const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8); const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8); @@ -2200,6 +2492,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16); qh += 8; + __asm__ __volatile__("" ::: "memory"); + // Prepare grid indices. const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16); const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8)); @@ -2207,6 +2501,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t index = __riscv_vsll_vx_u16m1(index, 3, 16); qs += 16; + __asm__ __volatile__("" ::: "memory"); + // Load the grid. const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16))); @@ -2215,9 +2511,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16); const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16); - const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16); const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( - __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16)); + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 16)); // Load q8 for sub-blocks. const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); @@ -2258,6 +2553,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + + __asm__ __volatile__("" ::: "memory"); } // Reduce and accumulate in `sumf`. @@ -2269,10 +2566,14 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq1_m_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2285,6 +2586,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static const uint8_t sign_gather_indices_arr[64] = { 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 @@ -2295,8 +2597,7 @@ static const uint8_t sign_bit_masks_arr[64] = { 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 }; - -static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -2387,7 +2688,7 @@ static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -2488,6 +2789,7 @@ static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2507,7 +2809,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined(__riscv_v) +#if defined __riscv_v_intrinsic static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -2542,9 +2844,85 @@ static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, }; -#endif -static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; +#pragma GCC unroll 1 + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + + int32_t sum_int = 0; + + // Loop over 4 subblocks of 64 elements + for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { + + // Load indices. + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 8); + qs += 8; + + // Prepare offsets + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 8), 3, 8); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 8), 3, 8); + + // load values and signs from the lookup tables + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 8); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 8); + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 64); + asm volatile("" ::: "memory"); + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 64); + asm volatile("" ::: "memory"); + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 0), zero_vec, 16)); + + int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 1), zero_vec, 16)); + + int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 2), zero_vec, 16)); + + int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 3), zero_vec, 16)); + + const uint8_t scale_byte_1 = scales[0]; + const uint8_t scale_byte_2 = scales[1]; + scales += 2; + + sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); + sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); + sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); + sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + } + + sumf += d * sum_int; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2618,10 +2996,14 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2634,7 +3016,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2725,7 +3108,7 @@ static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2818,6 +3201,7 @@ static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2825,16 +3209,112 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const case 128: ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } #else - ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + vuint8mf2_t v_id_8 = __riscv_vid_v_u8mf2(8); + vuint8m2_t v_id_32 = __riscv_vid_v_u8m2(32); + + // Keeping these in a tight scope to hint they're only needed for the mask computation. + vuint8m2_t v_sign_gather_indices, v_sign_masks; + { + vuint8m2_t v_shifts = __riscv_vand_vx_u8m2(v_id_32, 7, 32); + vuint8m2_t v_one_32 = __riscv_vmv_v_x_u8m2(1, 32); + v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_id_32, 3, 32); + v_sign_masks = __riscv_vsll_vv_u8m2(v_one_32, v_shifts, 32); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d); + const float combined_scale = d * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { + + // Grid lookup + vuint8m2_t v_grid_u8; + { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 8); + qs += 8; + + uint8_t qh_val = *qh++; + vuint8mf2_t v_qh_val = __riscv_vmv_v_x_u8mf2(qh_val, 8); + v_qh_val = __riscv_vsrl_vv_u8mf2(v_qh_val, v_id_8, 8); + v_qh_val = __riscv_vand_vx_u8mf2(v_qh_val, 1, 8); + + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 8); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 8); + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_val, 8); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 8); + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 8); + + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 8); + v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + } + __asm__ volatile ("" ::: "memory"); + + //Sign application and dot product + int32_t s_val; + { + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 4); + signs += 4; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 32); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 32); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + s_val = __riscv_vmv_x_s_i32m1_i32( + __riscv_vwredsum_vs_i16m4_i32m1(v_dot, v_zero, 32)); + } + __asm__ volatile ("" ::: "memory"); + { + uint8_t sc_byte = scales[ib >> 1]; + int sc_val = (ib & 1) ? (sc_byte >> 4) : (sc_byte & 0xF); + sc_val = sc_val * 2 + 1; + sum_block += (float)(s_val * sc_val); + } + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); @@ -2928,10 +3408,14 @@ static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t } *s = sumf; } +#endif void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2944,7 +3428,101 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // constants for unpacking logic + const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21}; + vuint32m2_t v_shifts = __riscv_vle32_v_u32m2(shifts_val, 8); + + const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint32m2_t v_gather_idx = __riscv_vle32_v_u32m2(gather_idx_val, 8); + + uint32_t aux32[2]; + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + + // Process 64 weights per loop + for (int ib = 0; ib < QK_K / 64; ++ib) { + + // load of metadata via memcpy + memcpy(aux32, metadata, 2 * sizeof(uint32_t)); + metadata += 2 * sizeof(uint32_t); + + vuint8m1_t v_q3_idx_u8 = __riscv_vle8_v_u8m1(q3_indices, 16); + q3_indices += 16; + + vuint16m2_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m2(v_q3_idx_u8, 4, 16); + + vuint32m4_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m4(grid32, v_q3_idx_u16, 16); + + vint8m4_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m4_i8m4( + __riscv_vreinterpret_v_u32m4_u8m4(v_q3_magnitudes_u32)); + + vuint32m2_t v_aux = __riscv_vle32_v_u32m2(aux32, 2); + + vuint32m2_t v_aux_expanded = __riscv_vrgather_vv_u32m2(v_aux, v_gather_idx, 8); + + vuint32m2_t v_s_vals_raw = __riscv_vand_vx_u32m2( + __riscv_vsrl_vv_u32m2(v_aux_expanded, v_shifts, 8), 127, 8); + + vuint16m1_t sign_indices_byte_offset = __riscv_vsll_vx_u16m1( + __riscv_vncvt_x_x_w_u16m1(v_s_vals_raw, 8), 3, 8); + + vuint64m4_t v_s_vals_u64 = __riscv_vluxei16_v_u64m4(signs64, sign_indices_byte_offset, 8); + + vint8m4_t v_s_vals = __riscv_vreinterpret_v_u8m4_i8m4( + __riscv_vreinterpret_v_u64m4_u8m4(v_s_vals_u64)); + + vint8m4_t v_q3_signed = __riscv_vmul_vv_i8m4(v_q3_magnitudes, v_s_vals, 64); + asm volatile("" ::: "memory"); + vint8m4_t v_q8 = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + + vint16m8_t v_dot = __riscv_vwmul_vv_i16m8(v_q8, v_q3_signed, 64); + + asm volatile("" ::: "memory"); + + vint16m4_t v_dot_1 = __riscv_vget_v_i16m8_i16m4(v_dot, 0); + vint16m4_t v_dot_2 = __riscv_vget_v_i16m8_i16m4(v_dot, 1); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m4_i32m1(v_dot_1, v_zero, 32); + vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m4_i32m1(v_dot_2, v_zero, 32); + + int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1); + int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2); + + const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1); + const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1); + + block_sum += sum1_i * scale1_f + sum2_i * scale2_f; + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -3036,10 +3614,14 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size } *s = 0.25f * sumf; } +#endif void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq3_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3052,7 +3634,8 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const #endif } -static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3082,12 +3665,14 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_ vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); // Unpack the weight blocks. - vuint8m2_t iq4bits1; - iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 0, __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16)); - iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 1, __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16)); - vuint8m2_t iq4bits2; - iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 0, __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16)); - iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 1, __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16)); + vuint8m2_t iq4bits1 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16), + __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16) + ); + vuint8m2_t iq4bits2 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16), + __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16) + ); // Gather values from the lookup table. vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32); @@ -3105,7 +3690,7 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_ *s = sumf; } -static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3161,6 +3746,7 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_ *s = sumf; } +#endif void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3168,7 +3754,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v case 128: ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } @@ -3177,7 +3763,74 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + // We process 2 sub-blocks together. + int sumi1 = 0, sumi2 = 0; + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K / 64; ++ib) { + // Load the packed weights. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 32); + iq4 += 32; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 32); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 32); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vcreate_v_u8m1_u8m4( + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 0), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 2), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 1), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 3), 16) + ); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 64); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + + sumi1 += acc0 * ls1; + sumi2 += acc1 * ls2; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3190,19 +3843,17 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ const int nb = n / QK_K; -#if defined __riscv_v_intrinsic const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); float sumf = 0; - int acc[4]; // Indices for re-ordering IQ4 data. - uint64_t index[16] = { + uint16_t index[16] = { 0, 1, 8, 9, 2, 3, 10, 11, 4, 5,12, 13, 6, 7, 14, 15, }; - vuint64m4_t i_vec = __riscv_vle64_v_u64m4(index, 16); + vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 16); for (int ibl = 0; ibl < nb; ++ibl) { const int8_t * q8 = y[ibl].qs; @@ -3211,30 +3862,33 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + #pragma GCC unroll 1 for (int ib = 0; ib < QK_K / 128; ++ib) { // Weights and activations. vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64); - vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); iq4 += 64; - q8 += 128; // Unpack the weight blocks. vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64); vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64); - vuint8m4_t iq4bits; - iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 0, iq4bits_lo); - iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 1, iq4bits_hi); - vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgather_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); + vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128); + __asm__ __volatile__("" ::: "memory"); + // Multiply with activations. + vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128); + q8 += 128; + + __asm__ __volatile__("" ::: "memory"); // Reduce separately. - __riscv_vse32_v_i32m1(&acc[0],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[1],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[2],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[3],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf) | ((h << 4) & 0x30)) - 32; int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >> 4) | ((h << 2) & 0x30)) - 32; @@ -3242,28 +3896,27 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >> 4) | ((h >> 2) & 0x30)) - 32; h >>= 8; - sumi1 += acc[0] * ls1; - sumi2 += acc[1] * ls2; - sumi3 += acc[2] * ls3; - sumi4 += acc[3] * ls4; + sumi1 += acc0 * ls1; + sumi2 += acc1 * ls2; + sumi3 += acc2 * ls3; + sumi4 += acc3 * ls4; + + __asm__ __volatile__("" ::: "memory"); } sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4); } *s = sumf; - -#else - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif } +#endif void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq4_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3276,7 +3929,8 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3292,8 +3946,107 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; for (int i = 0; i < nb; i++) { + const uint8_t * GGML_RESTRICT tq = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + // First loop. - vint32m4_t suml1; + vint16m4_t suml1; + { + const int vl = 32; + const vuint8m2_t tqb = __riscv_vle8_v_u8m2(tq, vl); + tq += 32; + + { + const vuint16m4_t tq0 = __riscv_vsrl_vx_u16m4(__riscv_vwmulu_vx_u16m4(tqb, 3, vl), 8, vl); + const vint16m4_t q80 = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(q8, vl), vl); + suml1 = __riscv_vmul_vv_i16m4(__riscv_vreinterpret_v_u16m4_i16m4(__riscv_vsub_vx_u16m4(tq0, 1, vl)), q80, vl); + q8 += 32; + } + + uint8_t pow3 = 3; + #pragma GCC unroll 1 + for (int t = 0; t < 4; t++) { + const vuint16m4_t tqn = __riscv_vsrl_vx_u16m4(__riscv_vwmulu_vx_u16m4(__riscv_vmul_vx_u8m2(tqb, pow3, vl), 3, vl), 8, vl); + const vint16m4_t q8n = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(q8, vl), vl); + suml1 = __riscv_vmacc_vv_i16m4(suml1, __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vsub_vx_u16m4(tqn, 1, vl)), q8n, vl); + pow3 *= 3; + q8 += 32; + } + } + + // Second loop. + vint16m2_t suml2; + { + const int vl = 16; + const vuint8m1_t tqb = __riscv_vle8_v_u8m1(tq, vl); + + { + const vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tqb, 3, vl), 8, vl); + const vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + suml2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + q8 += 16; + } + + uint8_t pow3 = 3; + #pragma GCC unroll 1 + for (int t = 0; t < 4; t++) { + const vuint16m2_t tqn = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tqb, pow3, vl), 3, vl), 8, vl); + const vint16m2_t q8n = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + suml2 = __riscv_vmacc_vv_i16m2(suml2, __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tqn, 1, vl)), q8n, vl); + pow3 *= 3; + q8 += 16; + } + } + + // Third loop. + vint16m2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + const vuint8m1_t tqb = __riscv_vreinterpret_v_u32m1_u8m1(__riscv_vmv_v_x_u32m1(qh, vl / 4)); + + const vuint8m1_t p = __riscv_vle8_v_u8m1(pow, vl); + + const vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vv_u8m1(tqb, p, vl), 3, vl), 8, vl); + + const vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + + suml3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + } + + vint16m2_t sumb = __riscv_vadd_vv_i16m2(__riscv_vget_v_i16m4_i16m2(suml1, 0), __riscv_vget_v_i16m4_i16m2(suml1, 1), 16); + sumb = __riscv_vadd_vv_i16m2(sumb, suml2, 16); + sumb = __riscv_vadd_vv_i16m2(sumb, suml3, 16); + + vint32m1_t sum = __riscv_vwredsum_vs_i16m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m2_t suml1; { const int vl = 32; vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl); @@ -3316,13 +4069,13 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl); vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl); - vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl); - vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl); - suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl); + vint16m2_t sumi0 = __riscv_vadd_vv_i16m2(sum0, sum1, vl); + vint16m2_t sumi1 = __riscv_vadd_vv_i16m2(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m2(sum4, __riscv_vadd_vv_i16m2(sumi0, sumi1, vl), vl); } // Second loop. - vint32m2_t suml2; + vint16m1_t suml2; { const int vl = 16; vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl); @@ -3345,13 +4098,13 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); - vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl); - vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl); - suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl); + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); } // Third loop. - vint32m2_t suml3; + vint16m1_t suml3; { const int vl = 16; @@ -3367,24 +4120,26 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl); - vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); - suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl); + suml3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); } - vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16); - sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16); - sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16); + vint16m1_t sumb = __riscv_vadd_vv_i16m1(__riscv_vget_v_i16m2_i16m1(suml1, 0), __riscv_vget_v_i16m2_i16m1(suml1, 1), 16); + sumb = __riscv_vadd_vv_i16m1(sumb, __riscv_vadd_vv_i16m1(suml2, suml3, 16), 16); - vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); } *s = sumf; } +#endif void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_tq1_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3397,7 +4152,90 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { + const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; + const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; + const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; + const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; + const uint8_t* px = &x[i].qs[j]; + + size_t vl = __riscv_vsetvl_e16m4(32); + vint16m4_t vacc16 = __riscv_vmv_v_x_i16m4(0, vl); + + // Load Raw Packed elements + vl = __riscv_vsetvl_e8m2(32); + vuint8m2_t vx_u8 = __riscv_vle8_v_u8m2(px, vl); + + // Process bits 1:0 + { + // Unpack + vuint8m2_t t0 = __riscv_vand_vx_u8m2(vx_u8, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t0), 1, vl); + vint8m2_t vy = __riscv_vle8_v_i8m2(py0, vl); + // Accumulate + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 3:2 + { + vuint8m2_t t1 = __riscv_vsrl_vx_u8m2(vx_u8, 2, vl); + t1 = __riscv_vand_vx_u8m2(t1, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t1), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py1, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 5:4 + { + vuint8m2_t t2 = __riscv_vsrl_vx_u8m2(vx_u8, 4, vl); + t2 = __riscv_vand_vx_u8m2(t2, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t2), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py2, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 7:6 + { + vuint8m2_t t3 = __riscv_vsrl_vx_u8m2(vx_u8, 6, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t3), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py3, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + vl = __riscv_vsetvl_e16m4(32); + vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t vred32 = __riscv_vwredsum_vs_i16m4_i32m1(vacc16, vzero32, vl); + sumi += __riscv_vmv_x_s_i32m1_i32(vred32); + } + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + sumf += (float)sumi * d; + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -3467,10 +4305,14 @@ static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_tq2_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3483,7 +4325,8 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic +static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3513,12 +4356,14 @@ static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); // Unpack the weight blocks. - vuint8m2_t mxbits1; - mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 0, __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16)); - mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 1, __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16)); - vuint8m2_t mxbits2; - mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 0, __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16)); - mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 1, __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16)); + vuint8m2_t mxbits1 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16), + __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16) + ); + vuint8m2_t mxbits2 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16), + __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16) + ); // Gather values from the lookup table. vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32); @@ -3536,7 +4381,7 @@ static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t *s = sumf; } -static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3592,6 +4437,7 @@ static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3599,11 +4445,11 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo case 128: ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } #else - return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index cd5807879ea..c37488cae54 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -107,8 +107,7 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR } #else UNUSED(nb); - UNUSED(y); - ggml_quantize_mat_q8_0_4x4_generic(x, vy, k); + ggml_quantize_mat_q8_0_4x8_generic(x, vy, k); #endif } @@ -203,6 +202,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -222,7 +222,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); @@ -256,9 +255,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -280,7 +276,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -392,9 +387,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -416,7 +408,6 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -451,9 +442,6 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -476,7 +464,6 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(blocklen); UNUSED(bs); -#if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); @@ -505,9 +492,6 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -679,9 +663,9 @@ void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } // End K-Block __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); - } } +#endif void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -909,6 +893,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -929,7 +914,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -994,9 +978,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1019,7 +1000,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1267,9 +1247,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1292,7 +1269,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); @@ -1355,9 +1331,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1380,7 +1353,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1429,9 +1401,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1731,3 +1700,4 @@ void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } } +#endif diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 34184ed8510..500857579a7 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1463,4 +1463,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 74a359e6d12..648c6fcaba7 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1218,4 +1218,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 74d699f633d..94b19b82bbc 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -274,6 +274,18 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const } #endif #elif defined(__SSSE3__) +static inline __m128i bytes_from_bits_16(const uint8_t * x) { + uint16_t x16; + memcpy(&x16, x, sizeof(uint16_t)); + + const __m128i shuf_mask = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + __m128i bytes = _mm_shuffle_epi8(_mm_set1_epi16((short) x16), shuf_mask); + const __m128i bit_mask = _mm_set_epi64x(0x7fbfdfeff7fbfdfe, 0x7fbfdfeff7fbfdfe); + bytes = _mm_or_si128(bytes, bit_mask); + + return _mm_cmpeq_epi8(bytes, _mm_set1_epi64x(-1)); +} + // horizontally add 4x4 floats static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { __m128 res_0 =_mm_hadd_ps(a, b); @@ -540,6 +552,152 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const uint32_t * GGML_RESTRICT qs32 = (const uint32_t *) x[ib].qs; + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + + __m256 acc_block; + { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[0].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[0]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32)); + } + for (int K = 1; K < 4; ++K) { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); + } + acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + __m256 acc_block; + { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[0]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), q); + } + for(int K = 1; K < 4; ++K) { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); + } +#undef Q1_AVX_BLOCK + + acc = _mm256_add_ps(acc, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block)); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const __m128 d0 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d)); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + +#define Q1_SSSE3_BLOCK(QS_OFF, Y_IDX, ACC) \ + { \ + const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \ + const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \ + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \ + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \ + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \ + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \ + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \ + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \ + const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \ + const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \ + const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \ + (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \ + } + Q1_SSSE3_BLOCK(0, 0, acc_0) + Q1_SSSE3_BLOCK(4, 1, acc_1) + Q1_SSSE3_BLOCK(8, 2, acc_2) + Q1_SSSE3_BLOCK(12, 3, acc_3) +#undef Q1_SSSE3_BLOCK + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -2142,9 +2300,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #if defined __AVX2__ - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m15 = _mm256_set1_epi8(15); __m256 acc = _mm256_setzero_ps(); @@ -2156,53 +2313,45 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * GGML_RESTRICT qh = x[i].qh; const int8_t * GGML_RESTRICT q8 = y[i].qs; + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m256i scales_16 = _mm256_cvtepi8_epi16(scales); + const __m256i q8sclsub = _mm256_slli_epi32(_mm256_madd_epi16(q8sums, scales_16), 5); __m256i sumi = _mm256_setzero_si256(); int is = 0; for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m3), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(12)), 2); + const __m256i q4h_2 = _mm256_and_si256(q4bitsH, _mm256_set1_epi8(48)); + const __m256i q4h_3 = _mm256_srli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(-64)), 2); - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m15), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m15), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m15), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m15), q4h_3); const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); @@ -2214,6 +2363,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } + sumi = _mm256_sub_epi32(sumi, q8sclsub); acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); } diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 33c6cb65098..af1cebad131 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -531,7 +531,6 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t UNUSED(bs); - __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); // Permute mask used for easier vector processing at later stages @@ -580,6 +579,7 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t if constexpr ( std::is_same_v || std::is_same_v) { + const __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); } else if constexpr (std::is_same_v) { // Load 8 E8M0 exponents and convert to float via LUT diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 88a9c9ec057..5d1ca5ffcc3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -306,6 +306,7 @@ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { #if !defined(__ARM_FEATURE_DOTPROD) +// NOTE: this fallback produces the same total sum as native vdotq_s32 but with different per-lane grouping — do not use when individual lane values matter. inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); @@ -319,6 +320,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif // !defined(__ARM_FEATURE_DOTPROD) +static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo, + const int8x8_t q4_hi, const int8x8_t q8_hi) { + const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo); + const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi); + const int32x4_t sum_lo = vpaddlq_s16(p_lo); + const int32x4_t sum_hi = vpaddlq_s16(p_hi); + return vaddq_s32(sum_lo, sum_hi); +} + #endif // defined(__ARM_NEON) #ifdef __wasm_simd128__ diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b323bd9b06..2b3eb5b5ce6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -217,6 +217,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_F16, .nrows = 1, }, + [GGML_TYPE_Q1_0] = { + .from_float = quantize_row_q1_0, + .vec_dot = ggml_vec_dot_q1_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q4_0] = { .from_float = quantize_row_q4_0, .vec_dot = ggml_vec_dot_q4_0_q8_0, @@ -2350,11 +2356,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: { - n_tasks = n_threads; + const int64_t n_heads = node->src[1]->ne[1]; + n_tasks = MIN(n_threads, n_heads); } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: @@ -2871,8 +2881,12 @@ struct ggml_cplan ggml_graph_plan( const int64_t ne11 = node->src[1]->ne[1]; // H const int64_t ne12 = node->src[1]->ne[2]; // Channels In - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + GGML_ASSERT(node->src[0]->type == GGML_TYPE_F16 || node->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(node->src[1]->type == GGML_TYPE_F32); + + cur += ggml_type_size(node->src[0]->type) * ne00 * ne01 * ne02 * ne03; + cur += ggml_type_size(node->src[0]->type) * ne10 * ne11 * ne12; + } break; case GGML_OP_TOP_K: { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index ddf1737a317..128883b41ce 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -195,6 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .free = */ ggml_backend_cpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 9bcc18d442c..0ecf7ae02ac 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -1461,7 +1461,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) && - ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { + ggml_ne(op->src[1], 3) == 1) { return true; } } @@ -1473,10 +1473,12 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } else { + if (op->src[0]->type != GGML_TYPE_F16) { + return nullptr; + } std::array kernel_chain; const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain); - const bool has_kernel = slot_total > 0; - if (has_kernel && op->src[1]->ne[1] > 1) { + if (slot_total > 0 && op->src[1]->ne[1] > 1) { if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { return nullptr; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index c89e5076f26..e13828e3be6 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -180,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { } #endif +#if defined(__riscv_v_intrinsic) +template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { + return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { + return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { + return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { + return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} +#endif + #if defined(__riscv_zvfh) -template <> -inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { +template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); } -inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { +template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); } -inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { +template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); } -inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { +template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); } -inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { - return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); -} -inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { - return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); -} -inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { - return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); -} -inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { - return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); -} #endif #if defined(__riscv_zvfbfwma) -inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { +template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); } -inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { +template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); } -inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { +template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); } +template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) { + return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} #endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -272,7 +277,7 @@ inline float hsum(__m512 x) { } #endif // __AVX512F__ -#if defined(__riscv_zvfh) +#if defined(__riscv_v_intrinsic) inline float hsum(vfloat32m1_t x) { return __riscv_vfmv_f_s_f32m1_f32( __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1())); @@ -379,6 +384,21 @@ template <> inline __m256bh load(const float *p) { } #endif +#if defined(__riscv_v_intrinsic) +template <> inline vfloat32m1_t load(const float *p) { + return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t load(const float *p) { + return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t load(const float *p) { + return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t load(const float *p) { + return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); +} +#endif + #if defined(__riscv_zvfh) template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) { return __riscv_vle16_v_f16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2()); @@ -392,18 +412,6 @@ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) { template <> inline vfloat16m4_t load(const ggml_fp16_t *p) { return __riscv_vle16_v_f16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4()); } -template <> inline vfloat32m1_t load(const float *p) { - return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); -} -template <> inline vfloat32m2_t load(const float *p) { - return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); -} -template <> inline vfloat32m4_t load(const float *p) { - return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); -} -template <> inline vfloat32m8_t load(const float *p) { - return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); -} #endif #if defined(__riscv_zvfbfwma) @@ -416,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) { template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) { return __riscv_vle16_v_bf16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2()); } +template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4()); +} #endif -#if defined(__riscv_zvfh) +#if defined(__riscv_v_intrinsic) template T set_zero(); -template <> inline vfloat16mf2_t set_zero() { - return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2()); -} -template <> inline vfloat16m1_t set_zero() { - return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1()); -} -template <> inline vfloat16m2_t set_zero() { - return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2()); -} -template <> inline vfloat16m4_t set_zero() { - return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4()); -} template <> inline vfloat32m1_t set_zero() { return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1()); } @@ -449,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() { #if defined(__riscv_v_intrinsic) template size_t vlmax() { - if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m1(); } + if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m1(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m2(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m4(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m8(); } + #if defined (__riscv_zvfh) + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } + #endif + #if defined (__riscv_zvfbfwma) + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } + #endif return 0; } #endif @@ -2314,6 +2321,9 @@ class tinyBLAS_Q0_PPC { } void matmul(int64_t m, int64_t n) { + #if defined(_AIX) || defined(__BIG_ENDIAN__) + mnpack(0, m, 0, n); + #else const int64_t mc = 64; const int64_t kc = 64; int64_t nc = 64; @@ -2327,7 +2337,6 @@ class tinyBLAS_Q0_PPC { } else { n_aligned = (n / 64) * 64; } - if (n_aligned > 0) { if (n_aligned % 64 == 0) nc = 64; else if (n_aligned == n) nc = n; @@ -2345,6 +2354,7 @@ class tinyBLAS_Q0_PPC { } else { mnpack(0, m, 0, n); } + #endif } private: @@ -3184,16 +3194,21 @@ class tinyBLAS_PPC { } void matmul(int64_t m, int64_t n) { + #if defined(_AIX) || defined(__BIG_ENDIAN__) + mnpack(0, m, 0, n); + #else int64_t mc = 256; int64_t nc = 256; int64_t kc = 256; if (m % mc == 0 && n % nc == 0 && k % kc == 0) { matmul_tiled(m, n, mc, nc, kc); } else { mnpack(0, m, 0, n); } + #endif } private: + __attribute__((always_inline)) inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); @@ -3204,6 +3219,7 @@ class tinyBLAS_PPC { } } + __attribute__((always_inline)) inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); @@ -3738,7 +3754,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; -#elif defined(__riscv_zvfh) +#elif defined(__riscv_v_intrinsic) #if LMUL == 1 tinyBLAS_RVV tb{ params, k, (const float *)A, lda, @@ -3802,23 +3818,25 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return true; } #elif defined(__riscv_zvfbfwma) - #if LMUL == 1 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #elif LMUL == 2 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #else // LMUL = 4 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #endif - return tb.matmul(m, n); + if (Btype == GGML_TYPE_BF16) { + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); + } #endif return false; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3f85e531daa..a9bc21da6f0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -664,6 +664,7 @@ void ggml_compute_forward_add( { ggml_compute_forward_add_non_quantized(params, dst); } break; + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1113,6 +1114,7 @@ void ggml_compute_forward_add1( GGML_ABORT("fatal error"); } } break; + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1242,6 +1244,7 @@ void ggml_compute_forward_acc( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4331,6 +4334,7 @@ void ggml_compute_forward_out_prod( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4606,6 +4610,7 @@ void ggml_compute_forward_set( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4829,6 +4834,7 @@ void ggml_compute_forward_get_rows( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5554,6 +5560,7 @@ void ggml_compute_forward_clamp( ggml_compute_forward_clamp_f16(params, dst); } break; case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6923,16 +6930,15 @@ void ggml_compute_forward_conv_3d( ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type); } -// ggml_compute_forward_conv_transpose_2d - -void ggml_compute_forward_conv_transpose_2d( - const ggml_compute_params * params, - ggml_tensor * dst) { +template +static void ggml_compute_forward_conv_transpose_2d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -6943,7 +6949,7 @@ void ggml_compute_forward_conv_transpose_2d( const int nk = ne00*ne01*ne02*ne03; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); GGML_ASSERT(nb10 == sizeof(float)); if (ith == 0) { @@ -6951,12 +6957,12 @@ void ggml_compute_forward_conv_transpose_2d( // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); - ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02); + kernel_t * dst_data = wdata + i02*ne01*ne00*ne03; for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; @@ -6968,13 +6974,17 @@ void ggml_compute_forward_conv_transpose_2d( // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + nk; for (int i12 = 0; i12 < ne12; i12++) { for (int i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); - ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + kernel_t * dst_data = wdata + i11*ne10*ne12; for (int i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + if constexpr (std::is_same_v) { + dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + } else { + dst_data[i10*ne12 + i12] = src[i10]; + } } } } @@ -6996,21 +7006,27 @@ void ggml_compute_forward_conv_transpose_2d( const int ip0 = dp*ith; const int ip1 = MIN(ip0 + dp, np); - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; + kernel_t * const wdata_src = wdata + nk; for (int i2 = ip0; i2 < ip1; i2++) { // Cout float * dst_data = (float *)((char *) dst->data + i2*nb2); - ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; for (int i11 = 0; i11 < ne11; i11++) { for (int i10 = 0; i10 < ne10; i10++) { const int i1n = i11*ne10*ne12 + i10*ne12; for (int i01 = 0; i01 < ne01; i01++) { for (int i00 = 0; i00 < ne00; i00++) { float v = 0; - ggml_vec_dot_f16(ne03, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + if constexpr (std::is_same_v) { + ggml_vec_dot_f16(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } else { + ggml_vec_dot_f32(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; } } @@ -7019,6 +7035,28 @@ void ggml_compute_forward_conv_transpose_2d( } } +void ggml_compute_forward_conv_transpose_2d( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_transpose_2d_impl(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_transpose_2d_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_conv_2d_dw struct ggml_conv_2d_dw_params { @@ -9922,13 +9960,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10139,13 +10173,9 @@ static void ggml_compute_forward_gla_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10602,13 +10632,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * r = (float *) dst->src[0]->data; float * w = (float *) dst->src[1]->data; diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7ebbb9c6f15..e5f9a4083f9 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -22,6 +22,10 @@ #define UNUSED GGML_UNUSED +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q1_0_ref(x, y, k); +} + void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { quantize_row_q4_0_ref(x, y, k); } @@ -116,6 +120,57 @@ void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + float sumi = 0.0f; + + for (int k = 0; k < 4; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + int sumi_block = 0; + + const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4]; + const int8_t * GGML_RESTRICT qy = yb->qs; + + for (int b = 0; b < 4; ++b, qy += 8) { + const unsigned mask = bits[b]; + sumi_block += ((mask & 0x01) ? qy[0] : -qy[0]) + + ((mask & 0x02) ? qy[1] : -qy[1]) + + ((mask & 0x04) ? qy[2] : -qy[2]) + + ((mask & 0x08) ? qy[3] : -qy[3]) + + ((mask & 0x10) ? qy[4] : -qy[4]) + + ((mask & 0x20) ? qy[5] : -qy[5]) + + ((mask & 0x40) ? qy[6] : -qy[6]) + + ((mask & 0x80) ? qy[7] : -qy[7]); + } + + sumi += d1 * sumi_block; + } + + sumf += d0 * sumi; + } + + *s = sumf; +} + + void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 3584aaa43e8..d4bc87a1c05 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -12,6 +12,7 @@ extern "C" { #endif // Quantization +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -36,6 +37,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dot product +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -68,6 +70,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 6b76ab3bfb1..f18758f16bb 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1365,6 +1365,7 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, } } +// Only enable these for RISC-V. #if defined __riscv_zvfh void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -1568,6 +1569,7 @@ void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert(nc % 16 == 0); UNUSED(bs); + UNUSED(nr); const int nb = n / QK_K; const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; @@ -2381,6 +2383,7 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, } } +// Only enable these for RISC-V. #if defined __riscv_zvfh void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h index 78d663e593e..4119d04f895 100644 --- a/ggml/src/ggml-cpu/simd-gemm.h +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -109,6 +109,96 @@ static void simd_gemm( C += N; } } +#elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic) +// RM accumulators + 1 B vector = RM + 1 <= 8 => RM <= 7 +// Microkernel: C[RM x vl] += A[RM x K] * B[K x N] +template +static inline void rvv_simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, size_t vl) +{ + static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4"); + + vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl); + vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6; + if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl); + if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl); + if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl); + if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl); + if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl); + if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl); + + for (int kk = 0; kk < K; kk++) { + vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl); + + acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl); + if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl); + if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl); + if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl); + if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl); + if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl); + if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl); + } + + __riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl); + if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl); + if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl); + if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl); + if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl); + if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl); + if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl); +} + +template +static inline void rvv_simd_gemm_dispatch_tail( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, int KN, int remaining_rows) +{ + if constexpr (RM > 0) { + if (remaining_rows == RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, N - jj); + } + } else { + rvv_simd_gemm_dispatch_tail(C, A, B, K, N, KN, remaining_rows); + } + } +} + +static constexpr int GEMM_RM = 7; + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + const int KN = (int)__riscv_vlenb(); + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, N - jj); + } + A += GEMM_RM * K; + C += GEMM_RM * N; + } + + int remaining_rows = M - ii; + rvv_simd_gemm_dispatch_tail(C, A, B, K, N, KN, remaining_rows); +} #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic pop diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 3198b33b509..bcd68da9aa9 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -126,7 +126,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG const int ggml_f16_epr = sve_register_length / 16; // running when 16 const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - const int np = (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); svfloat16_t sum_00 = svdup_n_f16(0.0f); svfloat16_t sum_01 = svdup_n_f16(0.0f); @@ -224,71 +224,75 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + np = n; + #elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + size_t vl = __riscv_vsetvlmax_e32m4(); + + // initialize accumulators to all zeroes + vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + // calculate step size + const size_t epr = __riscv_vsetvlmax_e16m2(); + const size_t step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 along the row dimension + for (int i = 0; i < np; i += step) { + vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); + vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); + vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); + vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); + vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); + + vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); + vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); + vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); + vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); + vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); + } - #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - size_t vl = __riscv_vsetvlmax_e32m4(); - - // initialize accumulators to all zeroes - vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - - // calculate step size - const size_t epr = __riscv_vsetvlmax_e16m2(); - const size_t step = epr * 2; - const int np = (n & ~(step - 1)); - - // unroll by 2 along the row dimension - for (int i = 0; i < np; i += step) { - vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); - vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); - vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); - vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); - vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); - - vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); - vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); - vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); - vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); - vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); - } + vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); + vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); - vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); - vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); + // leftovers + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); + vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); - // leftovers - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m2(n - i); - vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); - vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); - vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); - - vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); - vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); - } - - // reduce - vl = __riscv_vsetvlmax_e32m2(); - vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), - __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); - vl = __riscv_vsetvlmax_e32m1(); - vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), - __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); - vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( - acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); - - vl = __riscv_vsetvlmax_e32m2(); - vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), - __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); - vl = __riscv_vsetvlmax_e32m1(); - vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), - __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); - vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( - acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); - sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); - sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); + vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); + } + // reduce + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), + __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), + __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); + vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( + acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), + __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), + __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); + vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( + acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); + sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + np = n; + #else + const int np = 0; + #endif #else const int np = (n & ~(GGML_F16_STEP - 1)); @@ -313,21 +317,17 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { GGML_F16_VEC_REDUCE(sumf[k], sum[k]); } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); - } - } #endif #else - for (int i = 0; i < n; ++i) { + // scalar path + const int np = 0; +#endif + // scalar and leftovers + for (int i = np; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); } } -#endif for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { s[i] = (float)sumf[i]; @@ -532,40 +532,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, svst1_f16(pg, (__fp16 *)(y + np2), hy); } np = n; -#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - int np = (n & ~(step - 1)); - - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); - - vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } +#elif defined(__riscv_v_intrinsic) // implies __riscv_v_intrinsic + #if defined (__riscv_zvfh) + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } - np = n; + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; + #else + // fall to scalar path + const int np = 0; + #endif #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -584,10 +589,11 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, } } #else + // scalar path const int np = 0; #endif - // leftovers + // scalar and leftovers for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } @@ -785,7 +791,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float const int ggml_f16_step = 2 * ggml_f16_epr; GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np = (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); svfloat16_t ay1, ay2; for (int i = 0; i < np; i += ggml_f16_step) { @@ -805,36 +811,43 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float svfloat16_t out = svmul_f16_m(pg, hy, vx); svst1_f16(pg, (__fp16 *)(y + np), out); } -#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - const int np = (n & ~(step - 1)); + np = n; +#elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; + #else + // fall to scalar path + const int np = 0; + #endif #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -850,17 +863,14 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } #else - // scalar - for (int i = 0; i < n; ++i) { + // scalar path + const int np = 0; +#endif + // scalar and leftovers + for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); } -#endif } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } @@ -1026,12 +1036,12 @@ inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } -//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = ggml_table_gelu_quick_f16[i16[i]]; -// } -//} +inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = ggml_table_gelu_quick_f16[i16[i]]; + } +} #ifdef GGML_GELU_QUICK_FP16 inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { @@ -1050,13 +1060,6 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * } #endif -inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - for (int i = 0; i < n; ++i) { - float v = GGML_CPU_FP16_TO_FP32(x[i]); - y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v)))); - } -} - // Sigmoid Linear Unit (SiLU) function inline static float ggml_silu_f32(float x) { return x/(1.0f + expf(-x)); diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 262f88204e0..b54d4a6b107 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) + list(APPEND GGML_SOURCES_CUDA + template-instances/fattn-vec-instance-f16-f16.cu + template-instances/fattn-vec-instance-q4_0-q4_0.cu + template-instances/fattn-vec-instance-q8_0-q8_0.cu + template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-cuda @@ -182,6 +181,16 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver) endif() + if (GGML_CUDA_NCCL) + find_package(NCCL) + if (NCCL_FOUND) + add_compile_definitions(GGML_USE_NCCL) + target_link_libraries(ggml-cuda PRIVATE NCCL::NCCL) + else() + message(STATUS "Warning: NCCL not found, performance for multiple CUDA GPUs will be suboptimal") + endif() + endif() + set(CUDA_CXX_FLAGS "") set(CUDA_FLAGS -use_fast_math -extended-lambda) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 4896669c32a..0f3f017b534 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -47,35 +47,59 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, #ifdef STRIDED_ITERATOR_AVAILABLE auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); #else - ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + // offset_iterator needs to populate nrows + 1 elements, so we also have to ceildiv nrows + 1 by block_size + const int nrows_offset = nrows + 1; + ggml_cuda_pool_alloc offsets_alloc(pool, nrows_offset); int * offset_iterator = offsets_alloc.get(); - const dim3 offset_grid((nrows + block_size - 1) / block_size); + const dim3 offset_grid((nrows_offset + block_size - 1) / block_size); init_offsets<<>>(offset_iterator, ncols, nrows); #endif CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); size_t temp_storage_bytes = 0; + bool is_capturing = false; +#ifdef USE_CUDA_GRAPH + // Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does. + // See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149 + // TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release. + cudaStreamCaptureStatus capture_status; + CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status)); + is_capturing = (capture_status != cudaStreamCaptureStatusNone); +#endif // USE_CUDA_GRAPH + if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs( + nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - offset_iterator, offset_iterator + 1, stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, - stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } @@ -84,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream)); } } else { if (nrows == 1) { - DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, offset_iterator, - offset_iterator + 1, stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 7339fe0c070..adb4d5f0cb9 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -472,6 +472,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, } } +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 62bc950111b..12624785b44 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 36d8a3aaab2..10817505d9f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -65,8 +65,9 @@ #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 +#define GGML_CUDA_CC_CDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x950) // MI350X/MI355X // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 @@ -87,7 +88,8 @@ #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) #define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) -#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_CDNA4) +#define GGML_CUDA_CC_IS_CDNA4(cc) (cc >= GGML_CUDA_CC_CDNA4 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons @@ -186,6 +188,10 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) +#ifdef GGML_USE_NCCL +#define NCCL_CHECK(err) CUDA_CHECK_GEN(err, ncclSuccess, ncclGetErrorString) +#endif // GGML_USE_NCCL + #if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) static const char * cu_get_error_str(CUresult err) { const char * err_str; @@ -263,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) { #define FLASH_ATTN_AVAILABLE #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) -#if defined(TURING_MMA_AVAILABLE) -#define LDMATRIX_TRANS_AVAILABLE -#endif // defined(TURING_MMA_AVAILABLE) - static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); @@ -799,6 +801,47 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { +#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 + // ROCm does not support fp8 in software on devices with fp8 hardware, + // but CDNA3 supports only e4m3_fnuz (no inf). + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. + const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +#else +#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. + const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +#else + if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f + return 0.0f; + } + const int exp = (x >> 3) & 0xF; + const int man = x & 0x7; + float raw; + if (exp == 0) { + raw = ldexpf((float) man, -9); + } else { + raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7); + } + return static_cast(raw / 2); +#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) +#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 +} + +static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) { +#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only + if (!(x > 0.0f)) { + return 0; + } + const __nv_fp8_e4m3 xf(x); + return xf.__x; +#else + NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell +#endif // defined(BLACKWELL_MMA_AVAILABLE) +} + __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { const uint8_t sign_bit = (x < 0.0f) << 3; float ax = fabsf(x) * e; @@ -889,6 +932,13 @@ struct ggml_cuda_type_traits { static constexpr int qr = 1; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK1_0; + static constexpr int qr = QR1_0; + static constexpr int qi = QI1_0; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK4_0; @@ -931,6 +981,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI_MXFP4; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_NVFP4; + static constexpr int qr = QR_NVFP4; + static constexpr int qi = QI_NVFP4; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; @@ -1121,19 +1178,6 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_cuda_graph_node_properties { - void * node_data; - ggml_op node_op; - enum ggml_type node_type; - int32_t flags; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - void * src_data[GGML_MAX_SRC]; - int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; -}; - -static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial"); - struct ggml_cuda_graph { #ifdef USE_CUDA_GRAPH ~ggml_cuda_graph() { @@ -1150,13 +1194,15 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool warmup_complete = false; - std::vector props; - - // these are extra tensors (inputs) that participate in the ggml graph but are not nodes - // they properties also have to match in order to be able to safely reuse a CUDA graph - // ref: https://github.com/ggml-org/llama.cpp/pull/18583 - // ref: https://github.com/ggml-org/llama.cpp/pull/19165 - std::vector extra; + uint64_t uid = 0; + int64_t last_used_time = 0; + struct node_properties { + ggml_tensor node; + void * node_src_data_ptrs[GGML_MAX_SRC]; + int64_t node_src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t node_src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + }; + std::vector node_props; bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); @@ -1331,12 +1377,28 @@ struct ggml_backend_cuda_context { // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) std::unordered_map> cuda_graphs; + int64_t last_graph_eviction_sweep = 0; + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + const int64_t time_now = ggml_time_us(); + + // sweep every 5s, evicting cuda graphs unused for >=10s + if (time_now - last_graph_eviction_sweep >= 5'000'000) { + last_graph_eviction_sweep = time_now; + for (auto it = cuda_graphs.begin(); it != cuda_graphs.end(); ) { + if (time_now - it->second->last_used_time >= 10'000'000) { + it = cuda_graphs.erase(it); + } else { + ++it; + } + } + } + auto it = cuda_graphs.find(first_node_ptr); if (it == cuda_graphs.end()) { - cuda_graphs[first_node_ptr] = std::make_unique(); - return cuda_graphs[first_node_ptr].get(); + it = cuda_graphs.emplace(first_node_ptr, std::make_unique()).first; } + it->second->last_used_time = time_now; return it->second.get(); } diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index e9ffd274b99..102f944f924 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,96 +1,79 @@ #include "concat.cuh" // contiguous kernels -static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } - - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (nidx < ne00) { // src0 - int offset_src = - nidx + - blockIdx.y * ne00 + - blockIdx.z * ne00 * gridDim.y; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - (nidx - ne00) + - blockIdx.y * (ne0 - ne00) + - blockIdx.z * (ne0 - ne00) * gridDim.y; - dst[offset_dst] = y[offset_src]; - } -} - -static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } +template +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x, + const float * y, + float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]"); + + const int64_t n = ne0 * ne1 * ne2; + + for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) { + if constexpr (dim == 0) { + const int64_t row = i / ne0; + const int64_t i0 = i - row * ne0; + + if (i0 < ne00) { + dst[i] = x[row * ne00 + i0]; + } else { + dst[i] = y[row * (ne0 - ne00) + (i0 - ne00)]; + } + } else if constexpr (dim == 1) { + const int64_t dst_plane = ne0 * ne1; + const int64_t src0_plane = ne0 * ne01; + const int64_t src1_plane = dst_plane - src0_plane; + const int64_t i2 = i / dst_plane; + const int64_t i01 = i - i2 * dst_plane; + + if (i01 < src0_plane) { + dst[i] = x[i2 * src0_plane + i01]; + } else { + dst[i] = y[i2 * src1_plane + (i01 - src0_plane)]; + } + } else { + const int64_t src0_size = ne0 * ne1 * ne02; - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (blockIdx.y < (unsigned)ne01) { // src0 - int offset_src = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * ne01; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + - (blockIdx.y - ne01) * ne0 + - blockIdx.z * ne0 * (gridDim.y - ne01); - dst[offset_dst] = y[offset_src]; + if (i < src0_size) { + dst[i] = x[i]; + } else { + dst[i] = y[i - src0_size]; + } + } } } -static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } - - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (blockIdx.z < (unsigned)ne02) { // src0 - int offset_src = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + - blockIdx.y * ne0 + - (blockIdx.z - ne02) * ne0 * gridDim.y; - dst[offset_dst] = y[offset_src]; - } -} +static void concat_f32_cuda(const float * x, + const float * y, + float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int dim, + cudaStream_t stream) { + const int64_t n = ne0 * ne1 * ne2; + const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; -static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { - int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; - dim3 gridDim(num_blocks, ne1, ne2); if (dim == 0) { - concat_f32_dim0<<>>(x, y, dst, ne0, ne00); + concat_f32_cont<0> + <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { - concat_f32_dim1<<>>(x, y, dst, ne0, ne01); + concat_f32_cont<1> + <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } - concat_f32_dim2<<>>(x, y, dst, ne0, ne02); + concat_f32_cont<2><<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); } // non-contiguous kernel (slow) diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cu b/ggml/src/ggml-cuda/conv2d-transpose.cu index 03224e404d3..6cbd6f879e6 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cu +++ b/ggml/src/ggml-cuda/conv2d-transpose.cu @@ -1,12 +1,20 @@ -#include - #include "conv2d-transpose.cuh" -#include "ggml.h" - -__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, - float * __restrict__ output, const int in_w, const int in_h, const int out_w, - const int out_h, const int kernel_w, const int kernel_h, const int stride, - const int c_in, const int c_out, const int batches) { +#include "convert.cuh" + +template +static __global__ void conv2d_transpose_kernel(const float * __restrict__ input, + const kernel_t * __restrict__ kernel, + float * __restrict__ output, + const int in_w, + const int in_h, + const int out_w, + const int out_h, + const int kernel_w, + const int kernel_h, + const int stride, + const int c_in, + const int c_out, + const int batches) { const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; const int total_elements = out_w * out_h * c_out * batches; @@ -26,24 +34,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { for (int kh = 0; kh < kernel_h; ++kh) { int in_y = out_y_idx - kh; - if (in_y < 0 || in_y % stride) continue; + if (in_y < 0 || in_y % stride) { + continue; + } in_y /= stride; - if (in_y >= in_h) continue; + if (in_y >= in_h) { + continue; + } for (int kw = 0; kw < kernel_w; ++kw) { int in_x = out_x_idx - kw; - if (in_x < 0 || in_x % stride) continue; + if (in_x < 0 || in_x % stride) { + continue; + } in_x /= stride; - if (in_x >= in_w) continue; + if (in_x >= in_w) { + continue; + } const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; const int kernel_idx = (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; - float input_val = input[input_idx]; - half kern_val = kernel[kernel_idx]; + float input_val = input[input_idx]; + kernel_t kern_val = kernel[kernel_idx]; - accumulator += input_val * (float) kern_val; + accumulator += input_val * ggml_cuda_cast(kern_val); } } } @@ -56,11 +72,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor const ggml_tensor * kernel = dst->src[0]; const ggml_tensor * input = dst->src[1]; - GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); const float * input_data = (const float *) input->data; float * output_data = (float *) dst->data; - const half * kernel_data = (const half *) kernel->data; + const void * kernel_data = kernel->data; const int input_w = input->ne[0]; const int input_h = input->ne[1]; @@ -82,10 +99,17 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(ggml_is_contiguous(kernel)); GGML_ASSERT(ggml_is_contiguous(dst)); - const int total = (output_w * output_h * channels_out * batches); + const int total = output_w * output_h * channels_out * batches; const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; - conv2d_transpose_kernel<<>>( - input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, - channels_in, channels_out, batches); + if (kernel->type == GGML_TYPE_F16) { + conv2d_transpose_kernel<<>>( + input_data, (const half *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + + } else { + conv2d_transpose_kernel<<>>( + input_data, (const float *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + } } diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cuh b/ggml/src/ggml-cuda/conv2d-transpose.cuh index c9430b24850..72889c5f0fa 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cuh +++ b/ggml/src/ggml-cuda/conv2d-transpose.cuh @@ -1,4 +1,5 @@ #include "common.cuh" #define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 + void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b70492c7d6c..61630a35a29 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -617,6 +617,45 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_mxfp4<<>>(vx, y); } +template +static __global__ void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + const int64_t i = blockIdx.x; + const int tid = threadIdx.x; + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_cuda_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_cuda_cast(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_cuda_cast(d * kvalues_mxfp4[q >> 4]); +} + +template +static void dequantize_row_nvfp4_cuda( + const void * vx, + dst_t * y, + const int64_t k, + cudaStream_t stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + dequantize_block_nvfp4<<>>(vx, y, k); +} template static __global__ void convert_unary( const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, @@ -672,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -715,6 +756,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -726,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -766,6 +811,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -779,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -800,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -821,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 09f9a33f909..f5d37c7b998 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -41,6 +41,16 @@ template return __bfloat162float(x); } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { +#ifdef GGML_USE_HIP + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#else +#if __CUDA_ARCH__ >= 800 + return __bfloat1622float2(x); +#else + return make_float2(__bfloat162float(x.x), __bfloat162float(x.y)); +#endif // __CUDA_ARCH__ >= 800 +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc..9ae1342fc0e 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,5 +1,27 @@ #include "common.cuh" +static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q1_0 * x = (const block_q1_0 *) vx; + + const float d = x[ib].d; + + const int bit_index_0 = iqs; + const int bit_index_1 = iqs + 1; + + const int byte_index_0 = bit_index_0 / 8; + const int bit_offset_0 = bit_index_0 % 8; + + const int byte_index_1 = bit_index_1 / 8; + const int bit_offset_1 = bit_index_1 % 8; + + // Extract bits: 1 = +d, 0 = -d (branchless) + const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; + const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; + + v.x = (2*bit_0 - 1) * d; + v.y = (2*bit_1 - 1) * d; +} + static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e9abdf288c4..beeb5238946 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( return sum; } +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) nv_bfloat162 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + // FIXME replace macros in vector FA kernel with templating and use FP32 for BF16 + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ } } +template +static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + static_assert(std::is_same_v, "BF16 V dequantization only supports float output"); + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = ggml_cuda_cast(tmp[l]); + } +} + template static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0; + } else if constexpr (type_K == GGML_TYPE_BF16) { + return vec_dot_fattn_vec_KQ_bf16; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; + } else if constexpr (type_V == GGML_TYPE_BF16) { + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; @@ -628,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) -static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, - const int ne11, const int ne12, const int nbatch_fa) { +static __global__ void flash_attn_stream_k_fixup_uniform( + float * __restrict__ dst, + const float2 * __restrict__ dst_fixup, + const int ne01, const int ne02, + const int ne12, const int nblocks_stream_k, + const int gqa_ratio, + const int blocks_per_tile, + const uint3 fd_iter_j_z_ne12, + const uint3 fd_iter_j_z, + const uint3 fd_iter_j) { + constexpr int ncols = ncols1*ncols2; + + const int tile_idx = blockIdx.x; // One block per output tile. + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks. + const int b_first = tile_idx * blocks_per_tile; + const int b_last = b_first + blocks_per_tile - 1; + + const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols); + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12); + const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z); + const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j); + + const int sequence = dm0.x; + const int z_KV = dm1.x; + const int zt_gqa = dm2.x; + const int jt = dm2.y; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { + return; + } + + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup + float dst_val = *dst; + float max_val; + float rowsum; + { + const float2 tmp = dst_fixup[b_last*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + // Combine with all previous blocks in this tile. + for (int bidx = b_last - 1; bidx >= b_first; --bidx) { + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc]; + + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles +// (blocks_num.x not a multiple of ntiles_dst) +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_stream_k_fixup_general( + float * __restrict__ dst, + const float2 * __restrict__ dst_fixup, + const int ne01, const int ne02, + const int gqa_ratio, + const int total_work, + const uint3 fd_iter_k_j_z_ne12, + const uint3 fd_iter_k_j_z, + const uint3 fd_iter_k_j, + const uint3 fd_iter_k) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -641,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; - - const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; - const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; - const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0; + const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index - const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); - const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); - const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12); + const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z); + const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j); + const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k); + + const int sequence = dm0.x; + const int z_KV = dm1.x; + const int zt_gqa = dm2.x; + const int jt = dm3.x; const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. @@ -685,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup( // Iterate over previous blocks and compute the combined results. // All CUDA blocks that get here must have a previous block that needs a fixup. + const int tile_kbc0 = fastdiv(kbc0, fd_iter_k); int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*total_work / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -714,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup( max_val = max_val_new; // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) { break; } bidx--; @@ -928,14 +1063,28 @@ void launch_fattn( const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); - const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; - blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst; + blocks_num.x = ntiles_dst; blocks_num.y = 1; blocks_num.z = 1; + if(use_stream_k) { + const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst); + // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup). + // Only do this if the occupancy loss from rounding is acceptable. + const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst; + const int max_efficiency_loss_percent = 5; + const int efficiency_loss_percent = nblocks_stream_k_rounded > 0 + ? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw + : 100; + const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent + ? nblocks_stream_k_rounded + : nblocks_stream_k_raw; + + blocks_num.x = nblocks_stream_k; + } + if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); } @@ -1015,13 +1164,40 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if (stream_k) { - if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) { + // Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile. + const int nblocks_sk = (int)blocks_num.x; + const int bpt = nblocks_sk / ntiles_dst; + + const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]); + const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa); + const uint3 fd2 = init_fastdiv_values(ntiles_x); + + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; + + flash_attn_stream_k_fixup_uniform + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, + Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk, + gqa_ratio, bpt, fd0, fd1, fd2); + } else if (ntiles_dst % blocks_num.x != 0) { + // General fixup for the cases where nblocks_stream_k < ntiles_dst. + const int total_work = ntiles_KV * ntiles_dst; + + const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]); + const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa); + const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x); + const uint3 fd_k = init_fastdiv_values(ntiles_KV); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup_general <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa); + ((float *) KQV->data, dst_tmp_meta.ptr, + Q->ne[1], Q->ne[2], gqa_ratio, total_work, + fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index fff70c8eb89..3f01e858de7 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,14 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -80,6 +88,14 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -89,6 +105,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); @@ -103,6 +124,13 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -286,12 +314,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + // The minimum granularity is 16 bytes. + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; if constexpr (use_cp_async) { + static_assert(warp_size == 32, "bad warp_size"); static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); @@ -329,11 +358,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // 6: max 1*16= 16 bytes, 8 half ggml_cuda_unroll<6>{}(load); } else { - // TODO use ggml_cuda_memcpy_1 + const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}}; auto load = [&] __device__ (const int n) { - const int stride_k = warp_size >> n; - const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_k = 32 >> n; + const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { @@ -352,15 +381,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); + ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4, + !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero); } } }; - // 1: max 32* 4=128 bytes, 64 half - // 2: max 16* 4= 64 bytes, 32 half - // 3: max 8* 4= 32 bytes, 16 half - // 4: max 4* 4= 16 bytes, 8 half - ggml_cuda_unroll<4>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } } @@ -843,11 +875,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } -#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - T_A_VKQ A_identity; - make_identity_mat(A_identity); -#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { @@ -878,29 +905,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. -#if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); -#elif defined(AMD_MFMA_AVAILABLE) - // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. - // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. - // Load with transposed addressing: 4 strided half loads. - { - const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; - const half * xs0_h = (const half *) xs0; - const int stride_h = stride_tile_V * 2; // stride in half units - half * A_h = (half *) A.x; -#pragma unroll - for (int l = 0; l < 4; ++l) { - A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; - } - } -#else - // TODO: Try to transpose tile_V when loading gmem to smem. - // Use mma to transpose T_A_VKQ for RDNA. - T_A_VKQ A_trans; - load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - mma(A, A_trans, A_identity); -#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { @@ -1221,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); + const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col)); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -1552,7 +1557,7 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { NO_DEVICE_CODE; return; } @@ -1815,11 +1820,24 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); + // The number of viable configurations for Deepseek is very limited: extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); +// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build: +extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); + // For GLM 4.7 Flash extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 3fcb09b7a2b..d60634cc0e9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,14 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 320: { + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst); + } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f3fa80ab23d..585f2c22853 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -124,6 +130,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) @@ -187,6 +199,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -251,6 +270,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) @@ -767,7 +793,7 @@ static __global__ void flash_attn_tile( #ifdef GGML_USE_WMMA_FATTN (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) || #endif // GGML_USE_WMMA_FATTN - (use_logit_softcap && !(DV == 128 || DV == 256)) + (use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512)) ) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1098,7 +1124,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr size_t nbytes_shared = 0; #ifdef GGML_USE_HIP - if constexpr (DV <= 128) { + if constexpr (DKQ <= 128) { if (Q->ne[1] > 32/ncols2) { constexpr int cols_per_block = 64; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; @@ -1112,7 +1138,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm #endif // GGML_USE_HIP #ifndef GGML_USE_HIP - if constexpr (DV <= 256) + if constexpr (DKQ <= 256) #endif // GGML_USE_HIP { if (Q->ne[1] > 16/ncols2) { @@ -1126,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm } } - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } } if constexpr (ncols2 <= 8) { @@ -1192,7 +1220,26 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DV == 512) { + if constexpr (DKQ == 320) { + // This branch is only used for Mistral Small 4 which has a GQA ratio of 32. + // On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM. + // On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older). + // Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block. +#ifdef GGML_USE_HIP + if (use_gqa_opt && gqa_ratio % 32 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } +#else + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } +#endif // GGML_USE_HIP + GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); + } + + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1203,7 +1250,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DV <= 256) { + if constexpr (DKQ <= 512 && DKQ != 320) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1214,13 +1261,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); return; } - - launch_fattn_tile_switch_ncols1(ctx, dst); - return; } GGML_ABORT("fatal error"); } @@ -1255,4 +1304,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(320, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 7cbe32633e5..f0bd42a5761 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec( #endif // GGML_USE_HIP constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); - constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16; #ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else @@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll @@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 85c177f496f..8256591b21d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -75,13 +75,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + if constexpr (DKQ <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; + } else { + GGML_ABORT("fatal error"); } - - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); - return; } if (use_gqa_opt && gqa_ratio > 4) { @@ -94,12 +98,16 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con return; } - if (use_gqa_opt && gqa_ratio > 1) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); - return; - } + if constexpr (DKQ <= 256) { + if (use_gqa_opt && gqa_ratio > 1) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + } else { + GGML_ABORT("fatal error"); + } } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -135,6 +143,26 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 320: + // For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build). + GGML_ASSERT(V->ne[0] == 256); + { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 32 == 0); + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst); + } + break; + case 512: + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); + break; case 576: { // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); @@ -224,6 +252,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,6 +260,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) @@ -238,6 +268,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) @@ -245,6 +276,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) @@ -252,6 +284,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) @@ -259,10 +292,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -325,6 +368,22 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 320: + if (V->ne[0] != 256 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 32 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 512: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 576: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; @@ -355,6 +414,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_BF16: break; default: return BEST_FATTN_KERNEL_NONE; @@ -408,7 +468,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -441,7 +501,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 1ce6d5f31b5..6b44bec7317 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,7 +1,8 @@ #include "gated_delta_net.cuh" template -__global__ void gated_delta_net_cuda(const float * q, +__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) +gated_delta_net_cuda(const float * q, const float * k, const float * v, const float * g, @@ -38,7 +39,7 @@ __global__ void gated_delta_net_cuda(const float * q, const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; state += state_offset; - curr_state += state_offset; + curr_state += state_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -46,10 +47,11 @@ __global__ void gated_delta_net_cuda(const float * q, constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; float s_shard[rows_per_lane]; // state is stored transposed: M[col][i] = S[i][col], row col is contiguous + #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[col * S_v + i]; + s_shard[r] = curr_state[i]; } for (int t = 0; t < n_tokens; t++) { @@ -63,6 +65,16 @@ __global__ void gated_delta_net_cuda(const float * q, const float beta_val = *beta_t; + // Cache k and q in registers + float k_reg[rows_per_lane]; + float q_reg[rows_per_lane]; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + k_reg[r] = k_t[i]; + q_reg[r] = q_t[i]; + } + if constexpr (!KDA) { const float g_val = expf(*g_t); @@ -70,8 +82,7 @@ __global__ void gated_delta_net_cuda(const float * q, float kv_shard = 0.0f; #pragma unroll for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - kv_shard += s_shard[r] * k_t[i]; + kv_shard += s_shard[r] * k_reg[r]; } float kv_col = warp_reduce_sum(kv_shard); @@ -83,9 +94,8 @@ __global__ void gated_delta_net_cuda(const float * q, float attn_partial = 0.0f; #pragma unroll for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; - attn_partial += s_shard[r] * q_t[i]; + s_shard[r] = g_val * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; } float attn_col = warp_reduce_sum(attn_partial); @@ -99,7 +109,7 @@ __global__ void gated_delta_net_cuda(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; + kv_shard += expf(g_t[i]) * s_shard[r] * k_reg[r]; } float kv_col = warp_reduce_sum(kv_shard); @@ -113,8 +123,8 @@ __global__ void gated_delta_net_cuda(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col; - attn_partial += s_shard[r] * q_t[i]; + s_shard[r] = expf(g_t[i]) * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; } float attn_col = warp_reduce_sum(attn_partial); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 2fab33243dd..e99cba63d34 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_Q1_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_Q4_0: get_rows_cuda_q(src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5a0be4a472a..fbe0fa06242 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -82,7 +82,6 @@ #include #include #include -#include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -126,7 +125,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) if (err == hipSuccess) { // hipMemAdviseSetCoarseGrain is an optional performance hint; // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs). - cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); + (void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); (void)hipGetLastError(); // clear any error } @@ -325,6 +324,22 @@ static ggml_cuda_device_info ggml_cuda_init() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + if (getenv("GGML_CUDA_P2P") != nullptr) { + for (int id = 0; id < info.device_count; ++id) { + ggml_cuda_set_device(id); + for (int id_other = 0; id_other < info.device_count; ++id_other) { + if (id == id_other) { + continue; + } + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } + } + } + } + return info; } @@ -353,15 +368,21 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } ~ggml_cuda_pool_leg() { + clear_pool(); + GGML_ASSERT(pool_size == 0); + } + + void clear_pool() { ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { CUDA_CHECK(cudaFree(b.ptr)); pool_size -= b.size; + b.ptr = nullptr; + b.size = 0; } } - GGML_ASSERT(pool_size == 0); } void * alloc(size_t size, size_t * actual_size) override { @@ -406,7 +427,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaErrorMemoryAllocation) { + (void)cudaGetLastError(); + const size_t cached_bytes = pool_size; + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n", + device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0); + CUDA_CHECK(cudaDeviceSynchronize()); + clear_pool(); + err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaSuccess) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device); + } + } + CUDA_CHECK(err); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC @@ -633,26 +667,46 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer } static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); + CUDA_CHECK(cudaMemsetAsync((char *) tensor->data + offset, value, size, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } @@ -692,6 +746,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, + /* .set_tensor_2d = */ ggml_backend_cuda_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_cuda_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, /* .clear = */ ggml_backend_cuda_buffer_clear, /* .reset = */ NULL, @@ -1004,6 +1060,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_cuda_split_buffer_clear, /* .reset = */ NULL, @@ -1080,6 +1138,148 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; +#ifdef GGML_USE_NCCL +struct ggml_backend_cuda_comm_context { + std::vector backends; + std::vector comms; + + ~ggml_backend_cuda_comm_context() { + for (ncclComm_t comm : comms) { + NCCL_CHECK(ncclCommDestroy(comm)); + } + } +}; +#endif // GGML_USE_NCCL + +static void ggml_backend_cuda_comm_free(void * comm_ctx_v) { +#ifdef GGML_USE_NCCL + if (comm_ctx_v == nullptr) { + return; + } + ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; + delete comm_ctx; +#else + GGML_UNUSED(comm_ctx_v); +#endif // GGML_USE_NCCL +} + +static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) { +#ifdef GGML_USE_NCCL + for (size_t i = 0; i < n_backends; i++) { + if (!ggml_backend_is_cuda(backends[i])) { + return nullptr; + } + } + ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context; + std::vector dev_ids; + ret->backends.reserve(n_backends); + dev_ids.reserve(n_backends); + for (size_t i = 0; i < n_backends; i++) { + ret->backends.push_back(backends[i]); + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + dev_ids.push_back(cuda_ctx->device); + } + + ret->comms.resize(n_backends); + NCCL_CHECK(ncclCommInitAll(ret->comms.data(), n_backends, dev_ids.data())); + return ret; +#else + // If NCCL is installed it is used by default for optimal performance. + // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. + // RCCL is disabled by default, users are explicitly opting in. + // Therefore print no warning for RCCL. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + static bool warning_printed = false; + if (!warning_printed) { + GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); + warning_printed = true; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + GGML_UNUSED_VARS(backends, n_backends); + return nullptr; +#endif // GGML_USE_NCCL +} + +static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) { +#ifdef GGML_USE_NCCL + const int64_t ne = ggml_nelements(tensors[0]); + // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0 + // This then causes a crash in this function + if (ne == 0) { + return true; + } + + GGML_ASSERT(comm_ctx_v != nullptr); + ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; + const size_t n_backends = comm_ctx->backends.size(); + + for (size_t i = 0; i < n_backends; ++i) { + GGML_ASSERT(tensors[i] != nullptr); + GGML_ASSERT(ggml_nelements(tensors[i]) == ne); + GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i])); + } + + // For small tensors, simply reduce them as FP32. + // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. + if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { + for (size_t i = 0; i < n_backends; ++i) { + if ((tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + ggml_cuda_set_device(cuda_ctx->device); + CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, ggml_nbytes(tensors[i]), cuda_ctx->stream())); + } + } + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + + return true; + } + + // For large tensors it's faster to compress them to BF16 for the reduction: + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32); + to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + + ggml_cuda_pool_alloc tmp[GGML_CUDA_MAX_DEVICES]; + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + tmp[i].pool = &cuda_ctx->pool(); + tmp[i].alloc(ne); + + ggml_cuda_set_device(cuda_ctx->device); + if (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) { + to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + } else { + CUDA_CHECK(cudaMemsetAsync(tmp[i].get(), 0, ne * sizeof(nv_bfloat16), cuda_ctx->stream())); + } + CUDA_CHECK(cudaGetLastError()); + } + + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + + ggml_cuda_set_device(cuda_ctx->device); + to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } + + return true; +#else + GGML_UNUSED_VARS(comm_ctx_v, tensors); + return false; +#endif // GGML_USE_NCCL +} + ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -1297,7 +1497,12 @@ static void ggml_cuda_op_mul_mat_cublas( const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); - const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + const bool use_fp16 = + src0->type != GGML_TYPE_NVFP4 && + (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + ggml_is_contiguous(src0) && + row_diff == src0->ne[1] && + dst->op_params[0] == GGML_PREC_DEFAULT; if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); @@ -1421,64 +1626,6 @@ static void ggml_cuda_op_mul_mat_cublas( GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size); } -static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { - static bool peer_access_enabled = false; - - const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; - - if (peer_access_enabled == enable_peer_access) { - return; - } - -#ifdef NDEBUG - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - CUDA_CHECK(cudaDeviceSynchronize()); - } - - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - - for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) { - if (id == id_other) { - continue; - } - if (id != main_device && id_other != main_device) { - continue; - } - - int can_access_peer; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); - if (can_access_peer) { - if (enable_peer_access) { - cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); - if (err != cudaErrorPeerAccessAlreadyEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } else { - cudaError_t err = cudaDeviceDisablePeerAccess(id_other); - if (err != cudaErrorPeerAccessNotEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } - } - } - } - - ggml_cuda_set_device(main_device); -#endif // NDEBUG - - peer_access_enabled = enable_peer_access; - - GGML_UNUSED(main_device); -} - static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { @@ -2338,7 +2485,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) { + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc); + if (ne2 <= mmvq_mmid_max) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); return; } @@ -2478,11 +2626,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { - // why is this here instead of mul_mat? - if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) { - ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); - } - switch (dst->op) { case GGML_OP_ARGMAX: ggml_cuda_argmax(ctx, dst); @@ -2840,21 +2983,43 @@ static void ggml_backend_cuda_free(ggml_backend_t backend) { } static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); } static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_set_tensor_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_get_tensor_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cuda_ctx->stream())); } static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { @@ -2865,21 +3030,21 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) { return false; } // device -> device copy - ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; - ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *) backend_src->context; + ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *) backend_dst->context; - ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; - ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; + ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context; + ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context; if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); -#endif +#endif // NDEBUG return false; } @@ -2892,7 +3057,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; #else CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream())); -#endif +#endif // GGML_CUDA_NO_PEER_COPY } // record event on src stream after the copy @@ -2941,14 +3106,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { } // [TAG_MUL_MAT_ID_CUDA_GRAPHS] - if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) { - // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs - // TODO: figure out a way to enable for larger batch sizes, without hurting performance - // ref: https://github.com/ggml-org/llama.cpp/pull/18958 - use_cuda_graph = false; + if (node->op == GGML_OP_MUL_MAT_ID) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc); + if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) { + // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs + // TODO: figure out a way to enable for larger batch sizes, without hurting performance + // ref: https://github.com/ggml-org/llama.cpp/pull/18958 + use_cuda_graph = false; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif + } } if (!use_cuda_graph) { @@ -2959,74 +3128,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { return use_cuda_graph; } -static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { - memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); - props->node_data = node->data; - props->node_op = node->op; - props->node_type = node->type; - props->flags = node->flags; - for (int i = 0; i < GGML_MAX_DIMS; i++) { - props->ne[i] = node->ne[i]; - props->nb[i] = node->nb[i]; - } - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (!node->src[i]) { - continue; - } - - props->src_data[i] = node->src[i]->data; - } - memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); -} - -static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { - if (node->data != props->node_data && node->op != GGML_OP_VIEW) { - return false; - } - - if (node->op != props->node_op) { - return false; - } - - if (node->type != props->node_type) { - return false; - } - - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != props->ne[i]) { - return false; - } - if (node->nb[i] != props->nb[i]) { - return false; - } - } - - if (node->op != GGML_OP_VIEW) { - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (!node->src[i]) { - if (props->src_data[i] != nullptr) { - return false; - } - continue; - } - - if (node->src[i]->data != props->src_data[i]) { - return false; - } - } - } - - if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { - return false; - } - - if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { - return false; - } - - return true; -} - static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { return cgraph->nodes[0]; } @@ -3037,53 +3138,37 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx const void * graph_key = ggml_cuda_graph_get_key(cgraph); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - // Check if the graph size has changed - if (graph->props.size() != (size_t)cgraph->n_nodes) { - res = true; - graph->props.resize(cgraph->n_nodes); + if (cgraph->uid != 0 && + cgraph->uid == graph->uid) { + GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", cgraph->uid); + GGML_ASSERT((int)graph->node_props.size() == cgraph->n_nodes); + return false; } - // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token - std::unordered_set seen_node; - std::vector srcs_extra; - for (int i = 0; i < cgraph->n_nodes; i++) { - bool props_match = true; - - seen_node.insert(cgraph->nodes[i]); - - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); - } - if (!props_match) { - res = true; - } - ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); - - for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { - ggml_tensor * src = cgraph->nodes[i]->src[src_idx]; - if (src && seen_node.find(src) == seen_node.end()) { - srcs_extra.push_back(src); - } - } - } + graph->uid = cgraph->uid; - if (graph->extra.size() != (size_t) srcs_extra.size()) { + // Check if the graph size has changed + if ((int)graph->node_props.size() != cgraph->n_nodes) { res = true; - graph->extra.resize(srcs_extra.size()); + graph->node_props.resize(cgraph->n_nodes); } - for (size_t i = 0; i < srcs_extra.size(); ++i) { - bool props_match = true; - - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]); + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_cuda_graph::node_properties prop = {}; + memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor)); + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (cgraph->nodes[i]->src[j]) { + prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data; + memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j])); + memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j])); + } } - if (!props_match) { + if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) { + graph->node_props[i] = prop; res = true; } - ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]); } return res; @@ -3298,6 +3383,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod return true; } +// returns whether the write (out) nodes overwrite the read nodes in operation +static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph, + const int node_idx, + const int node_count, + const int * out_nodes, + const int out_count, + const bool is_topk_moe = false) { + auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { + const int64_t a_start = (int64_t) a->data; + const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a); + + const int64_t b_start = (int64_t) b->data; + const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b); + + if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { + return true; + } + + return false; + }; + + bool is_ok = true; + // exception for topk-moe, as each row is read entirely before writing + if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) { + return true; + } + + for (int i = 0; i < out_count; ++i) { + const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; + + for (int j = node_idx; j < node_idx + node_count; ++j) { + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; + + if (!src || src->op == GGML_OP_NONE) { + continue; + } + + if (nodes_overlap(dst, src)) { + bool found = false; + + for (int k = node_idx; k < j; ++k) { + if (cgraph->nodes[k] == src) { + found = true; + break; + } + } + + if (!found) { + is_ok = false; + break; + } + } + } + } + } + + return is_ok; +} + + static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, @@ -3327,7 +3477,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, const ggml_tensor * glu = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) { - return true; + int out_nodes[] = { node_idx + 4 }; + return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1); } } @@ -3338,7 +3489,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) { - return true; + int out_nodes[] = { node_idx + 2 }; + return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1); } } @@ -3404,6 +3556,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { return false; @@ -3412,6 +3567,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD + && ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * add = cgraph->nodes[node_idx+1]; + const ggml_tensor * silu = cgraph->nodes[node_idx+2]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } + + if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + // ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias. + const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0]; + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) { + return false; + } + + return true; + } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { const ggml_tensor * unary = cgraph->nodes[node_idx]; @@ -3440,6 +3620,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * sqr = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != sqr->type) { + return false; + } + + if (!ggml_is_contiguous(unary->src[0])) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -3464,67 +3668,360 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } -// returns whether the write (out) nodes overwrite the read nodes in operation -static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph, - int node_idx, - int node_count, - int * out_nodes, - int out_count) { - auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { - const int64_t a_start = (int64_t) a->data; - const int64_t a_end = a_start + ggml_nbytes(a); +// try and fuse nodes and return the number of nodes to skip +static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, int i) { + + static bool disable_fusion = getenv("GGML_CUDA_DISABLE_FUSION") != nullptr && std::atoi(getenv("GGML_CUDA_DISABLE_FUSION")); + if (disable_fusion) { + return 0; + } + + ggml_tensor * node = cgraph->nodes[i]; + + //topk-moe + if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || + cgraph->nodes[i]->op == GGML_OP_ARGSORT) { + ggml_cuda_topk_moe_args args; + const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); + std::vector ops; + + if (can_fuse) { + const ggml_tensor * logits = node->src[0]; + ggml_tensor * weights = nullptr; + ggml_tensor * ids = nullptr; + const ggml_tensor * bias = nullptr; + const ggml_tensor * clamp = nullptr; + const ggml_tensor * scale = nullptr; + + if (!args.delayed_softmax) { + ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; + int out_nodes[2]; // nodes which can't be elided + + if (args.prob_bias) { + bias = cgraph->nodes[i + 2]->src[1]; + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS }); + out_nodes[0] = i + 4; + ids = cgraph->nodes[i + 4]; + } else { + ops.insert(ops.end(), + { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }); + out_nodes[0] = i + 3; + ids = cgraph->nodes[i + 3]; + } - const int64_t b_start = (int64_t) b->data; - const int64_t b_end = b_start + ggml_nbytes(b); + if (args.norm) { + ops.insert(ops.end(), + { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE }); + clamp = cgraph->nodes[i + ops.size() - 3]; + } + if (args.scale) { + ops.insert(ops.end(), { GGML_OP_SCALE }); + scale = cgraph->nodes[i + ops.size() - 1]; + } - if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { - return true; + weights = cgraph->nodes[i + ops.size() - 1]; + out_nodes[1] = i + ops.size() - 1; + + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + return ops.size() - 1; + } + } else if (!args.norm && !args.prob_bias) { + //special case gpt-oss, no norm, no bias. + ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); + weights = cgraph->nodes[i + 5]; + ids = cgraph->nodes[i + 1]; + const ggml_tensor * softmax = cgraph->nodes[i + 4]; + + int out_nodes[2] = { i + 1, i + 5 }; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + return ops.size() - 1; + } + } } + } - return false; - }; + //RoPE + view + set-rows + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { + ggml_tensor * rope = cgraph->nodes[i]; + ggml_tensor * set_rows = cgraph->nodes[i + 2]; - bool is_ok = true; - // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok - if (ggml_nrows(cgraph->nodes[node_idx]) == 1) { - return true; + ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); + return 2; } - for (int i = 0; i < out_count; ++i) { - const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; + // multi-(add or mul) + if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, node->op); - for (int j = node_idx; j < node_idx + node_count; ++j) { - // Loop over all srcs of all nodes in the fusion. If the src overlaps - // the destination and the src is not an intermediate node that's being - // elided, then disable fusion. + for (; n_fuse <= 6; ++n_fuse) { + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; + } + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; + } + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; + } + } - for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { - const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; + n_fuse++; - if (!src || src->op == GGML_OP_NONE) { - continue; - } + if (n_fuse > 1) { + ggml_tensor fused_node; + memcpy(&fused_node, node, sizeof(ggml_tensor)); + for (int j = 0; j < n_fuse - 1; ++j) { + fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; + if (node->op == GGML_OP_ADD) { + ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); + } else { + ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); + } + return n_fuse - 1; + } + } - if (nodes_overlap(dst, src)) { - bool found = false; + bool fused_mul_mat_vec = false; + int fused_node_count = 0; - for (int k = node_idx; k < j; ++k) { - if (cgraph->nodes[k] == src) { - found = true; - break; - } - } + // gate + glu + up + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - if (!found) { - is_ok = false; - break; + if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 4]; + ggml_tensor * gate_bias_n = glu->src[0]; + ggml_tensor * up_bias_n = glu->src[1]; + + //we don't assume the order for {gate, up}. Instead infer it from the bias tensor + ggml_tensor * gate_n = nullptr; + ggml_tensor * up_n = nullptr; + + if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { + gate_n = cgraph->nodes[i]; + up_n = cgraph->nodes[i + 2]; + } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { + gate_n = cgraph->nodes[i + 2]; + up_n = cgraph->nodes[i]; + } else { + continue; + } + + auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { + if (op_bias == GGML_OP_ADD) { + if (bias_node->src[0] == mul_node) { + return bias_node->src[1]; + } + if (bias_node->src[1] == mul_node) { + return bias_node->src[0]; } + return (ggml_tensor *) nullptr; } + GGML_ASSERT(op_bias == GGML_OP_ADD_ID); + GGML_ASSERT(bias_node->src[0] == mul_node); + return bias_node->src[1]; + }; + + ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); + ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); + + if (!up_bias_tensor || !gate_bias_tensor) { + continue; + } + + // we don't support repeating adds + if (bias_op == GGML_OP_ADD && (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || + !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { + continue; + } + + const ggml_tensor * src0 = up_n->src[0]; + const ggml_tensor * src1 = up_n->src[1]; + const ggml_tensor * ids = up_n->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 2]; + ggml_tensor * gate = glu->src[0]; + ggml_tensor * up = glu->src[1]; + + bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) || + (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); + + if (!ok) { + continue; + } + + const ggml_tensor * src0 = up->src[0]; + const ggml_tensor * src1 = up->src[1]; + const ggml_tensor * ids = up->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; } } } - return is_ok; + if (fused_mul_mat_vec) { + return fused_node_count - 1; + } + + fused_mul_mat_vec = false; + fused_node_count = 0; + + // gate + add + glu + up + add + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { + continue; + } + + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + + ggml_tensor * bias_tensor = nullptr; + if (bias_op == GGML_OP_ADD) { + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } else { + continue; + } + } else { + if (bias_node->src[0] != mm_node) { + continue; + } + bias_tensor = bias_node->src[1]; + } + + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { + continue; + } + + if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { + continue; + } + + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + } + + if (fused_mul_mat_vec) { + return fused_node_count - 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { + ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { + ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i + 2], node); + return 2; + } + + return 0; } static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { @@ -3673,345 +4170,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } - // start of fusion operations - static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); - if (!disable_fusion) { - ggml_cuda_topk_moe_args args; - - if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || - cgraph->nodes[i]->op == GGML_OP_ARGSORT) { - const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); - - std::vector ops; - - if (can_fuse) { - const ggml_tensor * logits = node->src[0]; - ggml_tensor * weights = nullptr; - ggml_tensor * ids = nullptr; - const ggml_tensor * bias = nullptr; - const ggml_tensor * clamp = nullptr; - const ggml_tensor * scale = nullptr; - - if (!args.delayed_softmax) { - ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; - int out_nodes[2]; // nodes which can't be elided - - if (args.prob_bias) { - bias = cgraph->nodes[i + 2]->src[1]; - ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }); - out_nodes[0] = i + 4; - ids = cgraph->nodes[i + 4]; - } else { - ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, - GGML_OP_GET_ROWS }); - out_nodes[0] = i + 3; - ids = cgraph->nodes[i + 3]; - } - - if (args.norm) { - ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, - GGML_OP_DIV, GGML_OP_RESHAPE }); - clamp = cgraph->nodes[i + ops.size() - 3]; - } - if (args.scale) { - ops.insert(ops.end(), { GGML_OP_SCALE }); - scale = cgraph->nodes[i + ops.size() - 1]; - } - - weights = cgraph->nodes[i + ops.size() - 1]; - out_nodes[1] = i + ops.size() - 1; - - if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { - ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); - i += ops.size() - 1; - continue; - } - } else if (!args.norm && !args.prob_bias) { - //special case gpt-oss, no norm, no bias. - ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, - GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); - weights = cgraph->nodes[i + 5]; - ids = cgraph->nodes[i + 1]; - const ggml_tensor * softmax = cgraph->nodes[i + 4]; - - int out_nodes[2] = { i + 1, i + 5 }; - if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { - ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); - i += ops.size() - 1; - continue; - } - } - } - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { - ggml_tensor * rope = cgraph->nodes[i]; - ggml_tensor * set_rows = cgraph->nodes[i + 2]; - - ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); - i += 2; - continue; - } - - if (node->op == GGML_OP_ADD) { - int n_fuse = 0; - ggml_op ops[8]; - std::fill(ops, ops + 8, GGML_OP_ADD); - - for (; n_fuse <= 6; ++n_fuse){ - if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { - break; - } - if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { - break; - } - if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { - break; - } - } - - n_fuse++; - - if (n_fuse > 1) { - ggml_tensor fused_add_node; - memcpy(&fused_add_node, node, sizeof(ggml_tensor)); - for (int j = 0; j < n_fuse - 1; ++j) { - fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; - } - fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data; - ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse); - i += n_fuse - 1; - - continue; - } - } - - bool fused_mul_mat_vec = false; - int fused_node_count = 0; - - for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { - const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - - if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { - ggml_tensor * glu = cgraph->nodes[i + 4]; - ggml_tensor * gate_bias_n = glu->src[0]; - ggml_tensor * up_bias_n = glu->src[1]; - - //we don't assume the order for {gate, up}. Instead infer it from the bias tensor - ggml_tensor * gate_n = nullptr; - ggml_tensor * up_n = nullptr; - - if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { - gate_n = cgraph->nodes[i]; - up_n = cgraph->nodes[i + 2]; - } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { - gate_n = cgraph->nodes[i + 2]; - up_n = cgraph->nodes[i]; - } else { - continue; - } - - auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { - if (op_bias == GGML_OP_ADD) { - if (bias_node->src[0] == mul_node) { - return bias_node->src[1]; - } - if (bias_node->src[1] == mul_node) { - return bias_node->src[0]; - } - return (ggml_tensor *) nullptr; - } - GGML_ASSERT(op_bias == GGML_OP_ADD_ID); - GGML_ASSERT(bias_node->src[0] == mul_node); - return bias_node->src[1]; - }; - - ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); - ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); - - if (!up_bias_tensor || !gate_bias_tensor) { - continue; - } + int nodes_to_skip = ggml_cuda_try_fuse(cuda_ctx, cgraph, i); - // we don't support repeating adds - if (bias_op == GGML_OP_ADD && - (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || - !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { - continue; - } - - const ggml_tensor * src0 = up_n->src[0]; - const ggml_tensor * src1 = up_n->src[1]; - const ggml_tensor * ids = up_n->src[2]; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate_n->src[0]; - fusion_data.x_bias = up_bias_tensor; - fusion_data.gate_bias = gate_bias_tensor; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 5; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate_n->src[0]; - fusion_data.x_bias = up_bias_tensor; - fusion_data.gate_bias = gate_bias_tensor; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 5; - break; - } - } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { - ggml_tensor * glu = cgraph->nodes[i + 2]; - ggml_tensor * gate = glu->src[0]; - ggml_tensor * up = glu->src[1]; - - bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) - || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); - - if (!ok) continue; - - const ggml_tensor * src0 = up->src[0]; - const ggml_tensor * src1 = up->src[1]; - const ggml_tensor * ids = up->src[2]; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate->src[0]; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 3; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate->src[0]; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 3; - break; - } - } - } - - if (fused_mul_mat_vec) { - i += fused_node_count - 1; - continue; - } - - fused_mul_mat_vec = false; - fused_node_count = 0; - - for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { - const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - - if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { - continue; - } - - ggml_tensor * mm_node = cgraph->nodes[i]; - ggml_tensor * bias_node = cgraph->nodes[i + 1]; - - ggml_tensor * bias_tensor = nullptr; - if (bias_op == GGML_OP_ADD) { - if (bias_node->src[0] == mm_node) { - bias_tensor = bias_node->src[1]; - } else if (bias_node->src[1] == mm_node) { - bias_tensor = bias_node->src[0]; - } else { - continue; - } - } else { - if (bias_node->src[0] != mm_node) { - continue; - } - bias_tensor = bias_node->src[1]; - } - - const ggml_tensor * src0 = mm_node->src[0]; - const ggml_tensor * src1 = mm_node->src[1]; - const ggml_tensor * ids = mm_node->src[2]; - - if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { - continue; - } - - if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { - continue; - } - - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.x_bias = bias_tensor; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 2; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 2; - break; - } - } - - if (fused_mul_mat_vec) { - i += fused_node_count - 1; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { - ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { - ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || - ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || - ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { - ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { - i += 2; - ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); - continue; - } + if (nodes_to_skip != 0) { + i += nodes_to_skip; + continue; } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); @@ -4425,6 +4588,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .free = */ ggml_backend_cuda_free, /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, + /* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, @@ -4775,12 +4940,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4811,6 +4978,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_F32: case GGML_TYPE_BF16: case GGML_TYPE_I32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5211,6 +5379,15 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { GGML_UNUSED(reg); + if (strcmp(name, "ggml_backend_comm_init") == 0) { + return (void *)ggml_backend_cuda_comm_init; + } + if (strcmp(name, "ggml_backend_comm_free") == 0) { + return (void *)ggml_backend_cuda_comm_free; + } + if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) { + return (void *)ggml_backend_cuda_comm_allreduce_tensor; + } if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { return (void *)ggml_backend_cuda_split_buffer_type; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 5d1dadd3e4f..79bb2934c5f 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -86,17 +86,12 @@ namespace ggml_cuda_mma { // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - static constexpr bool is_i_major(const data_layout dl) { - return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED; - } - static constexpr __device__ data_layout get_input_data_layout() { -#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) } template @@ -113,7 +108,6 @@ namespace ggml_cuda_mma { T x[ne] = {0}; static constexpr __device__ bool supported() { - if (I == 64 && J == 2) return true; if (I == 16 && J == 8) return true; if (I == 32 && J == 4) return true; if (I == 16 && J == 16) return true; @@ -122,7 +116,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + if constexpr (I == 16 && J == 4) { return threadIdx.x % 16; } else if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; @@ -139,8 +133,8 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> - return (2 * ((threadIdx.x / 16) % 2) + l); + if constexpr (I == 16 && J == 4) { + return threadIdx.x / 16; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x / 16) + l; } else if constexpr (I == 32 && J == 4) { @@ -154,7 +148,7 @@ namespace ggml_cuda_mma { return -1; } } -#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#elif defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -283,7 +277,7 @@ namespace ggml_cuda_mma { static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -407,7 +401,7 @@ namespace ggml_cuda_mma { return -1; } } -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) }; template @@ -701,57 +695,12 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) - static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { -#if defined(RDNA4) - const int row = t.get_i(0); - const int left_right = t.get_j(0) / 4; - const int up_down = row / 8; - const int idx = row % 8; - reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; -#else - GGML_UNUSED_VARS(t); - NO_DEVICE_CODE; -#endif // defined(RDNA4) - } - template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { -#if defined(AMD_MFMA_AVAILABLE) - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } -#elif defined(AMD_WMMA_AVAILABLE) - // All wmma layout has contiguous data when i-major. - if constexpr (is_i_major(dl)) { - // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() - constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); - if constexpr (sizeof(t.x) > aligned_copy_bytes) { - static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); - constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; -#pragma unroll - for (int i = 0; i < aligned_copy_count; ++i) { - ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } - } else { -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } -#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -764,26 +713,37 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - load_generic(t, xs0, stride); + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void load_ldmatrix( - tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2); +#else + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 4, "bad ne"); + ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } @@ -796,19 +756,26 @@ namespace ggml_cuda_mma { asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); -#else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#if 1 - // TODO: more generic handling - static_assert(sizeof(T) == 4, "bad type size"); +#elif defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 32, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4); #else - load_generic(t, xs0, stride); -#endif // 1 + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -827,23 +794,30 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void load_ldmatrix( tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE - int * xi = (int * ) t.x; + int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); +#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; @@ -1025,7 +999,8 @@ namespace ggml_cuda_mma { const floatx2_t& a_frag = reinterpret_cast(A.x[0]); const floatx2_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA1) +#elif defined(CDNA4) || defined(CDNA2) || defined(CDNA1) + // CDNA4 (gfx950) does not support xf32 MFMA, use f32 path like CDNA2/CDNA1 #pragma unroll for (int i = 0; i < 2; ++i) { acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); @@ -1040,25 +1015,35 @@ namespace ggml_cuda_mma { #endif // AMD_MFMA_AVAILABLE } - static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, - const tile<16, 8, int> & A, - const tile<8, 8, int> & B, - uint32_t a_scale, - uint32_t b_scale) { + template + static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D, + const tile<16, 8, int> & A, + const tile<8, 8, int> & B, + uint32_t a_scale, + uint32_t b_scale) { #ifdef BLACKWELL_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; float * Dxi = (float *) D.x; - asm volatile( - "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " - "%10, {0, 0}, %11, {0, 0};" - : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) - : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + if constexpr (type == GGML_TYPE_MXFP4) { + asm volatile( + "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + } else { + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + } #else GGML_UNUSED_VARS(D, A, B, a_scale, b_scale); -#endif // BLACKWELL_MMA_AVAILABLE +#endif // BLACKWELL_MMA_AVAILABLE } static __device__ __forceinline__ void mma( @@ -1187,7 +1172,7 @@ namespace ggml_cuda_mma { #elif defined(AMD_MFMA_AVAILABLE) using floatx4_t = __attribute__((ext_vector_type(4))) float; floatx4_t& acc_frag = reinterpret_cast(D.x[0]); -#if defined(CDNA3) || defined(CDNA2) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; const bf16x4_t& a_frag = reinterpret_cast(A.x[0]); const bf16x4_t& b_frag = reinterpret_cast(B.x[0]); @@ -1216,74 +1201,28 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; -#if defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); -#endif // defined(CDNA3) - +#if defined(CDNA4) || defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) #elif defined(AMD_WMMA_AVAILABLE) - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; - #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); - + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true); #endif // RDNA4 - #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1295,21 +1234,12 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; -#if defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); -#endif // defined(CDNA3) +#if defined(CDNA4) || defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) #else GGML_UNUSED_VARS(D, A, B); @@ -1328,7 +1258,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1343,12 +1273,12 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } static __device__ __forceinline__ void mma( tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1363,41 +1293,35 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA4) || defined(CDNA3) + const int64_t xA = uint32_t(A.x[0]); + const int64_t xB = uint32_t(B.x[0]); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) +#elif defined(AMD_WMMA_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false); #endif // RDNA4 #else GGML_UNUSED(D); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 9a69f41d159..e1add5e0331 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -5,6 +5,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { + case GGML_TYPE_Q1_0: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q4_0: mul_mat_q_case(ctx, args, stream); break; @@ -23,6 +26,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con case GGML_TYPE_MXFP4: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q2_K: mul_mat_q_case(ctx, args, stream); break; @@ -116,7 +122,7 @@ void ggml_cuda_mul_mat_q( || GGML_CUDA_CC_IS_CDNA(cc); // TODO: tighter pool buffer size vs q8 path - const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -127,9 +133,9 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { + if (use_native_fp4) { static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); - quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); } else { @@ -140,10 +146,8 @@ void ggml_cuda_mul_mat_q( } // Stride depends on quantization format - const int64_t s12 = use_native_mxfp4 ? - ne11 * ne10_padded * sizeof(block_fp4_mmq) / - (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32) - : + const int64_t s12 = use_native_fp4 ? + ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; @@ -192,8 +196,8 @@ void ggml_cuda_mul_mat_q( const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { - quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + if (use_native_fp4) { + quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); } else { quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, @@ -202,8 +206,9 @@ void ggml_cuda_mul_mat_q( CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : - ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); + static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4"); + const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. @@ -267,12 +272,14 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t bool mmq_supported; switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -362,5 +369,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; - } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 255e59f6fc6..edf546d8f1e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -10,9 +10,9 @@ using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. -#define MMQ_ITER_K 256 -#define MMQ_ITER_K_MXFP4_FP4 512 -#define MMQ_NWARPS 8 +#define MMQ_ITER_K 256 +#define MMQ_ITER_K_FP4 512 +#define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); @@ -46,9 +46,12 @@ struct block_q8_1_mmq { int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; +// this struct is used for fp4 data types (currently only used for Blackwell) +// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits +// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales struct block_fp4_mmq { - uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc. - int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values + uint32_t d4[4]; + int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte) }; static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); @@ -57,6 +60,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { + case GGML_TYPE_Q1_0: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: return MMQ_Q8_1_DS_LAYOUT_DS4; @@ -68,6 +73,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_MXFP4: return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_NVFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q2_K: return MMQ_Q8_1_DS_LAYOUT_D2S6; case GGML_TYPE_Q3_K: @@ -100,7 +107,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : + return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -110,9 +117,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -139,10 +146,11 @@ static int get_mmq_y_host(const int cc) { static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) { #if defined(BLACKWELL_MMA_AVAILABLE) - return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K; -#else - return MMQ_ITER_K; +if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) { + return MMQ_ITER_K_FP4; +} #endif // defined(BLACKWELL_MMA_AVAILABLE) + return MMQ_ITER_K; } static constexpr __device__ int get_mmq_y_device() { @@ -183,12 +191,14 @@ static constexpr __device__ int get_mmq_y_device() { static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; @@ -206,12 +216,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } } -#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell +#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); @@ -220,9 +231,12 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4"); +static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); + static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; @@ -230,6 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; +#if defined(BLACKWELL_MMA_AVAILABLE) + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4; +#else + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4; +#endif // defined(BLACKWELL_MMA_AVAILABLE) case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; @@ -295,6 +314,87 @@ static constexpr __device__ int mmq_get_nwarps_device() { // ------------------------------------------------------------ +template static __device__ __forceinline__ void load_tiles_q1_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0; + constexpr int threads_per_row = blocks_per_iter * QI1_0; + constexpr int nrows = warp_size / threads_per_row; + constexpr int scale_entries_per_block = QK1_0 / QK8_1; + constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block; + + const int txi = threadIdx.x % threads_per_row; + const int kbx = txi / QI1_0; + const int kqsx = txi % QI1_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx; + const int qs_offset = 4*kqsx; + const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) | + (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24); + + int unpacked_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (qs0 >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0; +#pragma unroll + for (int j = 0; j < 8; ++j) { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j]; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } + + const int ksx = threadIdx.x % scale_entries_per_row; + const int scale_block = ksx / scale_entries_per_block; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -379,17 +479,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_0_Q8_1_MMQ]; -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + + int tmp0[4], tmp1[4]; + + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]); } + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -482,17 +590,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_1_Q8_1_MMQ]; -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + + int tmp0[4], tmp1[4]; + + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]); } + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -826,6 +942,187 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr } } +#ifdef BLACKWELL_MMA_AVAILABLE +template +static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kbx0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4); + constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block + constexpr int rows_per_warp = warp_size / threads_per_row; + + uint32_t * x_u32 = (uint32_t *) x_tile; + + const int txi = threadIdx.x; + const int kbx = txi % threads_per_row; + const int row_in_warp = txi / threads_per_row; + + const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx; + uint32_t * x_u32_scale = x_u32 + 64 + kbx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = bxi_base + i * stride; + const int row_base = i * MMQ_MMA_TILE_X_K_FP4; + const int q_base = row_base + 8 * kbx; + + const uint32_t * src_qs = reinterpret_cast(bxi->qs); + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0]; + x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1]; + } + + x_u32_scale[row_base] = get_int_b4(bxi->d, 0); + } +} + +// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell. +// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per +// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3) +// and the per-type stride constant differ. +template +static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x, + const int * __restrict__ y, + float * __restrict__ sum, + const int k00) { + static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4, + "vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4"); + + typedef tile<16, 8, int> tile_A; + typedef tile<8, 8, int> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int stride = MMQ_MMA_TILE_X_K_FP4; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp / tile_C::I; + constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J; + + y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + const int * y_qs = (const int *) y + 4; + const uint32_t * y_sc = (const uint32_t *) y; + + // 2 threads per quad supply the packed scale register to the block_scale MMA, + // see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling + const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8; + const int tidx_B = threadIdx.x / 4; + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + tile_A A[ntx][nfrags]; + uint32_t scaleA[ntx][nfrags]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + const int k0 = k00 + frag * tile_A::J; + load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride); + scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { + tile_B B[nfrags]; + uint32_t scaleB[nfrags]; + +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + const int k0 = frag * tile_B::J; + load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K); + scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + tile_C C = {}; + mma_block_scaled_fp4(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; + } + } + } + } +} +#endif // BLACKWELL_MMA_AVAILABLE + + +template +static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kb0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4; + constexpr int rows_per_warp = warp_size / threads_per_row; + const int kbx = threadIdx.x % threads_per_row; + const int row_in_warp = threadIdx.x / threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx; + const uint32_t * __restrict__ src_qs = reinterpret_cast(bxi->qs); + const int kqs = 16 * kbx; + const int ksc = 4 * kbx; + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4); + const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y; + x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#else + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y; + x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } +} + template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -887,13 +1184,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); float dB; const int j = j0 + tile_C::get_j(0); @@ -996,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -template -static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x, - const int * __restrict__ y, - float * __restrict__ sum, - const int k00) { - typedef tile<16, 8, int> tile_A; - typedef tile<8, 8, int> tile_B; - typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K); - - // Match layout from load_tiles_mxfp4_fp4 - const int * x_qs = (const int *) x; - const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); - const int * y_qs = (const int *) y + 4; - const uint32_t * y_sc = (const uint32_t *) y; - - // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4 - tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; - uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; - - // Block scale - // Each thread has to point to a 4 byte scale value - // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { - const int k0 = k00 + k01; - - load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0, - MMQ_MMA_TILE_X_K_FP4); - - // based on block-scaling document, 2 threads in each quad need to supply to the scale value - const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8; - scaleA[n][k01 / (2 * QI_MXFP4)] = - *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4)); - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { - tile_B B; - uint32_t scaleB; // 2xN scales - - load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K); - - scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - - mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB); -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; - } - } - } - } -} template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( @@ -1128,13 +1354,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -1229,7 +1455,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -// Used for Q3_K, IQ2_S, and IQ2_XS +// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -1268,57 +1494,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1343,13 +1519,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -1575,74 +1751,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; - const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 - : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y - : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); - - tile_C Cm; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm, A1, B[0]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd; - mma(Cd, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); - float tmp = Cd.x[l]*dm.x; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm.x[l]*dm.y; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1667,13 +1776,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; @@ -2406,59 +2515,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -2484,13 +2541,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -3208,6 +3265,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( template struct mmq_type_traits; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; @@ -3253,7 +3318,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; #ifdef BLACKWELL_MMA_AVAILABLE static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma; #else static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3261,6 +3326,19 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ; +#ifdef BLACKWELL_MMA_AVAILABLE + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma; +#else + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; +#endif // BLACKWELL_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; @@ -3392,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( #if defined(BLACKWELL_MMA_AVAILABLE) // FP4 tile stores 8 blocks - constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1; + constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1; #else constexpr int ne_block = 4 * QK8_1; #endif // defined(BLACKWELL_MMA_AVAILABLE) @@ -3464,10 +3542,10 @@ template static __global__ void mul_mat_q( const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, - const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const int ncols_max) { + const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 ntx) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -3481,8 +3559,7 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x - const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y // Initialize the ids for writing back data with just the index. // For regular matrix multiplications this is never changed. @@ -3503,8 +3580,9 @@ static __global__ void mul_mat_q( // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { - const int wt = blockIdx.z / nchannels_y; - const int zt = blockIdx.z - wt*nchannels_y; + const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y); + const int wt = tmp2.x; + const int zt = tmp2.y; const int jt = blockIdx.y; const int it = blockIdx.x; @@ -3547,40 +3625,40 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, - tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z); return; } -#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - - constexpr int ITER_K = get_iter_k(type); +#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - const int64_t blocks_per_ne00 = ncols_x / qk; - constexpr int blocks_per_iter = ITER_K / qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; + kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter; // kb0 == k index when doing the matrix multiplication for an output tile. - int kb0_start = kbc % blocks_per_ne00; - int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); - while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int kb0_start = fastmodulo(kbc, blocks_per_ne00); + int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc)); + while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) { + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3598,11 +3676,11 @@ static __global__ void mul_mat_q( offset_dst = 0; if (jt*mmq_x >= col_diff) { - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); continue; } @@ -3627,32 +3705,34 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); } if (kbc >= kbc_stop) { return; } - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3694,7 +3774,7 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile @@ -3703,46 +3783,37 @@ static __global__ void mul_mat_q( } template -static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, - const int32_t * expert_bounds, - float * __restrict__ dst, - const float * __restrict__ tmp_last_tile, - const int ncols_x, - const int nrows_x, - const int ncols_dst, - const size_t stride_col_dst, - const int nchannels_y, - const size_t stride_channel_dst, - const int nsamples_y, - const size_t stride_sample_dst, - const int ncols_max) { - constexpr int mmq_y = get_mmq_y_device(); - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int ITER_K = get_iter_k(type); - - constexpr int blocks_per_iter = ITER_K / qk; - const int64_t blocks_per_ne00 = ncols_x / qk; +__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1) +static __global__ void mul_mat_q_stream_k_fixup( + const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, + const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y, + const int stride_sample_dst, const uint3 ntx) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; - constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int nwarps = mmq_get_nwarps_device()/2; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + float sum[mmq_x / nwarps] = {0.0f}; + const int i = blockIdx.y*warp_size + threadIdx.x; - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; - const int nty = (nrows_x + mmq_y - 1) / mmq_y; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; const int bidx0 = blockIdx.x; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; - kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; - const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0; + const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } @@ -3751,11 +3822,11 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, // Iterate over previous blocks and sum up partial sums written to fixup buffer. // All CUDA blocks that get here must have a previous block that needs a fixup. - int64_t bidx = bidx0 - 1; - int64_t kbc_stop = kbc0; + int bidx = bidx0 - 1; + int kbc_stop = kbc0; while(true) { - int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; if (kbc == kbc_stop) { // Did not have any data. bidx--; @@ -3765,20 +3836,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, any_fixup = true; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; - } + sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) { break; } bidx--; @@ -3789,14 +3856,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } - int tmp = kbc0; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc0, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; if (!ids_dst) { const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; @@ -3804,6 +3873,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int i_max = nrows_x - it*mmq_y - 1; const int j_max = ncols_dst - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3813,16 +3885,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[j*stride_col_dst + i] += sum[j0/nwarps]; } return; } @@ -3842,6 +3905,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int i_max = nrows_x - it*mmq_y - 1; const int j_max = col_diff - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3851,16 +3917,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps]; } } @@ -3908,29 +3965,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int channel_ratio = args.nchannels_y / args.nchannels_x; const int sample_ratio = args.nsamples_y / args.nsamples_x; + const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits::qk); + const uint3 ntx_fd = init_fastdiv_values(ntx); + const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y); + const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y); + const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio); + const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio); + if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } return; } - const dim3 block_nums_stream_k(nsm, 1, 1); - const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + // For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles. + // This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important. + const int ntiles_dst = ntx * nty * ntzw; + const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm; + const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves); + const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1); + + GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow. + + const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0; ggml_cuda_pool & pool = ctx.pool(id); ggml_cuda_pool_alloc tmp_fixup(pool); @@ -3938,40 +4010,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); } + const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1); + const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z); + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } } @@ -4069,6 +4146,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_NVFP4); extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); @@ -4095,3 +4173,4 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); + diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 632246e43fd..da48f313a38 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,12 +9,14 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1; case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_NVFP4: return vec_dot_nvfp4_q8_1; case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; @@ -33,14 +35,16 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) } } -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { +static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_NVFP4: return VDR_NVFP4_Q8_1_MMVQ; case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; @@ -95,6 +99,200 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } +// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID. +// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE. +// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 4; + case GGML_TYPE_NVFP4: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 6; + case GGML_TYPE_Q4_1: return 6; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_0: return 6; + case GGML_TYPE_Q5_1: return 6; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 7; + case GGML_TYPE_IQ3_S: return 6; + case GGML_TYPE_IQ3_XXS: return 7; + case GGML_TYPE_MXFP4: return 7; + case GGML_TYPE_NVFP4: return 8; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 5; + case GGML_TYPE_IQ1_M: return 5; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 5; + case GGML_TYPE_Q4_1: return 5; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 5; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_K: return 6; + case GGML_TYPE_Q6_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 6; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 7; + case GGML_TYPE_IQ1_M: return 7; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 7; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 5; + case GGML_TYPE_NVFP4: return 5; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 7; + case GGML_TYPE_Q4_1: return 7; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_0: return 7; + case GGML_TYPE_Q5_1: return 7; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 5; + case GGML_TYPE_Q8_0: return 7; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +// Host function: returns the max batch size for the current arch+type at runtime. +int get_mmvq_mmid_max_batch(ggml_type type, int cc) { + // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID. + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return MMVQ_MAX_BATCH_SIZE; + } + if (cc >= GGML_CUDA_CC_TURING) { + return get_mmvq_mmid_max_batch_turing_plus(type); + } + return get_mmvq_mmid_max_batch_pascal_older(type); + } + + // AMD + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return get_mmvq_mmid_max_batch_rdna4(type); + } + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return get_mmvq_mmid_max_batch_rdna3(type); + } + if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); + } + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return get_mmvq_mmid_max_batch_cdna(type); + } + if (GGML_CUDA_CC_IS_GCN(cc)) { + return get_mmvq_mmid_max_batch_gcn(type); + } + } + return MMVQ_MAX_BATCH_SIZE; +} + +// Device constexpr: returns the max batch size for the current arch+type at compile time. +template +static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { +#if defined(RDNA4) + return get_mmvq_mmid_max_batch_rdna4(type); +#elif defined(RDNA3) + return get_mmvq_mmid_max_batch_rdna3(type); +#elif defined(RDNA2) || defined(RDNA1) + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); +#elif defined(CDNA) + return get_mmvq_mmid_max_batch_cdna(type); +#elif defined(GCN) + return get_mmvq_mmid_max_batch_gcn(type); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE) + return MMVQ_MAX_BATCH_SIZE; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + return get_mmvq_mmid_max_batch_turing_plus(type); +#else + return get_mmvq_mmid_max_batch_pascal_older(type); +#endif +} + static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { @@ -173,11 +371,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { switch (ncols_dst) { case 1: - return 1; + return small_k ? nwarps : 1; case 2: case 3: case 4: @@ -193,7 +391,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -208,7 +406,7 @@ static __global__ void mul_mat_vec_q( constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -220,22 +418,13 @@ static __global__ void mul_mat_vec_q( const uint32_t channel_dst = blockIdx.y; - uint32_t token_idx = 0; uint32_t channel_x; uint32_t channel_y; uint32_t sample_dst; - if constexpr (is_multi_token_id) { - // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case - token_idx = blockIdx.z; - channel_x = ids[channel_dst + token_idx * ids_stride]; - channel_y = fastmodulo(channel_dst, nchannels_y); - sample_dst = 0; - } else { - channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - sample_dst = blockIdx.z; - } + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -292,9 +481,6 @@ static __global__ void mul_mat_vec_q( float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; - if constexpr (is_multi_token_id) { - y += token_idx*stride_col_y; - } const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -348,10 +534,6 @@ static __global__ void mul_mat_vec_q( dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; - if constexpr (is_multi_token_id) { - dst += token_idx*stride_col_dst; - } - // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_dst; ++j) { @@ -411,17 +593,82 @@ static __global__ void mul_mat_vec_q( } } +// Dedicated MoE multi-token kernel. +// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst) +// Block: (warp_size, ncols_dst) - each warp handles one token independently. +// No shared memory reduction needed since each warp works alone. +template +__launch_bounds__(get_mmvq_mmid_max_batch_for_device()*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q_moe( + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, + float * __restrict__ dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride) { + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const uint32_t token_idx = threadIdx.y; + const int row0 = c_rows_per_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + + if (token_idx >= ncols_dst) { + return; + } + + const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride]; + const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y); + + const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y; + const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x; + + // partial sum for each thread + float tmp[c_rows_per_block] = {0.0f}; + + for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); + const int kqs = vdr * (threadIdx.x % (qi/vdr)); + +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + + // Warp-level reduction only - no shared memory needed +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] = warp_reduce_sum(tmp[i]); + } + + // Write results + if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) { + dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x]; + } +} + template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { + const int nwarps = calc_nwarps(type, ncols_dst, table_id); + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + const int64_t nblocks = (nrows_x + rpb - 1) / rpb; const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); - const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); + const dim3 block_dims(warp_size, nwarps, 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -434,7 +681,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -444,12 +691,33 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } +template +static void mul_mat_vec_q_moe_launch( + const void * vx, const void * vy, const int32_t * ids, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride, + const int warp_size, const int nchannels_dst, cudaStream_t stream) { + + constexpr int rows_per_block = 2; // 2 gives best perf based on tuning + const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block; + const dim3 block_nums(nblocks_rows, nchannels_dst); + const dim3 block_dims(warp_size, ncols_dst); + + mul_mat_vec_q_moe<<>>( + vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride); +} + template static void mul_mat_vec_q_switch_ncols_dst( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, @@ -468,31 +736,88 @@ static void mul_mat_vec_q_switch_ncols_dst( const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; const int warp_size = ggml_cuda_info().devices[device].warp_size; - const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + const mmvq_parameter_table_id table_id = get_device_table_id(cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_ids = ids != nullptr; + const auto should_use_small_k = [&](int c_ncols_dst) { + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + + constexpr std::array iq_slow_turing = { + GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_S, + }; + constexpr std::array iq_slow_other = { + GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, + }; + constexpr std::array slow_pascal = { + GGML_TYPE_IQ3_S, + GGML_TYPE_Q2_K, + GGML_TYPE_Q3_K, + }; + + const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING; + const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA; + + if (is_nvidia_turing_plus) { + if (ncols_dst == 1 && + std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) { + use = false; + } + } else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) || + (is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) || + GGML_CUDA_CC_IS_RDNA(cc)) { + use = false; + } + + return use; + }; + if (has_ids && ncols_dst > 1) { - // Multi-token MUL_MAT_ID path only - single-token goes through regular path below - constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + // Multi-token MUL_MAT_ID path - dedicated MoE kernel + mul_mat_vec_q_moe_launch( + vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride, warp_size, nchannels_dst, stream); return; } switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + + bool use_small_k = should_use_small_k(c_ncols_dst); + + if (use_small_k) { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); + } else { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); + } } break; case 2: { constexpr int c_ncols_dst = 2; @@ -566,6 +891,12 @@ static void mul_mat_vec_q_switch_type( const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride, cudaStream_t stream) { switch (type_x) { + case GGML_TYPE_Q1_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, @@ -602,6 +933,12 @@ static void mul_mat_vec_q_switch_type( nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 8a154631f69..6bf0a8e8677 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,7 +1,10 @@ #include "common.cuh" #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. -#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID + +// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, +// based on the quantization type and GPU architecture (compute capability). +int get_mmvq_mmid_max_batch(ggml_type type, int cc); void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 4300ffc148c..52f664719ae 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -70,6 +70,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { return static_cast(biased); } + +static __global__ void quantize_mmq_nvfp4( + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2) { +#if defined(BLACKWELL_MMA_AVAILABLE) + + const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB; + if (i0_base >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t k_block = i0_base / QK_K; + const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K; + if (k_block >= blocks_per_col) { + return; + } + + const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x; + block_fp4_mmq * y = (block_fp4_mmq *) vy; + block_fp4_mmq * yb = y + ib; + + const int sub = (i0_base % QK_K) / QK_NVFP4_SUB; + + float vals_raw[QK_NVFP4_SUB]; + float amax_raw = 0.0f; + const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01; +#pragma unroll + for (int k = 0; k < QK_NVFP4_SUB; k++) { + const int64_t i00 = i0_base + k; + if (i00 < ne00) { + const float v = x[base_idx + i00]; + vals_raw[k] = v; + amax_raw = fmaxf(amax_raw, fabsf(v)); + } else { + vals_raw[k] = 0.0f; + } + } + + static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2}; + const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f); + + float best_err = FLT_MAX; + uint8_t fp8_code = 0; + float subblock_scale = 0.0f; + +#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell. + for (int i = 0; i < 5; i++) { + const int test_code = first_fp8_code + test_offsets[i]; + if (test_code < 0 || test_code > 0x7e) { + continue; + } + const uint8_t code = (uint8_t) test_code; + const float test_scale = ggml_cuda_ue4m3_to_fp32(code); + const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f; + float cur_err = 0.0f; +#pragma unroll + for (int k = 0; k < QK_NVFP4_SUB; ++k) { + const float v = vals_raw[k]; + const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale); + const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale; + cur_err = fmaf(err_diff, err_diff, cur_err); + } + + if (cur_err < best_err) { + best_err = cur_err; + fp8_code = test_code; + subblock_scale = test_scale; + } + } + + const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f; + uint32_t q0 = 0; + uint32_t q1 = 0; +#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1 + for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) { + q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k); + q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4); + q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k); + q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4); + } + + uint32_t * yqs = reinterpret_cast(yb->qs); + yqs[2 * sub + 0] = q0; + yqs[2 * sub + 1] = q1; + reinterpret_cast(yb->d4)[sub] = fp8_code; +#else + NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only. +#endif // defined(BLACKWELL_MMA_AVAILABLE) + +} + // quantize values in the format mxfp4 is stored which is interleaved nibbles // i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31 static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, @@ -316,28 +412,32 @@ void quantize_mmq_q8_1_cuda( } } -void quantize_mmq_mxfp4_cuda(const float * x, - const int32_t * ids, - void * vy, - [[maybe_unused]] const ggml_type type_src0, - const int64_t ne00, - const int64_t s01, - const int64_t s02, - const int64_t s03, - const int64_t ne0, - const int64_t ne1, - const int64_t ne2, - const int64_t ne3, - cudaStream_t stream) { - GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); - - constexpr int nwarps = 8; - constexpr int vals_per_warp = 2 * QK_MXFP4; - constexpr int vals_per_block = nwarps * vals_per_warp; - - const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; - const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); - const dim3 block_size(WARP_SIZE, nwarps, 1); - - quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); +void quantize_mmq_fp4_cuda( + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4); + GGML_ASSERT(ne0 > 0); + + if (type_src0 == GGML_TYPE_NVFP4) { + GGML_ASSERT(ne00 % QK_NVFP4 == 0); + constexpr int nvfp4_block_size = 128; + const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size); + const dim3 block_size(nvfp4_block_size, 1, 1); + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + quantize_mmq_nvfp4<<>>( + x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + } else { + GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); + + constexpr int nwarps = 8; + constexpr int vals_per_warp = 2 * QK_MXFP4; + constexpr int vals_per_block = nwarps * vals_per_warp; + + const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + const dim3 block_size(WARP_SIZE, nwarps, 1); + + quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + } } diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 6a91df63578..768a3ae6de6 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda( ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); -void quantize_mmq_mxfp4_cuda(const float * x, +void quantize_mmq_fp4_cuda(const float * x, const int32_t * ids, void * vy, ggml_type type_src0, diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 69985cd335c..4841389fbc8 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -3,6 +3,7 @@ template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + for (int64_t i = 0; i < n_t; i++) { float sumf = 0.0f; @@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -97,6 +102,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + // Compute from shared memory for (int64_t i = 0; i < local_n_t; i++) { float sumf = 0.0f; @@ -104,12 +111,13 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, for (size_t j = 0; j < d_conv; j++) { sumf += smem[tid * n_cols + i + j] * w[j]; } + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } template -static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, +static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, const int64_t n_s, cudaStream_t stream) { @@ -120,30 +128,38 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); ssm_conv_long_token_f32<<>>( - src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; switch (nc) { case 3: launch_kernel(std::integral_constant{}); break; case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now."); + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_bias = bias_add_node != nullptr; const bool fuse_silu = silu_dst != nullptr; + // bias always comes with silu. + GGML_ASSERT(!fuse_bias || fuse_silu); + + // The bias (when fused) is the non-conv operand of the ADD node. + const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr; + // When fusing, write to silu_dst (the node downstream references). const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; @@ -159,16 +175,23 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; + const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr; float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(out->type == GGML_TYPE_F32); + if (fuse_bias) { + GGML_ASSERT(bias->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(bias)); + GGML_ASSERT(ggml_nelements(bias) == nr); + } + if (fuse_silu) { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } else { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index f96a1cd2484..8514ca84920 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu index 1f554d81e5e..8fc3b17976e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index dc16829021f..22d383173f3 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 517993cb068..d2415bfa957 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu index 264751d65ec..abd2b21ce04 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 97b19c67ade..8eec1d74e29 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 163b1d939e4..84b674cd05a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 989626dfa5e..3475dfea08a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index bad296b4141..5906398db91 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 173de7aac7d..684cd25ce0d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 680a13ca6de..4bc60d62f91 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu new file mode 100644 index 00000000000..c91f508079d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(320, 256); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu new file mode 100644 index 00000000000..7c61d8d2ecd --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(512, 512); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu new file mode 100644 index 00000000000..3a2fa99b05b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu new file mode 100644 index 00000000000..60f0f6f7952 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu new file mode 100644 index 00000000000..489e05f08c3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu new file mode 100644 index 00000000000..6fa3c26d309 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu new file mode 100644 index 00000000000..421027fb29d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu new file mode 100644 index 00000000000..abbc9434802 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu new file mode 100644 index 00000000000..d641f859d81 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu new file mode 100644 index 00000000000..d1071dc2438 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu new file mode 100644 index 00000000000..8afda314238 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu new file mode 100644 index 00000000000..506864ac18d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu new file mode 100644 index 00000000000..0bbda8371e6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu new file mode 100644 index 00000000000..79be24daf9e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu new file mode 100644 index 00000000000..45636e5e70c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index e382df1ae20..5e9a1cb2eb3 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,9 +3,9 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -32,10 +32,11 @@ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ + "GGML_TYPE_Q1_0", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -61,7 +62,7 @@ def get_short_name(long_quant_name): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = head_size_kq if head_size_kq != 576 else 512 + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -83,11 +84,16 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue - if head_size_kq != 576 and ncols2 in (16, 32): + # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 continue - if head_size_kq == 576 and ncols2 not in (4, 16, 32): + if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue - head_size_v = head_size_kq if head_size_kq != 576 else 512 + if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash + continue + if head_size_kq not in (320, 576) and ncols2 in (16, 32): + continue + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu new file mode 100644 index 00000000000..2cb140d35a3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_NVFP4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu new file mode 100644 index 00000000000..f0686b0d0d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q1_0); diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 785a18389f2..59ce36fb1c9 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -25,14 +25,14 @@ static void top_k_cub(ggml_cuda_pool & pool, auto indexes_in = cuda::make_counting_iterator(0); size_t temp_storage_bytes = 0; - DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, - env); + CUDA_CHECK(DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, + env)); ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); void * d_temp_storage = temp_storage_alloc.get(); - DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, - ncols, k, env); + CUDA_CHECK(DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, + ncols, k, env)); } #elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 4ad30fa1f35..2aeba26f414 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) { return x * x; } +static __device__ __forceinline__ float op_relu_sqr(float x) { + const float r = fmaxf(x, 0.0f); + return r * r; +} + static __device__ __forceinline__ float op_sqrt(float x) { return sqrtf(x); } @@ -615,3 +620,21 @@ void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary GGML_ABORT("Unsupported unary op for fused unary+mul"); } } + +/* fused relu + sqr */ + +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) { + const ggml_tensor * src = relu_node->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); + GGML_ASSERT(src->type == sqr_node->type); + + const int k = ggml_nelements(src); + if (src->type == GGML_TYPE_F16) { + unary_cuda((const half *)src->data, (half *)sqr_node->data, k, stream); + } else { + unary_cuda((const float *)src->data, (float *)sqr_node->data, k, stream); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index f1dd2183a6c..81ed873ecc3 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -91,6 +91,8 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index ab803aca21b..d1741cc8d7b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -106,6 +106,9 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism +#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -322,6 +325,38 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1( return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __device__ __forceinline__ float vec_dot_nvfp4_q8_1( + const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & kbx, + const int32_t & iqs) { + + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq + kbx; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_cuda_dp4a(v0.x, get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_cuda_dp4a(v0.y, get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_cuda_dp4a(v1.x, get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_cuda_dp4a(v1.y, get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_cuda_ue4m3_to_fp32(bq4->d[is]) * __low2float(bq8->ds); + sum += d * float(sumi); + } + + return sum; +} #define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMQ 4 @@ -637,6 +672,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } +static __device__ __forceinline__ float vec_dot_q1_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx; + + // Q1_0: 128 elements with ONE scale + // Q8_1: 32 elements per block with individual scales + // iqs selects which of the 4 chunks of 32 elements to process (0-3) + + const float d1 = bq1_0->d; + + // Process only the chunk specified by iqs + const block_q8_1 * bq8_1_chunk = bq8_1 + iqs; + + // Load 32 bits (4 bytes) for this chunk from Q1_0 + const int offset = iqs * 4; + const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) | + (bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24); + + // Unpack 32 bits into 32 signed values (-1 or +1) + int vi_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (v >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + // Compute dot product for this 32-element chunk + int sumi = 0; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int u = get_int_b4(bq8_1_chunk->qs, j); + sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi); + } + + // Apply Q1_0's single scale and this chunk's Q8_1 scale + const float d8 = __low2float(bq8_1_chunk->ds); + return d1 * d8 * sumi; +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index ba032cfab4b..323c9801934 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,9 +6,14 @@ #include #include -#if CUDART_VERSION >= 12050 +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + +#if CUDART_VERSION >= 11080 #include -#endif // CUDART_VERSION >= 12050 +#define FP8_AVAILABLE +#endif // CUDART_VERSION >= 11080 #if CUDART_VERSION >= 12080 #include diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 35d1e1a0639..78ca364d38f 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -10,6 +10,11 @@ #include #endif // defined(GGML_HIP_ROCWMMA_FATTN) +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -53,6 +58,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t +#define cudaErrorMemoryAllocation hipErrorOutOfMemory #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags @@ -183,6 +189,10 @@ #define GCN #endif // defined(GCN5) || defined(GCN4) +#if defined(__gfx950__) +#define CDNA4 +#endif // defined(__gfx950__) + #if defined(__gfx942__) #define CDNA3 #endif // defined(__gfx942__) @@ -195,9 +205,9 @@ #define CDNA1 #endif // defined(__gfx908__) -#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #define CDNA // For the entire family -#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#endif // defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #if defined(__GFX12__) #define RDNA4 @@ -235,6 +245,12 @@ typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat162 nv_bfloat162; +#if HIP_VERSION >= 60200000 +#include +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +#define FP8_AVAILABLE +#endif // HIP_VERSION >= 60200000 + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1abb8acfd4b..8aa056e9174 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -42,6 +42,7 @@ #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t +#define cudaErrorMemoryAllocation musaErrorMemoryAllocation #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index f3a583543c6..b82bae0c103 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include(ExternalProject) -option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF) set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 19917cb1140..df4ed101464 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -7,10 +7,17 @@ #include #include -#include #include +#include +#include #include #include +#include +#include +#include +#include +#include +#include #ifdef _WIN32 # include @@ -33,22 +40,36 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "op-desc.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp_iface.h" #include "htp-drv.h" -static size_t opt_ndev = 1; -static size_t opt_nhvx = 0; // use all -static int opt_arch = 0; // autodetect -static int opt_etm = 0; -static int opt_verbose = 0; -static int opt_profile = 0; -static int opt_hostbuf = 1; // hostbuf ON by default -static int opt_experimental = 0; +using intvec = std::vector; +using uintvec = std::vector; +using u32vec = std::vector; + +static int opt_arch = 0; // autodetect +static size_t opt_ndev = 1; +static size_t opt_nhvx = 0; // use all +static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings +static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size +static int opt_etm = 0; +static int opt_verbose = 0; +static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) +static int opt_hostbuf = 1; // hostbuf ON by default + +// Default PMU events, if profiling with PMU (mode=2) is enabled +// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html +// https://docs.qualcomm.com/doc/80-N2040-61/topic/hvx-pmu-events.html +static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }; // Enable all stages by default -static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; -static int opt_opsync = 0; // synchronous ops +static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; +static int opt_opbatch = 1024; // max number of ops in a batch +static int opt_opqueue = 16; // max number of pending batches + +static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) @@ -85,42 +106,39 @@ static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_t op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { if (!opt_verbose) return; op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); } static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, - uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) { + uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; + char pmu_str[256] = ""; + if (opt_profile > 1) { + static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); + sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", + pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); + } + op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, - op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec); + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); } // ** backend sessions -struct ggml_hexagon_session { - ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); - ~ggml_hexagon_session() noexcept(true); - - void allocate(int dev_id) noexcept(false); - void release() noexcept(true); - - void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false); - void flush(); - - ggml_backend_buffer_type buffer_type = {}; - ggml_backend_buffer_type repack_buffer_type = {}; +struct ggml_hexagon_opbatch; +struct ggml_hexagon_opqueue; +struct ggml_hexagon_session { std::string name; remote_handle64 handle; dspqueue_t queue; @@ -132,87 +150,28 @@ struct ggml_hexagon_session { bool valid_handle; bool valid_queue; bool valid_iface; - std::atomic op_pending; - uint32_t prof_usecs; - uint32_t prof_cycles; - uint32_t prof_pkts; -}; - -void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { - // Bump pending flag (cleared in the session::flush once we get the response) - this->op_pending++; // atomic inc - - int err = dspqueue_write(this->queue, - 0, // flags - the framework will autoset this - n_bufs, // number of buffers - bufs, // buffer references - sizeof(req), // Message length - (const uint8_t *) &req, // Message - DSPQUEUE_TIMEOUT // Timeout - ); - - if (err != 0) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err); - } - - if (sync) { - flush(); - } -} - -// Flush HTP response queue i.e wait for all outstanding requests to complete -void ggml_hexagon_session::flush() { - dspqueue_t q = this->queue; - - // Repeatedly read packets from the queue until it's empty. We don't - // necessarily get a separate callback for each packet, and new packets - // may arrive while we're processing the previous one. - - while (this->op_pending) { - struct htp_general_rsp rsp; - uint32_t rsp_size; - uint32_t flags; - struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - uint32_t n_bufs; - - // Read response packet from queue - int err = dspqueue_read(q, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(rsp), // Max message length - &rsp_size, // Message length - (uint8_t *) &rsp, // Message - DSPQUEUE_TIMEOUT); // Timeout + std::atomic op_pending; + ggml_hexagon_opbatch* op_batch; + ggml_hexagon_opqueue* op_queue; - if (err == AEE_EEXPIRED) { - // TODO: might need to bail out if the HTP is stuck on something - continue; - } + ggml_backend_buffer_type buffer_type = {}; + ggml_backend_buffer_type repack_buffer_type = {}; - if (err != 0) { - GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); - } + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); + ~ggml_hexagon_session() noexcept(true); - // Basic sanity checks - if (rsp_size != sizeof(rsp)) { - GGML_ABORT("ggml-hex: dspcall : bad response (size)\n"); - } + const char* c_name() const { return name.c_str(); } - if (rsp.status != HTP_STATUS_OK) { - GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status)); - // TODO: handle errors - } + void allocate(int dev_id) noexcept(false); + void release() noexcept(true); - // TODO: update profiling implementation, currently only works for opt_opsync mode - this->prof_usecs = rsp.prof_usecs; - this->prof_cycles = rsp.prof_cycles; - this->prof_pkts = rsp.prof_pkts; + void enqueue_op(htp_op_code opcode, const ggml_tensor *op); + void flush(bool all = true); - this->op_pending--; // atomic dec - } -} + void flush_pending(bool all = false); + void flush_batch(); +}; // ** backend buffers @@ -226,82 +185,94 @@ struct ggml_backend_hexagon_buffer_type_context { std::string name; }; -struct ggml_backend_hexagon_buffer_context { - bool mmap_to(ggml_hexagon_session * s) { - HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n", - s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd, - (int) this->repack); +struct ggml_hexagon_shared_buffer { + ggml_hexagon_session * sess; + uint8_t * base; + size_t size; + int fd; + bool mapped; + bool pinned; + + void mmap() { + fastrpc_map_flags flags = this->pinned ? FASTRPC_MAP_FD : FASTRPC_MAP_FD_DELAYED; - int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD); + int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, flags); if (err != 0) { - GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", - s->domain_id, this->size, this->fd, (unsigned) err); - return false; + GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), + sess->domain_id, this->size, this->fd, (unsigned) err); + throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)"); } - return true; - } + HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n", + sess->c_name(), (void *) this->base, this->size, this->fd, pinned); - bool mmap() { - if (this->mapped) { - return true; - } - if (!mmap_to(this->sess)) { - return false; - } this->mapped = true; - return true; } - void munmap() { - if (!this->mapped) { - return; + void unmap() { + if (!this->mapped) return; + + if (!this->pinned) { + // HTP might still hold a reference, tell it drop it + htp_iface_munmap(sess->handle, this->fd); } - fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size); + fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size); + + HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(), + (void *) this->base, size, this->fd); + this->mapped = false; + this->fd = -1; } - ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { - size += 4 * 1024; // extra page for padding + void alloc(size_t size) { + if (this->base) return; - this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); + this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size); if (!this->base) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->c_name(), size); throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); } this->fd = rpcmem_to_fd(this->base); if (this->fd < 0) { - GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base); - rpcmem_free(this->base); - this->base = NULL; + GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->c_name(), (void *) this->base); throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)"); } + this->size = size; - HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(), - (void *) this->base, size, this->fd, (int) repack); + HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(), + (void *) this->base, this->size, this->fd, (int) pinned); + mmap(); + } + + void free() { + if (!this->base) return; + + unmap(); + rpcmem_free(this->base); + + HEX_VERBOSE("ggml-hex: %s freed buffer: base %p size %zu fd %d\n", sess->c_name(), + (void *) this->base, size, this->fd); + this->base = NULL; + } + + ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) { this->sess = sess; - this->size = size; + this->size = 0; + this->base = nullptr; + this->fd = -1; this->mapped = false; - this->repack = repack; - } + this->pinned = pinned; - ~ggml_backend_hexagon_buffer_context() { - munmap(); - if (this->base) { - rpcmem_free(this->base); - this->base = NULL; - } + alloc(size); } - ggml_hexagon_session * sess; // primary session - uint8_t * base; - size_t size; - int fd; - bool mapped; // mmap is done - bool repack; // repacked buffer + ~ggml_hexagon_shared_buffer() { + free(); + } }; static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) { @@ -309,30 +280,26 @@ static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_ } static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) { - auto ctx = static_cast(buffer->context); - delete ctx; + auto sbuf = static_cast(buffer->context); + delete sbuf; } static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) { - auto ctx = static_cast(buffer->context); - return ctx->base; + auto sbuf = static_cast(buffer->context); + return sbuf->base; } static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - auto ctx = static_cast(buffer->context); - auto sess = ctx->sess; + auto sbuf = static_cast(buffer->context); + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(), - tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage, - (int) ctx->repack); + HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d\n", sess->c_name(), + tensor->name, (void *) sbuf->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage); if (tensor->view_src != NULL && tensor->view_offs == 0) { - ; // nothing to do for the view - } else { - if (!ctx->mapped) { - ctx->mmap(); - } + return GGML_STATUS_SUCCESS; // nothing to do for the view } + return GGML_STATUS_SUCCESS; } @@ -460,7 +427,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { d[7] = x[i * 8 + 7].d; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q4x4x2(y, i, k); } @@ -479,7 +446,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { const uint8_t * y_q = y + 0; // quants first const uint8_t * y_d = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q4x4x2(y, i, k); } @@ -795,7 +762,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { d[7] = x[i * 8 + 7].d; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q8x4x2(y, i, k); } @@ -813,7 +780,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { const uint8_t * y_q = y + 0; // quants first const uint8_t * y_d = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q8x4x2(y, i, k); } @@ -1148,7 +1115,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) e[7] = x[i * 8 + 7].e; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_mxfp4x4x2(y, i, k); } @@ -1167,7 +1134,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) const uint8_t * y_q = y + 0; // quants first const uint8_t * y_e = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_mxfp4x4x2(y, i, k); } @@ -1386,11 +1353,10 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, const void * data, size_t offset, size_t size) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, - offset, size); + HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->c_name(), tensor->name, data, offset, size); switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -1405,6 +1371,13 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q8_0_q8x4x2(tensor, data, size); break; + case GGML_TYPE_IQ4_NL: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + // IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16]) + repack_q4_0_q4x4x2(tensor, data, size); + break; + case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1422,11 +1395,10 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, void * data, size_t offset, size_t size) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, - offset, size); + HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->c_name(), tensor->name, data, offset, size); switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -1441,6 +1413,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q8x4x2_q8_0(data, tensor, size); break; + case GGML_TYPE_IQ4_NL: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_0(data, tensor, size); + break; + case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1464,10 +1442,10 @@ static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t bu } static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; - HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size); - memset(ctx->base, value, ctx->size); + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; + HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->c_name(), (void *) sbuf->base, sbuf->size); + memset(sbuf->base, value, sbuf->size); } static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { @@ -1477,6 +1455,8 @@ static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor, /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor, /* .clear = */ ggml_backend_hexagon_buffer_clear, /* .reset = */ NULL, @@ -1492,10 +1472,11 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { - ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/); - return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + size += 4 * 1024; // guard page + ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context (host): %s\n", sess->c_name(), exc.what()); return nullptr; } } @@ -1504,10 +1485,11 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { - ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/); - return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + size += 4 * 1024; // guard page + ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context (repack): %s\n", sess->c_name(), exc.what()); return nullptr; } } @@ -1522,7 +1504,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return 1 * 1024 * 1024 * 1024; // 1GB per buffer + return opt_mbuf; // typically 1GB per buffer GGML_UNUSED(buffer_type); } @@ -1554,213 +1536,686 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; -void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { - this->valid_session = false; - this->valid_handle = false; - this->valid_queue = false; - this->valid_iface = false; +// Backend session implementation - this->domain_id = 3; // Default for CDSP, updated after the session is created - this->session_id = 0; // Default for CDSP, updated after the session is created - this->dev_id = dev_id; - this->name = std::string("HTP") + std::to_string(dev_id); +struct ggml_hexagon_opbatch { + ggml_hexagon_session* sess; - this->op_pending = 0; - this->prof_usecs = 0; - this->prof_cycles = 0; - this->prof_pkts = 0; + std::vector ops; // pointers to original ops - GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str()); + std::vector h_bufs; // htp buffer descriptors + std::vector h_tens; // htp tensor descriptors + std::vector h_ops; // htp op descriptors - domain * my_domain = get_domain(this->domain_id); - if (my_domain == NULL) { - GGML_LOG_ERROR("ggml-hex: unable to get domain struct for CDSP\n"); - throw std::runtime_error("ggml-hex: failed to get CDSP domain (see log for details)"); - } + std::unordered_map b_map; // buffer fd to index + std::unordered_map t_map; // tensor ptr to index + std::unordered_multimap d_map; // tensor data to index - // Create new session - if (dev_id != 0) { - struct remote_rpc_reserve_new_session n; - n.domain_name_len = strlen(CDSP_DOMAIN_NAME); - n.domain_name = const_cast(CDSP_DOMAIN_NAME); - n.session_name = const_cast(this->name.c_str()); - n.session_name_len = this->name.size(); + unsigned int n_bufs; // num buffers in the batch + unsigned int n_tens; // num tensors ... + unsigned int n_ops; // num ops ... + size_t b_vmem; // sum of all buffer sizes - int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n)); - if (err != AEE_SUCCESS) { - GGML_LOG_ERROR("ggml-hex: failed to reserve new session %d : error 0x%x\n", dev_id, err); - throw std::runtime_error("ggml-hex: remote_session_control(new-sess) failed (see log for details)"); - } + unsigned int n_bufs_max; + unsigned int n_tens_max; + unsigned int n_ops_max; + size_t b_vmem_max; - // Save the IDs - this->session_id = n.session_id; - this->domain_id = n.effective_domain_id; - this->valid_session = true; + void reset() { + n_bufs = 0; + n_tens = 0; + n_ops = 0; + b_vmem = 0; + + b_map.clear(); + t_map.clear(); + d_map.clear(); } - // Get session URI + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size, size_t max_vmem) { + this->sess = sess; - char session_uri[256]; - { - char htp_uri[256]; - snprintf(htp_uri, sizeof(htp_uri), "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch); + n_bufs_max = HTP_OP_MAX_BUFS; + n_ops_max = batch_size; + n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; - struct remote_rpc_get_uri u = {}; - u.session_id = this->session_id; - u.domain_name = const_cast(CDSP_DOMAIN_NAME); - u.domain_name_len = strlen(CDSP_DOMAIN_NAME); - u.module_uri = const_cast(htp_uri); - u.module_uri_len = strlen(htp_uri); - u.uri = session_uri; - u.uri_len = sizeof(session_uri); + b_vmem_max = max_vmem; - int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u)); - if (err != AEE_SUCCESS) { - // fallback to single session uris - int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN; + ops.resize(n_ops_max); - snprintf(session_uri, htp_URI_domain_len, "%s%s", htp_uri, my_domain->uri); + h_bufs.resize(n_bufs_max); + h_tens.resize(n_tens_max); + h_ops.resize(n_ops_max); - GGML_LOG_WARN("ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\n", dev_id, err, session_uri); - } - } + b_map.reserve(n_bufs_max); + t_map.reserve(n_tens_max); + d_map.reserve(n_tens_max); - // Enable Unsigned PD - { - struct remote_rpc_control_unsigned_module u; - u.domain = this->domain_id; - u.enable = 1; - int err = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u)); - if (err != AEE_SUCCESS) { - GGML_LOG_ERROR("ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\n", dev_id, err); - throw std::runtime_error("ggml-hex: remote_session_control(unsign) failed (see log for details)"); - } - } + GGML_LOG_INFO("ggml-hex: %s op batching: n-bufs %u n-tensors %u n-ops %u vmem %zu\n", + sess->c_name(), n_bufs_max, n_tens_max, n_ops_max, b_vmem_max); - // Open session - int err = htp_iface_open(session_uri, &this->handle); - if (err != AEE_SUCCESS) { - GGML_LOG_ERROR("ggml-hex: failed to open session %d : error 0x%x\n", dev_id, err); - throw std::runtime_error("ggml-hex: failed to open session (see log for details)"); + reset(); } - this->valid_handle = true; + bool empty() const { return n_ops == 0; } - GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(), - this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); + // add buffer and return its index + int add_buffer(ggml_hexagon_shared_buffer * sbuf) { + // Lookup by fd + auto it = b_map.find(sbuf->fd); + if (it != b_map.end()) { return it->second; } - // Enable FastRPC QoS mode - { - struct remote_rpc_control_latency l; - l.enable = 1; + // Add new buffer to the batch + int bi = n_bufs++; + GGML_ASSERT(n_bufs < HTP_OP_MAX_BUFS); - int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l)); - if (err != 0) { - GGML_LOG_WARN("ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\n", (unsigned) err); - } - } + b_map.insert({sbuf->fd, bi}); - // Now let's setup the DSP queue - err = dspqueue_create(this->domain_id, - 0, // Flags - 128 * 1024, // Request queue size (in bytes) - 64 * 1024, // Response queue size (in bytes) - nullptr, // Read packet callback (we handle reads explicitly) - nullptr, // Error callback (we handle errors during reads) - (void *) this, // Callback context - &queue); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: %s dspqueue_create failed: 0x%08x\n", this->name.c_str(), (unsigned) err); - throw std::runtime_error("ggml-hex: failed to create dspqueue (see log for details)"); - } + htp_buf_desc &b = h_bufs[bi]; + b.base = (uint64_t) sbuf->base; + b.fd = sbuf->fd; + b.size = sbuf->size; - this->valid_queue = true; + b_vmem += b.size; - // Export queue for use on the DSP - err = dspqueue_export(queue, &this->queue_id); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: dspqueue_export failed: 0x%08x\n", (unsigned) err); - throw std::runtime_error("ggml-hex: dspqueue export failed (see log for details)"); - } + HEX_VERBOSE("ggml-hex: add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); - if (opt_etm) { - err = htp_iface_enable_etm(this->handle); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err); - } + return bi; } - // Start the DSP-side service. We need to pass the queue ID to the - // DSP in a FastRPC call; the DSP side will import the queue and start - // listening for packets in a callback. - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); - throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); + bool same_shape(const htp_tensor * h, const ggml_tensor * t) const { + return (h->ne[0] == t->ne[0]) && (h->ne[1] == t->ne[1]) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && + (h->nb[0] == t->nb[0]) && (h->nb[1] == t->nb[1]) && (h->nb[2] == t->nb[2]) && (h->nb[3] == t->nb[3]); } - this->valid_iface = true; -} - -void ggml_hexagon_session::release() noexcept(true) { - GGML_LOG_INFO("ggml-hex: releasing session: %s\n", this->name.c_str()); - int err; + // add tensor and return its index + int add_tensor(const ggml_tensor * t) { + auto sbuf = static_cast(t->buffer->context); - // Stop the DSP-side service and close the queue - if (this->valid_iface) { - err = htp_iface_stop(this->handle); - if (err != 0) { - GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err); + // First lookup by tensor data + auto range = d_map.equal_range(t->data); + for (auto it = range.first; it != range.second; ++it) { + htp_tensor * h = &h_tens[it->second]; + if (same_shape(h, t)) { return it->second; } } - } - if (opt_etm) { - err = htp_iface_disable_etm(this->handle); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err); - } - } + // Lookup by tensor ptr + auto it = t_map.find(t); + if (it != t_map.end()) { return it->second; } - if (this->valid_queue) { - err = dspqueue_close(queue); - if (err != 0) { - GGML_ABORT("ggml-hex: dspqueue_close failed: 0x%08x\n", (unsigned) err); - } - } + // Add new tensor to the batch + int ti = n_tens++; + GGML_ASSERT(n_tens <= n_tens_max); - if (this->valid_handle) { - htp_iface_close(this->handle); - } -} + t_map.insert({t, ti}); + d_map.insert({t->data, ti}); -ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) { - buffer_type.device = dev; - repack_buffer_type.device = dev; + uint64_t t_offset = (uint8_t *) t->data - sbuf->base; + size_t t_size = ggml_nbytes(t); - try { - allocate(dev_id); + htp_tensor &h = h_tens[ti]; + h.bi = add_buffer(sbuf); + h.data = t_offset; + h.size = t_size; + h.type = t->type; + h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; + h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; - buffer_type.iface = ggml_backend_hexagon_buffer_type_interface; - buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this); + h.flags = 0; + if (ggml_backend_buffer_get_usage(t->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + h.flags |= HTP_TENSOR_COMPUTE; + } - repack_buffer_type.iface = ggml_backend_hexagon_repack_buffer_type_interface; - repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this); - } catch (const std::exception & exc) { - release(); - throw; + HEX_VERBOSE("ggml-hex: add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", + ti, t->name, h.bi, (void*) t->data, (size_t) t_offset, t_size, h.flags, + (size_t) t->ne[0], (size_t) t->ne[1], (size_t) t->ne[2], (size_t) t->ne[3]); + + return ti; } -} -ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) { - release(); + bool fit_op(const struct ggml_tensor *t) const { + if (n_ops >= n_ops_max ) return false; - delete static_cast(buffer_type.context); - delete static_cast(repack_buffer_type.context); -} + // check how much extras we will need + size_t extra_bufs = 0; + size_t extra_vmem = 0; + size_t extra_tens = 0; -// ** backend interface + auto fit_tensor = [&](const ggml_tensor *t) { + if (!t_map.count(t)) { + extra_tens++; -static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { + auto sbuf = static_cast(t->buffer->context); + if (!b_map.count(sbuf->fd)) { + extra_vmem += sbuf->size; + extra_bufs += 1; + } + } + }; + + for (unsigned int i=0; i < HTP_OP_MAX_INPUTS && t->src[i]; i++) { + fit_tensor(t->src[i]); + } + fit_tensor(t); + + if ((extra_bufs + n_bufs) > n_bufs_max) return false; + if ((extra_tens + n_tens) > n_tens_max) return false; + if ((extra_vmem + b_vmem) > b_vmem_max) return false; + + return true; + } + + // assumes that fit_op() was called first and returned true + void add_op(htp_op_code opcode, const struct ggml_tensor * t) { + // Add new op + + unsigned int n = n_ops++; + GGML_ASSERT(n_ops <= n_ops_max); + + ops[n] = t; + + htp_op_desc &o = h_ops[n]; + memcpy(&o.params, &t->op_params, sizeof(t->op_params)); + o.opcode = opcode; + o.flags = 0; + + if (!(opt_opstage & HTP_OPSTAGE_COMPUTE)) { + o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + ggml_hexagon_dump_op_exec(sess->c_name(), t, o.flags); + + for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { + o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; + } + o.dst = add_tensor(t); + } +}; + +struct ggml_hexagon_opqueue { + // Shared buffer for storing batches + ggml_hexagon_shared_buffer *shm_buf; + size_t shm_blk_size; + + using opvec = std::vector; + + std::queue done; // completed batch ids + std::vector op_cache; // per batch op cache + std::vector start_usec; // per batch start time + + ggml_hexagon_opqueue(ggml_hexagon_session *sess, size_t batch_size, size_t depth) { + size_t n_bufs = HTP_OP_MAX_BUFS; + size_t n_ops = batch_size; + size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + + shm_blk_size = sizeof(htp_buf_desc) * n_bufs + + sizeof(htp_tensor) * n_tensors + + sizeof(htp_op_desc) * n_ops + + sizeof(htp_prof_desc) * n_ops; + + shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */); + + op_cache.resize(depth); + start_usec.resize(depth, 0); + + // init done queue + for (unsigned int i = 0; i < depth; i++) { done.push(i); } + + if (opt_verbose) { + GGML_LOG_INFO("ggml-hex: %s allocated op-queue : batch-size %zu depth %zu shm-size %zu shm-block-size %zu\n", + sess->c_name(), batch_size, depth, shm_buf->size, shm_blk_size); + } + } + + ~ggml_hexagon_opqueue() { + delete shm_buf; + } + + // push new batch + bool push(htp_opbatch_req& req, dspqueue_buffer& dbuf, ggml_hexagon_opbatch* op_batch) { + static_assert(sizeof(htp_opbatch_req) % 8 == 0, "sizeof(htp_opbatch_req) must be multiple of 8"); + static_assert(sizeof(htp_opbatch_rsp) % 8 == 0, "sizeof(htp_opbatch_rsp) must be multiple of 8"); + static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); + static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); + static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + static_assert(sizeof(htp_prof_desc) % 8 == 0, "sizeof(htp_prof_desc) must be multiple of 8"); + + if (done.empty()) { return false; } + + req.id = done.front(); done.pop(); // batch id + req.n_bufs = op_batch->n_bufs; + req.n_tensors = op_batch->n_tens; + req.n_ops = op_batch->n_ops; + + op_cache[req.id] = op_batch->ops; + start_usec[req.id] = ggml_time_us(); + + const size_t b_size = sizeof(htp_buf_desc) * req.n_bufs; + const size_t t_size = sizeof(htp_tensor) * req.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * req.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * req.n_ops; + + dbuf.ptr = shm_buf->base + (req.id * shm_blk_size); + dbuf.fd = shm_buf->fd; + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base; + dbuf.size = b_size + t_size + o_size + p_size; + + GGML_ASSERT(dbuf.size <= shm_blk_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * b_ptr = m_ptr; m_ptr += b_size; + uint8_t * t_ptr = m_ptr; m_ptr += t_size; + uint8_t * o_ptr = m_ptr; + + memcpy(b_ptr, (void *) op_batch->h_bufs.data(), b_size); + memcpy(t_ptr, (void *) op_batch->h_tens.data(), t_size); + memcpy(o_ptr, (void *) op_batch->h_ops.data(), o_size); + + HEX_VERBOSE("ggml-hex: %s op-queue push batch #%u : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu m-size %zu\n", + shm_buf->sess->c_name(), req.id, req.n_bufs, req.n_tensors, req.n_ops, op_batch->b_vmem, + b_size, t_size, o_size, (size_t) dbuf.size); + + op_batch->reset(); + + if (opt_verbose > 1) { + htp_buf_desc *b = (htp_buf_desc*) b_ptr; + for (unsigned int i=0; i < req.n_bufs; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", shm_buf->sess->c_name(), i, + b[i].fd, (void *) b[i].base, (size_t) b[i].size); + } + htp_tensor *t = (htp_tensor*) t_ptr; + for (unsigned int i=0; i < req.n_tensors; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-tensor #%u : bi %u offset %u size %u : %zu:%zu:%zu:%zu\n", + shm_buf->sess->c_name(), i, t[i].bi, t[i].data, t[i].size, + (size_t) t[i].ne[0], (size_t) t[i].ne[1], (size_t) t[i].ne[2], (size_t) t[i].ne[3]); + } + } + + return true; + } + + void pop(htp_opbatch_rsp rsp, dspqueue_buffer dbuf) { + GGML_ASSERT(rsp.id < op_cache.size()); + + done.push(rsp.id); + + const size_t b_size = sizeof(htp_buf_desc) * rsp.n_bufs; + const size_t t_size = sizeof(htp_tensor) * rsp.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops; + + const size_t m_size = b_size + t_size + o_size + p_size; + GGML_ASSERT(m_size <= shm_blk_size); + + HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n", + shm_buf->sess->c_name(), rsp.id, rsp.n_bufs, rsp.n_tensors, rsp.n_ops, + (size_t) dbuf.size, b_size, t_size, o_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * p_ptr = m_ptr + (b_size + t_size + o_size); + + if (opt_profile && rsp.n_ops > 0) { + auto & ops = op_cache[rsp.id]; + + uint64_t batch_usec = ggml_time_us() - start_usec[rsp.id]; + uint32_t htp_usec = 0; + + GGML_ASSERT(rsp.n_ops <= ops.size()); + + const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr; + for (uint32_t i = 0; i < rsp.n_ops; i++) { + htp_usec += pd[i].usecs; + ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu); + } + + GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n", + shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec); + } + } +}; + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush_pending(bool all) { + while (this->op_pending) { + struct htp_opbatch_rsp rsp; + uint32_t rsp_size; + uint32_t flags; + + struct dspqueue_buffer dbuf; + uint32_t n_dbufs; + + // Read response packet from queue + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT); + if (err == AEE_EEXPIRED) { + continue; + } + + if (err != 0) { + GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); + } + + // Basic sanity checks + if (rsp_size != sizeof(rsp) || n_dbufs != 1) { + GGML_ABORT("ggml-hex: %s dspcall : bad response : size %u dspbufs %u\n", this->c_name(), rsp_size, n_dbufs); + } + + if (rsp.status != HTP_STATUS_OK) { + GGML_LOG_ERROR("ggml-hex: %s dspcall : dsp-rsp: %s\n", this->c_name(), status_to_str(rsp.status)); + // TODO: handle errors + } + + op_queue->pop(rsp, dbuf); + + this->op_pending--; // atomic dec + + if (!all) break; + } +} + +void ggml_hexagon_session::flush_batch() { + if (op_batch->empty()) { return; } + + htp_opbatch_req req {}; + dspqueue_buffer dbuf{}; + + if (!op_queue->push(req, dbuf, op_batch)) { + flush_pending(false); + op_queue->push(req, dbuf, op_batch); + } + + // Bump pending flag (cleared in the session::flush once we get the response) + this->op_pending++; // atomic inc + + HEX_VERBOSE("ggml-hex: %s queue-opbatch: %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); + + int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); + if (err != 0) { + GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); + } +} + +void ggml_hexagon_session::enqueue_op(htp_op_code opcode, const ggml_tensor *op) { + if (!op_batch->fit_op(op)) { + flush_batch(); + } + op_batch->add_op(opcode, op); +} + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush(bool all) { + flush_batch(); + flush_pending(all); +} + +static size_t ggml_hexagon_measure_max_vmem(ggml_hexagon_session *sess) { + // Allocate a bunch pinned buffers till failure. + // This is kind of expensive but handy for figuring out exactly how much we can mmap on a specific device. + // Typically we're going to allocate all/most of these buffers anyway for the model weights. + + std::vector sbufs; + + const size_t MiB = 1024 * 1024; + const size_t GiB = MiB * 1024; + + size_t vmem = 0; + size_t step = 256u * MiB; + + try { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + + while (1) { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, step, true)); + vmem += step; + } + } catch (...) { } + + for (auto b : sbufs) { delete b; } + + return vmem - step; // backoff to account for overhead from internal mappings +} + +void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { + this->valid_session = false; + this->valid_handle = false; + this->valid_queue = false; + this->valid_iface = false; + + this->domain_id = 3; // Default for CDSP, updated after the session is created + this->session_id = 0; // Default for CDSP, updated after the session is created + this->dev_id = dev_id; + this->name = std::string("HTP") + std::to_string(dev_id); + + this->op_pending = 0; + + GGML_LOG_DEBUG("ggml-hex: %s allocating new session\n", this->name.c_str()); + + domain * my_domain = get_domain(this->domain_id); + if (my_domain == NULL) { + GGML_LOG_ERROR("ggml-hex: unable to get domain struct for CDSP\n"); + throw std::runtime_error("ggml-hex: failed to get CDSP domain (see log for details)"); + } + + // Create new session + if (dev_id != 0) { + struct remote_rpc_reserve_new_session n; + n.domain_name_len = strlen(CDSP_DOMAIN_NAME); + n.domain_name = const_cast(CDSP_DOMAIN_NAME); + n.session_name = const_cast(this->name.c_str()); + n.session_name_len = this->name.size(); + + int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n)); + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: failed to reserve new session %d : error 0x%x\n", dev_id, err); + throw std::runtime_error("ggml-hex: remote_session_control(new-sess) failed (see log for details)"); + } + + // Save the IDs + this->session_id = n.session_id; + this->domain_id = n.effective_domain_id; + this->valid_session = true; + } + + // Get session URI + + char session_uri[256]; + { + char htp_uri[256]; + snprintf(htp_uri, sizeof(htp_uri), "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch); + + struct remote_rpc_get_uri u = {}; + u.session_id = this->session_id; + u.domain_name = const_cast(CDSP_DOMAIN_NAME); + u.domain_name_len = strlen(CDSP_DOMAIN_NAME); + u.module_uri = const_cast(htp_uri); + u.module_uri_len = strlen(htp_uri); + u.uri = session_uri; + u.uri_len = sizeof(session_uri); + + int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u)); + if (err != AEE_SUCCESS) { + // fallback to single session uris + int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN; + + snprintf(session_uri, htp_URI_domain_len, "%s%s", htp_uri, my_domain->uri); + + GGML_LOG_WARN("ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\n", dev_id, err, session_uri); + } + } + + // Enable Unsigned PD + { + struct remote_rpc_control_unsigned_module u; + u.domain = this->domain_id; + u.enable = 1; + int err = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u)); + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\n", dev_id, err); + throw std::runtime_error("ggml-hex: remote_session_control(unsign) failed (see log for details)"); + } + } + + // Open session + int err = htp_iface_open(session_uri, &this->handle); + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: failed to open session %d : error 0x%x\n", dev_id, err); + throw std::runtime_error("ggml-hex: failed to open session (see log for details)"); + } + + this->valid_handle = true; + + // Enable FastRPC QoS mode + { + struct remote_rpc_control_latency l; + l.enable = 1; + + int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l)); + if (err != 0) { + GGML_LOG_WARN("ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\n", (unsigned) err); + } + } + + GGML_LOG_INFO("ggml-hex: %s new session : session-id %d domain-id %d uri %s handle 0x%lx\n", this->c_name(), + this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); + + const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024; + const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024; + + // Now let's setup the DSP queue + err = dspqueue_create(this->domain_id, + 0, // Flags + req_q_size, // Request queue size (in bytes) + rsp_q_size, // Response queue size (in bytes) + nullptr, // Read packet callback (we handle reads explicitly) + nullptr, // Error callback (we handle errors during reads) + (void *) this, // Callback context + &queue); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: %s dspqueue_create failed: 0x%08x\n", this->name.c_str(), (unsigned) err); + throw std::runtime_error("ggml-hex: failed to create dspqueue (see log for details)"); + } + + this->valid_queue = true; + + // Export queue for use on the DSP + err = dspqueue_export(queue, &this->queue_id); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: dspqueue_export failed: 0x%08x\n", (unsigned) err); + throw std::runtime_error("ggml-hex: dspqueue export failed (see log for details)"); + } + + if (opt_etm) { + err = htp_iface_etm(this->handle, 1); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err); + } + } + + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + std::copy(opt_pmu_evt.begin(), opt_pmu_evt.end(), pmu_conf.events); + + err = htp_iface_profiler(this->handle, opt_profile, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to enable profiling: 0x%08x\n", (unsigned) err); + } + } + + // Allocate buffers and state for op batching + this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue); + + if (!opt_vmem) { + opt_vmem = ggml_hexagon_measure_max_vmem(this); + GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem); + } + + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem); + + // Start dspqueue/opbatch processing + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err); + throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); + } + this->valid_iface = true; +} + +void ggml_hexagon_session::release() noexcept(true) { + GGML_LOG_INFO("ggml-hex: releasing session: %s\n", this->name.c_str()); + + int err; + + if (this->valid_iface) { + // Stop dspqueue/opbatch processing + err = htp_iface_stop(this->handle); + if (err != 0) { + GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err); + } + } + + delete this->op_batch; + delete this->op_queue; + + if (opt_etm) { + err = htp_iface_etm(this->handle, 0); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err); + } + } + + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + err = htp_iface_profiler(this->handle, 0, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: warn : failed to disable profiling: 0x%08x\n", (unsigned) err); + } + } + + if (this->valid_queue) { + err = dspqueue_close(queue); + if (err != 0) { + GGML_ABORT("ggml-hex: dspqueue_close failed: 0x%08x\n", (unsigned) err); + } + } + + if (this->valid_handle) { + htp_iface_close(this->handle); + } +} + +ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) { + buffer_type.device = dev; + repack_buffer_type.device = dev; + + op_batch = nullptr; + op_queue = nullptr; + + try { + allocate(dev_id); + + buffer_type.iface = ggml_backend_hexagon_buffer_type_interface; + buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this); + + repack_buffer_type.iface = ggml_backend_hexagon_repack_buffer_type_interface; + repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this); + } catch (const std::exception & exc) { + release(); + throw; + } +} + +ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) { + release(); + + delete static_cast(buffer_type.context); + delete static_cast(repack_buffer_type.context); +} + +// ** backend interface + +static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment; } @@ -1799,9 +2254,12 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return false; } - return opt_experimental; -} + if (dst->ne[3] != 1) { + return false; + } + return true; +} static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; @@ -1818,6 +2276,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: if (src0->ne[0] % 32) { return false; @@ -1867,6 +2326,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: if ((src0->ne[0] % 32)) { return false; @@ -1960,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return false; } - // TODO: add support for non-contigiuos tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + // TODO: add support for non-contiguous elements within a row + if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) { return false; } @@ -2064,6 +2524,23 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s } } + // Reject non-HVX-aligned sizes when ne[0] > HVX_F32_LANES + // The HVX softmax implementation has issues with tail handling for larger non-aligned sizes + // Small sizes (ne[0] <= 32) work correctly with tail-only processing + const int64_t ne0 = src0->ne[0]; + if (ne0 > 32 && (ne0 & (32 - 1)) != 0) { + return false; + } + + // HVX vector size constraints for softmax + #define SOFTMAX_MAX_ROW_SIZE 131072 // 128K elements max for numerical precision + + // Reject very large row sizes to avoid numerical precision issues + // Softmax accumulation over many elements can lead to precision loss + if (ne0 > SOFTMAX_MAX_ROW_SIZE) { + return false; + } + return true; } @@ -2192,359 +2669,104 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return false; // src0 should be effectively 3D } - const int d_conv = src1->ne[0]; - const int d_inner = src0->ne[1]; - const int n_t = dst->ne[1]; - const int n_s = dst->ne[2]; - - if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) { - return false; - } - if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) { - return false; - } - if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { - return false; - } - - // TODO: add support for non-contiguous tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { - return false; - } - - return true; -} - -enum dspqbuf_type { - DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, - DSPQBUF_TYPE_CPU_WRITE_DSP_READ, - DSPQBUF_TYPE_CONSTANT, -}; - -static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) { - if (opt_verbose < 2) return; - - auto buf = static_cast(t->buffer->context); - auto sess = buf->sess; - - GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(), - t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset, - (unsigned int) d->size); -} - -// Init hexagon tensor from GGML tensor and Hexagon buffer -static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) { - h->data = 0; // updated by the receiver - h->type = t->type; - h->ne[0] = t->ne[0]; - h->ne[1] = t->ne[1]; - h->ne[2] = t->ne[2]; - h->ne[3] = t->ne[3]; - h->nb[0] = t->nb[0]; - h->nb[1] = t->nb[1]; - h->nb[2] = t->nb[2]; - h->nb[3] = t->nb[3]; -} - -static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) { - if (!t) { - return 0; - } - - auto buf = static_cast(t->buffer->context); - - memset(d, 0, sizeof(*d)); - d->fd = buf->fd; - d->ptr = t->data; - d->offset = (uint8_t *) t->data - buf->base; - d->size = ggml_nbytes(t); - - if (!d->size) { - // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty - d->size = 64; - } - - switch (type) { - case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: - // Flush CPU - d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER; - break; - case DSPQBUF_TYPE_CPU_WRITE_DSP_READ: - // Flush CPU, Invalidate DSP - d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - break; - default: - // Constant buffer, no cache maintenance - d->flags = 0; - break; - } - - htp_req_tensor_init(h, t); - - dspqbuf_dump(d, t, type); - - return 1; -} - -typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op); - -template -static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) { - uint64_t t = ggml_time_us(); - - // Construct HTP request - htp_general_req req; - memset(&req, 0, sizeof(req)); - - req.flags = flags; - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; - } - - ggml_hexagon_dump_op_exec(sess->name, op, req.flags); - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - size_t n_bufs = _init_req_func(&req, bufs, op); - sess->enqueue(req, bufs, n_bufs, opt_opsync); - } - - t = ggml_time_us() - t; - - ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t); -} - -template -static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - switch (t->op) { - case GGML_OP_MUL_MAT: - req->op = HTP_OP_MUL_MAT; - break; - case GGML_OP_MUL: - req->op = HTP_OP_MUL; - break; - case GGML_OP_ADD: - req->op = HTP_OP_ADD; - break; - case GGML_OP_SUB: - req->op = HTP_OP_SUB; - break; - case GGML_OP_DIV: - req->op = HTP_OP_DIV; - break; - default: - GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op); - break; - } - - // src0: Weights (mulmat) or First Operand (binary op). - // If constant (e.g. weights), no cache management is needed. - // src1: Input Activations (mulmat) or Second Operand (binary op). - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_CPY; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_GET_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_ARGSORT; - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -template -static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - switch (t->op) { - case GGML_OP_MUL_MAT_ID: - req->op = HTP_OP_MUL_MAT_ID; - break; - case GGML_OP_ADD_ID: - req->op = HTP_OP_ADD_ID; - break; - default: - GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op); - } - - // src0: Weights (mulmat) or Input Activations (other op). - // If constant, no cache management is needed. - // src1: Input Activations (mulmat) or Second Operand (binary op). - // src2: Expert IDs (mulmat) or Activated Experts (other op). - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} + const int d_conv = src1->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; -static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_SET_ROWS; + if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) { + return false; + } + if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) { + return false; + } + if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { + return false; + } - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + // TODO: add support for non-contiguous tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } - return n_bufs; + return true; } -static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - - bool supported = false; - - switch (t->op) { - case GGML_OP_RMS_NORM: - req->op = HTP_OP_RMS_NORM; - supported = true; - break; - - case GGML_OP_SCALE: - req->op = HTP_OP_SCALE; - supported = true; - break; - - case GGML_OP_SQR: - req->op = HTP_OP_SQR; - supported = true; - break; +static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; - case GGML_OP_SQRT: - req->op = HTP_OP_SQRT; - supported = true; - break; + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } - case GGML_OP_UNARY: - if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { - req->op = HTP_OP_UNARY_SILU; - supported = true; - } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) { - req->op = HTP_OP_UNARY_GELU; - supported = true; - } - break; + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } - case GGML_OP_GLU: - if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) { - req->op = HTP_OP_GLU_SWIGLU; - supported = true; - } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) { - req->op = HTP_OP_GLU_SWIGLU_OAI; - supported = true; - } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) { - req->op = HTP_OP_GLU_GEGLU; - supported = true; - } - break; + GGML_UNUSED(sess); + return true; +} - case GGML_OP_SOFT_MAX: - req->op = HTP_OP_SOFTMAX; - supported = true; - break; +static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; - default: - break; + // diag only supports F32 currently + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; } - if (!supported) { - GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op); + // Input must have ne[1] == 1 (vector input) + if (src0->ne[1] != 1) { + return false; } - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_SUM_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + // Output must be square in first two dimensions + if (dst->ne[0] != dst->ne[1] || dst->ne[0] != src0->ne[0]) { + return false; + } - return n_bufs; + GGML_UNUSED(sess); + return true; } -static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_ROPE; +static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // A + const struct ggml_tensor * src1 = op->src[1]; // B + const struct ggml_tensor * dst = op; // X - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} + if (!src0 || !src1) { + return false; + } -static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_FLASH_ATTN_EXT; + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + if (src0->ne[0] != src0->ne[1]) { + return false; + } - return n_bufs; -} + if (src0->ne[1] != src1->ne[1]) { + return false; + } -static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_SSM_CONV; + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return false; + } - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || dst->ne[3] != src1->ne[3]) { + return false; + } - return n_bufs; + GGML_UNUSED(sess); + return true; } static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); - return sess->name.c_str(); + return sess->c_name(); } static void ggml_backend_hexagon_free(ggml_backend_t backend) { @@ -2553,139 +2775,76 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { delete backend; } -static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type)); -} +static htp_op_code op_remap_to_htp(const ggml_tensor * t) { + switch (t->op) { + case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT; + case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT; + case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID; + case GGML_OP_MUL: return HTP_OP_MUL; + case GGML_OP_ADD: return HTP_OP_ADD; + case GGML_OP_ADD_ID: return HTP_OP_ADD_ID; + case GGML_OP_SUB: return HTP_OP_SUB; + case GGML_OP_DIV: return HTP_OP_DIV; + case GGML_OP_CPY: return HTP_OP_CPY; + case GGML_OP_CONT: return HTP_OP_CPY; + case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS; + case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; + case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; + case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_SCALE: return HTP_OP_SCALE; + case GGML_OP_SQR: return HTP_OP_SQR; + case GGML_OP_SQRT: return HTP_OP_SQRT; + case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX; + case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV; + case GGML_OP_ROPE: return HTP_OP_ROPE; + case GGML_OP_REPEAT: return HTP_OP_REPEAT; + case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_FILL: return HTP_OP_FILL; + case GGML_OP_DIAG: return HTP_OP_DIAG; + case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(t)) { + case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; + case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; + case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; + case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; + case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; + default: + break; + } + break; -static inline bool is_compute_op(ggml_tensor *node) -{ - return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); -} + case GGML_OP_GLU: + switch (ggml_get_glu_op(t)) { + case GGML_GLU_OP_SWIGLU: return HTP_OP_GLU_SWIGLU; + case GGML_GLU_OP_SWIGLU_OAI: return HTP_OP_GLU_SWIGLU_OAI; + case GGML_GLU_OP_GEGLU: return HTP_OP_GLU_GEGLU; + default: break; + } + break; -// scan the graph and figure out last compute op index -static inline int last_compute_op(ggml_cgraph * graph) { - int last = 0; - for (int i = 0; i < graph->n_nodes; ++i) { - if (is_compute_op(graph->nodes[i])) { - last = i; - } + default: + GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(t)); } + return HTP_OP_INVALID; +} - return last; +static inline bool op_is_compute(ggml_tensor *node) +{ + return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { auto sess = static_cast(backend->context); - HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes); - - const int last = last_compute_op(graph); - - const struct ggml_tensor * prev_op = nullptr; // prev executed op + HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); for (int i = 0; i < graph->n_nodes; ++i) { - ggml_tensor * node = graph->nodes[i]; - - if (!is_compute_op(node)) { - continue; - } - - uint32_t flags = 0; - - // skip quantizer if src1 is reused - if (op_reuse_src1(node, prev_op)) { - flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - - prev_op = node; - - // ask for early notification for the last Op - if (i == last) { - flags |= HTP_OPFLAGS_EARLY_WAKEUP; - } - - switch (node->op) { - case GGML_OP_MUL_MAT: - if (ggml_is_quantized(node->src[0]->type)) { - ggml_hexagon_dispatch_op>(sess, node, flags); - } else { - ggml_hexagon_dispatch_op>(sess, node, flags); - } - break; - case GGML_OP_MUL_MAT_ID: - if (ggml_is_quantized(node->src[0]->type)) { - ggml_hexagon_dispatch_op>(sess, node, flags); - } else { - ggml_hexagon_dispatch_op>(sess, node, flags); - } - break; - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_DIV: - ggml_hexagon_dispatch_op>(sess, node, flags); - break; - case GGML_OP_ADD_ID: - ggml_hexagon_dispatch_op>(sess, node, flags); - break; - case GGML_OP_RMS_NORM: - case GGML_OP_SCALE: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_SQR: - case GGML_OP_SQRT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_SUM_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_UNARY: - if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || - (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { - ggml_hexagon_dispatch_op(sess, node, flags); - } - break; - case GGML_OP_GLU: - if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) || - (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) { - ggml_hexagon_dispatch_op(sess, node, flags); - } - break; - case GGML_OP_SOFT_MAX: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_ROPE: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_FLASH_ATTN_EXT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_SET_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_GET_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_CPY: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_ARGSORT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_SSM_CONV: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - default: - GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); + ggml_tensor * n = graph->nodes[i]; + if (op_is_compute(n) && (opt_opstage & HTP_OPSTAGE_QUEUE)) { + sess->enqueue_op(op_remap_to_htp(n), n); } } @@ -2698,7 +2857,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { auto sess = static_cast(backend->context); - HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str()); + HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->c_name()); // Wait until all pending ops complete sess->flush(); @@ -2876,6 +3035,8 @@ static struct ggml_backend_i hexagon_backend_i = { /* .free = */ ggml_backend_hexagon_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_hexagon_synchronize, /* .graph_plan_create = */ NULL, @@ -2915,7 +3076,7 @@ static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, c static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) { auto sess = static_cast(dev->context); - return sess->name.c_str(); + return sess->c_name(); GGML_UNUSED(dev); } @@ -2926,8 +3087,7 @@ static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev } static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // ~2GB per session for now - *free = 2ULL * 1024 * 1024 * 1024; + *free = 0; *total = *free; GGML_UNUSED(dev); @@ -3006,9 +3166,58 @@ static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, return true; } +static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + + // CONT is same-type only, supports f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + return true; +} + +static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // Support f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + // src and dst must be the same type + if (src0->type != dst->type) return false; + + // dst dims must be multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0) return false; + if (dst->ne[1] % src0->ne[1] != 0) return false; + if (dst->ne[2] % src0->ne[2] != 0) return false; + if (dst->ne[3] % src0->ne[3] != 0) return false; + + // require contiguous tensors (no transposition) + if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false; + + return true; +} + +static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * dst = op; + + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); + // reject ops that match the filter + if (opt_opfilter && std::regex_match(ggml_op_desc(op), *opt_opfilter)) { + return false; + } + // all srcs & dsts must be mapped to the same session if (!ggml_hexagon_supported_buffers(sess, op)) { ggml_hexagon_dump_op_supp(sess->name, op, false); @@ -3025,6 +3234,13 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = true; break; + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_DIV: + supp = ggml_hexagon_supported_binary(sess, op); + break; + case GGML_OP_MUL_MAT: supp = ggml_hexagon_supported_mul_mat(sess, op); break; @@ -3033,13 +3249,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_mul_mat_id(sess, op); break; - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_DIV: - supp = ggml_hexagon_supported_binary(sess, op); - break; - case GGML_OP_ADD_ID: supp = ggml_hexagon_supported_add_id(sess, op); break; @@ -3063,21 +3272,34 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_UNARY: - { - const auto unary_op = ggml_get_unary_op(op); - if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) { + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_SOFTPLUS: + supp = ggml_hexagon_supported_unary(sess, op); + break; + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; + case GGML_OP_GLU: - { - const auto glu_op = ggml_get_glu_op(op); - if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) { + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; + case GGML_OP_ROPE: supp = ggml_hexagon_supported_rope(sess, op); break; @@ -3098,6 +3320,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_CONT: + supp = ggml_hexagon_supported_cont(sess, op); + break; + + case GGML_OP_REPEAT: + supp = ggml_hexagon_supported_repeat(sess, op); + break; + case GGML_OP_ARGSORT: supp = ggml_hexagon_supported_argsort(sess, op); break; @@ -3106,6 +3336,22 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_ssm_conv(sess, op); break; + case GGML_OP_CUMSUM: + supp = ggml_hexagon_supported_cumsum(sess, op); + break; + + case GGML_OP_FILL: + supp = ggml_hexagon_supported_fill(sess, op); + break; + + case GGML_OP_DIAG: + supp = ggml_hexagon_supported_diag(sess, op); + break; + + case GGML_OP_SOLVE_TRI: + supp = ggml_hexagon_supported_solve_tri(sess, op); + break; + default: break; } @@ -3172,21 +3418,6 @@ struct ggml_hexagon_registry { ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev); - if (!opt_arch) { - int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); - opt_arch = 73; - } - } - -#if defined(__ANDROID__) - if (opt_arch < 75) { - opt_ndev = 1; - GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); - } -#endif - GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch); // Create devices / sessions @@ -3241,6 +3472,26 @@ static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, cons return NULL; } +template std::vector str_to_vec(const char* str) { + std::stringstream ss(str); + std::vector v; + std::string t; + + while (std::getline(ss, t, ',')) { + v.push_back(std::stoul(t, nullptr, 0)); + } + + return v; +} + +template std::string vec_to_str(std::vector v) { + std::stringstream ss; + ss << std::setbase(BASE) << std::showbase; + for (auto i : v) { ss << i << ','; } + auto str = ss.str(); str.pop_back(); // drop last comma + return str; +} + static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, @@ -3249,45 +3500,85 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, + "please update hexagon_type to match ggml_type"); + + const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); + const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); + const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); + const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); + const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); + const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); + const char * str_etm = getenv("GGML_HEXAGON_ETM"); + const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); + const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); + const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + const char * str_vmem = getenv("GGML_HEXAGON_VMEM"); + const char * str_mbuf = getenv("GGML_HEXAGON_MBUF"); + + // Init Arch first since it affects other defaults + if (!str_arch) { + int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); + opt_arch = 73; + } + } else { + if (str_arch[0] == 'v' || str_arch[0] == 'V') { + str_arch++; + } + opt_arch = strtoul(str_arch, NULL, 0); + } + + size_t MiB = 1024 * 1024; - const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL"); - const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); - const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); - const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); - const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC"); - const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); - const char * str_etm = getenv("GGML_HEXAGON_ETM"); - const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); - const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); - const char * str_arch = getenv("GGML_HEXAGON_ARCH"); - - opt_experimental = str_experimental ? atoi(str_experimental) : 0; - opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; - opt_opsync = str_opsync ? atoi(str_opsync) : 0; - opt_profile = str_profile ? atoi(str_profile) : 0; - opt_etm = str_etm ? atoi(str_etm) : 0; - opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + // Update vmem default + opt_vmem = opt_arch >= 75 ? HTP_OP_MAX_VMEM_DEFAULT : 3000 * MiB; + + auto RE_ICASE = std::regex_constants::icase; + + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_profile = str_profile ? atoi(str_profile) : 0; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf; + opt_vmem = str_vmem ? strtoul(str_vmem, NULL, 0) * MiB : opt_vmem; if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { opt_ndev = GGML_HEXAGON_MAX_SESSIONS; } - if (str_arch) { - if (str_arch[0] == 'v') { - str_arch++; - } - opt_arch = strtoul(str_arch, NULL, 0); +#if defined(__ANDROID__) + if (opt_arch < 75) { + opt_ndev = 1; + GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); } +#endif - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1; + if (str_profile) { + opt_pmu_evt = [&]() -> std::vector { + auto v = str_to_vec(str_profile); + switch (v.size()) { + case 1: opt_profile = v[0]; return opt_pmu_evt; // mode with default pmu events + case 8: opt_profile = 2; return v; // mode with custom pmu events + default: opt_profile = 0; return {}; // garbage input + }}(); + if (opt_profile == 1) opt_pmu_evt = {}; + GGML_LOG_INFO("ggml-hex: Profiling mode %u : pmu-evt [ %s ]\n", opt_profile, + vec_to_str(opt_pmu_evt).c_str()); + } reg->context = new ggml_hexagon_registry(reg); - - HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), - sizeof(struct htp_general_rsp)); } static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = { diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 02d07a503d5..7c9e4cda5f1 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -30,8 +30,13 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + repeat-ops.c argsort-ops.c ssm-conv.c + cumsum-ops.c + fill-ops.c + diag-ops.c + solve-tri-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE @@ -39,6 +44,32 @@ target_compile_definitions(${HTP_LIB} PRIVATE $,FARF_HIGH=1,> FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + +# HMX acceleration: available on v73+ architectures +set(HTP_HMX_VERSIONS v73 v75 v79 v81) +list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) + +if (_hmx_idx GREATER_EQUAL 0) + target_sources(${HTP_LIB} PRIVATE + hmx-queue.c + hmx-matmul-ops.c + hmx-flash-attn-ops.c + ) + + # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) + set_source_files_properties( + hmx-matmul-ops.c + hmx-flash-attn-ops.c + PROPERTIES COMPILE_OPTIONS "-mhmx" + ) + + target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) +endif() + build_idl(htp_iface.idl ${HTP_LIB}) set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index d8b924981e0..6416d2dfbc3 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -14,59 +14,42 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define htp_act_preamble3 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; - -#define htp_act_preamble2 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_act_preamble \ + const struct htp_tensor * src0 = actx->octx->src[0]; \ + const struct htp_tensor * src1 = actx->octx->src[1]; \ + const struct htp_tensor * dst = actx->octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = src1 ? src1->ne[0] : 0; \ + const uint32_t ne11 = src1 ? src1->ne[1] : 0; \ + const uint32_t ne12 = src1 ? src1->ne[2] : 0; \ + const uint32_t ne13 = src1 ? src1->ne[3] : 0; \ + \ + const uint32_t nb10 = src1 ? src1->nb[0] : 0; \ + const uint32_t nb11 = src1 ? src1->nb[1] : 0; \ + const uint32_t nb12 = src1 ? src1->nb[2] : 0; \ + const uint32_t nb13 = src1 ? src1->nb[3] : 0; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_act_context { @@ -97,10 +80,7 @@ struct htp_act_context { static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; size_t src0_row_size = actx->src0_row_size; size_t src1_row_size = actx->src1_row_size; @@ -207,10 +187,7 @@ static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -332,9 +309,7 @@ static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, vo static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble2; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -433,9 +408,7 @@ static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble2; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -533,10 +506,7 @@ static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; size_t src0_row_size = actx->src0_row_size; size_t src1_row_size = actx->src1_row_size; @@ -652,9 +622,9 @@ static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * } static int execute_op_activations_f32(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) { FARF(ERROR, "Non-contiguous tensors are not supported at this time \n"); @@ -697,25 +667,20 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); size_t src0_row_size = src0->nb[1]; - size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used + size_t src1_row_size = src1 ? src1->nb[1] : src0->nb[1]; size_t dst_row_size = dst->nb[1]; - const bool src1_valid = src1->ne[0]; - if (!src1_valid) { - src1_row_size = src0_row_size; - } - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size - size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned; size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row); // Make sure the reserved vtcm size is sufficient - if(vtcm_row_per_thread ==0){ + if (vtcm_row_per_thread == 0) { FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size, spad_size_per_row * n_threads); return HTP_STATUS_VTCM_TOO_SMALL; @@ -733,7 +698,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - if (src1->ne[0]) { + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + + if (src1) { FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, @@ -773,9 +742,9 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { // Pointers and GLU logic const uint8_t * data_src0 = (const uint8_t *) src0->data; - const uint8_t * data_src1 = (const uint8_t *) src1->data; + const uint8_t * data_src1 = src1 ? (const uint8_t *) src1->data : NULL; - if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { + if (!src1 && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { const int32_t swapped = octx->op_params[1]; data_src1 = data_src0; actx.src1_row_size = actx.src0_row_size; @@ -799,7 +768,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { int op_activations(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_activations_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index 170220e8f80..bdd0623615d 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -12,7 +12,7 @@ #include "hex-dma.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #ifndef MIN @@ -164,13 +164,19 @@ static void quicksort_values_indices_desc(float * values, int32_t * indices, int if (i < right) quicksort_values_indices_desc(values, indices, i, right); } +// LUT for ramp initialization of argsort output (first 32 members) +int32_t argosrt_ramp_lut[32] __attribute__((aligned(VLEN))) = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 +}; + static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { struct htp_argsort_context * actx = (struct htp_argsort_context *)data; struct htp_ops_context * octx = actx->octx; // Unpack context - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; // Scratchpad memory uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; @@ -205,8 +211,12 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { // Padded to 128 bytes. size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t num_vec_ind_values = hmx_ceil_div(ne00, VLEN/(sizeof(int32_t))); float * values_buf = (float *) spad; int32_t * indices_buf = (int32_t *) (spad + values_size); + HVX_Vector * indices_buf_vec = (HVX_Vector *) (spad + values_size); + const HVX_Vector ind_init_vec = *(HVX_Vector *)argosrt_ramp_lut; + const HVX_Vector ind_diff_vec = Q6_V_vsplat_R(32); for (uint32_t r = start_row; r < end_row; r++) { uint32_t src_offset = r * nb01; @@ -218,9 +228,11 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); - // Initialize indices - for (uint32_t j = 0; j < ne00; j++) { - indices_buf[j] = j; + // Initialize indices - Start with values 0..31, add 32 for additional vec iterations + HVX_Vector curr_ind_vec = ind_init_vec; + for (uint32_t j_vec = 0; j_vec < num_vec_ind_values; j_vec++) { + indices_buf_vec[j_vec] = curr_ind_vec; + curr_ind_vec = Q6_Vw_vadd_VwVw(curr_ind_vec, ind_diff_vec); } // Sort values and mirror swaps to indices @@ -237,16 +249,16 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { int op_argsort(struct htp_ops_context * octx) { // Check supported types - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + const uint32_t total_rows = octx->src[0]->ne[1] * octx->src[0]->ne[2] * octx->src[0]->ne[3]; const uint32_t n_threads = MIN(total_rows, octx->n_threads); // Allocate scratchpad // We need 1 row of float + 1 row of int32 per thread. - uint32_t ne00 = octx->src0.ne[0]; + uint32_t ne00 = octx->src[0]->ne[0]; size_t values_size = hex_round_up(ne00 * sizeof(float), 128); size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); size_t spad_per_thread = values_size + indices_size; @@ -266,9 +278,9 @@ int op_argsort(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = spad_per_thread; FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", - octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], - octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], - octx->src0.data, octx->dst.data); + octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], + octx->dst->ne[0], octx->dst->ne[1], octx->dst->ne[2], octx->dst->ne[3], + octx->src[0]->data, octx->dst->data); struct htp_argsort_context actx; actx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index ec90f22de52..52013ad0fec 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -14,7 +14,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #ifndef MIN @@ -24,31 +24,29 @@ // Context for binary operations struct htp_binary_context { struct htp_ops_context * octx; - struct fastdiv_values dim1_div; - struct fastdiv_values dim2_div; - struct fastdiv_values dim12_div; + + struct fastdiv_values src0_dim1_div; // ne01 + struct fastdiv_values src0_dim2_div; // ne02 + struct fastdiv_values src0_dim12_div;// ne03 struct fastdiv_values src1_dim1_div; // ne11 struct fastdiv_values src1_dim2_div; // ne12 struct fastdiv_values src1_dim3_div; // ne13 - uint32_t nrows_per_thread; - bool split_at_ne01; - bool split_at_ne02; - - // Precomputed values uint32_t block_max; + uint32_t nrows_per_thread; size_t src0_row_size_aligned; size_t src1_row_size_aligned; size_t dst_row_size_aligned; - uint32_t src1_fetch_rows; // 1 or block_max - uint32_t src1_dma_stride; // 0 or stride + + bool split_at_ne01; + bool split_at_ne02; }; -#define htp_binary_preamble \ - const struct htp_tensor * src0 = &octx->src0; \ - const struct htp_tensor * src1 = &octx->src1; \ - struct htp_tensor * dst = &octx->dst; \ +#define htp_binary_preamble \ + const struct htp_tensor * src0 = octx->src[0]; \ + const struct htp_tensor * src1 = octx->src[1]; \ + const struct htp_tensor * dst = octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -72,12 +70,11 @@ struct htp_binary_context { const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, - uint32_t ne01, uint32_t ne02) { +static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, uint32_t ne01, uint32_t ne02) { uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint32_t rows_left = end_row - ir; @@ -184,13 +181,15 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-scalar: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; @@ -204,9 +203,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; @@ -215,7 +214,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -229,9 +228,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; // src1 indices (broadcast/repeat) @@ -255,9 +254,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); + p02 = fastdiv(prem, &bctx->src0_dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; @@ -275,13 +274,15 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-same-shape: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); @@ -297,9 +298,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint32_t i13 = (ne13 == 1) ? 0 : i03; @@ -307,23 +308,23 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint32_t i11 = (ne11 == 1) ? 0 : i01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; - uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * src1_curr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_curr), bctx->src1_row_size_aligned, nb11, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } for (uint32_t ir = start_row; ir < end_row; ) { uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; @@ -335,9 +336,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi } uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); @@ -345,9 +346,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); + p02 = fastdiv(prem, &bctx->src0_dim1_div); p01 = prem - p02 * ne01; uint32_t p13 = (ne13 == 1) ? 0 : p03; @@ -358,7 +359,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, nb11, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } @@ -373,15 +374,17 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-row-bcast: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); - uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; @@ -391,15 +394,14 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint32_t ir_prefetch = start_row; int spad_idx = 0; - void * s1_ptr = (void *) src1_spad; + void * s1_ptr = (void *) src1_spad_base; for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -407,7 +409,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -415,7 +417,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, for (uint32_t ir = start_row; ir < end_row; ) { uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; for (uint32_t r = 0; r < current_block_size; r++) { @@ -425,21 +427,19 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -455,17 +455,19 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-complex: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -473,11 +475,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -485,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -496,11 +497,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; @@ -521,11 +521,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -541,18 +540,20 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + FARF(HIGH, "binary-repeat: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -560,11 +561,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -572,7 +572,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -583,11 +583,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; @@ -612,11 +611,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -631,10 +629,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { struct htp_binary_context * bctx = (struct htp_binary_context *) data; struct htp_ops_context * octx = bctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; const uint32_t ne00 = src0->ne[0]; const uint32_t ne01 = src0->ne[1]; @@ -646,6 +644,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { const uint32_t nb02 = src0->nb[2]; const uint32_t nb03 = src0->nb[3]; const uint32_t nb11 = src1->nb[1]; // src1 row stride + const uint32_t nb1 = dst->nb[1]; const uint32_t nb2 = dst->nb[2]; const uint32_t nb3 = dst->nb[3]; @@ -657,8 +656,8 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -666,11 +665,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -678,7 +676,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -689,11 +687,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 @@ -712,11 +709,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); ir_prefetch += next_block_size; @@ -727,52 +723,48 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { } static int execute_op_binary(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); // Use packed row sizes for VTCM allocation - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const size_t src0_row_size = src0->ne[0] * elem_size; const size_t src1_row_size = src1->ne[0] * elem_size; - const size_t dst_row_size = dst->ne[0] * elem_size; + const size_t dst_row_size = dst->ne[0] * elem_size; - // Align to VLEN - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); bool is_add_id = (octx->op == HTP_OP_ADD_ID); bool is_scalar = !is_add_id && (src1->ne[0] == 1); - // Determine which kernel we will use to alloc memory and dispatch - bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) && + bool is_transposed = (src0->nb[1] < src0_row_size || src1->nb[1] < src1_row_size || dst->nb[1] < dst_row_size); + + bool is_same_shape = !is_add_id && !is_scalar && !is_transposed && + (src1->ne[0] == src0->ne[0] && src0->ne[0] % VLEN == 0) && (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); - bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); - bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]); - bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]); + bool is_row_bcast = is_same_shape && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + bool is_complex = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] == src0->ne[0]); + bool is_repeat = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] != src0->ne[0]); size_t spad_row_total; - if (is_scalar) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); - } else if (is_row_bcast) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); - } else if (use_vector_same) { + if (is_same_shape) { spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); - } else if (is_add_id) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly } else { spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); } size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); + // Adjust for static src1 in row_bcast case if (is_row_bcast) { size_t needed_static = src1_row_size_aligned; @@ -782,36 +774,34 @@ static int execute_op_binary(struct htp_ops_context * octx) { } if (rows_per_buffer < 1) { - FARF(ERROR, "binary: VTCM too small\n"); - return HTP_STATUS_VTCM_TOO_SMALL; + FARF(ERROR, "binary: VTCM too small\n"); + return HTP_STATUS_VTCM_TOO_SMALL; } octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; - if (is_scalar || use_complex || use_repeat || is_add_id) { - octx->src1_spad.size_per_thread = 0; - } else if (is_row_bcast) { + if (is_add_id || is_scalar || is_complex || is_repeat || is_row_bcast) { octx->src1_spad.size_per_thread = 0; } else { octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; } + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; if (is_row_bcast) { octx->src1_spad.size = src1_row_size_aligned; } else { octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; } - octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { return HTP_STATUS_OK; @@ -823,46 +813,37 @@ static int execute_op_binary(struct htp_ops_context * octx) { } struct htp_binary_context bctx; - bctx.octx = octx; - bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; - bctx.block_max = rows_per_buffer; + bctx.octx = octx; + bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + bctx.block_max = rows_per_buffer; bctx.src0_row_size_aligned = src0_row_size_aligned; bctx.src1_row_size_aligned = src1_row_size_aligned; bctx.dst_row_size_aligned = dst_row_size_aligned; - bctx.dim1_div = init_fastdiv_values(src0->ne[1]); - bctx.dim2_div = init_fastdiv_values(src0->ne[2]); - bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); + bctx.src0_dim1_div = init_fastdiv_values(src0->ne[1]); + bctx.src0_dim2_div = init_fastdiv_values(src0->ne[2]); + bctx.src0_dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); - bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); - bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); - bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); + bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); + bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); + bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); - bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); + bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); - bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); - - bctx.split_at_ne01 = (src0->ne[2] > 1) && - ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); - bctx.split_at_ne02 = (src0->ne[3] > 1) && - ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); - - // Precompute specific kernel parameters - if (use_vector_same) { - bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1]; - bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer; - } + bctx.split_at_ne01 = (src0->ne[2] > 1) && ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + bctx.split_at_ne02 = (src0->ne[3] > 1) && ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); worker_callback_t worker_func; - if (is_add_id) worker_func = binary_job_add_id; - else if (is_scalar) worker_func = binary_job_scalar; - else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; - else if (use_vector_same) worker_func = binary_job_vector_same_shape; - else if (use_complex) worker_func = binary_job_vector_complex; - else worker_func = binary_job_element_repeat; + if (is_add_id) worker_func = binary_job_add_id; + else if (is_scalar) worker_func = binary_job_scalar; + else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; + else if (is_same_shape) worker_func = binary_job_vector_same_shape; + else if (is_complex) worker_func = binary_job_vector_complex; + else worker_func = binary_job_element_repeat; if (is_row_bcast) { dma_queue_pop(q); @@ -876,12 +857,12 @@ static int execute_op_binary(struct htp_ops_context * octx) { int op_binary(struct htp_ops_context * octx) { // Does not support permutations of src1 - const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src1 = octx->src[1]; if (src1->nb[1] < src1->nb[0]) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) { return execute_op_binary(octx); } diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake index 7fa236e328f..ed5c198468c 100644 --- a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") #Compiler Options -set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") +set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}") set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}") diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index a40d866b9c3..e5b9d350fd7 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -11,7 +11,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" @@ -32,10 +32,10 @@ struct htp_copy_context { void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); }; -#define cpy_preamble \ - struct htp_tensor *src0 = &octx->src0; \ - struct htp_tensor *dst = &octx->dst; \ - \ +#define cpy_preamble \ + const struct htp_tensor *src0 = octx->src[0]; \ + const struct htp_tensor *dst = octx->dst; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ diff --git a/ggml/src/ggml-hexagon/htp/cumsum-ops.c b/ggml/src/ggml-hexagon/htp/cumsum-ops.c new file mode 100644 index 00000000000..2ced1971236 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/cumsum-ops.c @@ -0,0 +1,270 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" +#include "hex-dma.h" + +#define htp_cumsum_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_cumsum_context { + struct htp_ops_context * octx; + size_t src_row_size; + size_t dst_row_size; + size_t src_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t rows_per_thread; + uint32_t total_rows; +}; + +#define htp_cumsum_preamble \ + struct htp_cumsum_context * cctx = (struct htp_cumsum_context *) data; \ + struct htp_ops_context * octx = cctx->octx; \ + htp_cumsum_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX prefix scan helpers +// --------------------------------------------------------------------------- + +#if __HVX_ARCH__ > 75 +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} +#else +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} +#endif // __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_prefix_scan_f32(HVX_Vector v, HVX_Vector carry_in) { + const HVX_Vector zero = Q6_V_vsplat_R(0); + + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 4)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 8)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 16)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 32)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 64)); + v = hvx_cumsum_vadd(v, carry_in); + + return v; +} + +static inline HVX_Vector hvx_splat_last_f32(HVX_Vector v) { + return hvx_vec_repl4(Q6_V_vror_VR(v, 124)); +} + +static inline void hvx_cumsum_row_f32(const float * restrict src, float * restrict dst, uint32_t n) { + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + HVX_Vector carry = Q6_V_vsplat_R(0); + + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v = *((const HVX_UVector *) (src + i * VLEN_FP32)); + v = hvx_prefix_scan_f32(v, carry); + hvx_vec_store_u(dst + i * VLEN_FP32, VLEN, v); + carry = hvx_splat_last_f32(v); + } + + if (nloe) { + float acc = hvx_vec_get_f32(carry); + const float * src_tail = src + nvec * VLEN_FP32; + float * dst_tail = dst + nvec * VLEN_FP32; + for (uint32_t i = 0; i < nloe; i++) { + acc += src_tail[i]; + dst_tail[i] = acc; + } + } +} + +// --------------------------------------------------------------------------- +// Per thread worker: Double-buffered DMA +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + if (ir0 >= ir1) { + return; + } + + const size_t src_row_size = cctx->src_row_size; + const size_t dst_row_size = cctx->dst_row_size; + const size_t src_row_size_aligned = cctx->src_row_size_aligned; + const size_t dst_row_size_aligned = cctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + uint8_t * src_spad = octx->src0_spad.data + (ith * src_row_size_aligned * 2); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned * 2); + + for (uint32_t ir = ir0, spad_idx = 0; ir < ir1 && spad_idx < 2; ir++, spad_idx++) { + // Dummy dst writeback to establish queue ordering + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data, dst_spad + (spad_idx * dst_row_size_aligned)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad + (spad_idx * src_row_size_aligned), + src_data + (ir * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + + for (uint32_t ir = ir0; ir < ir1; ir++) { + float * dst_spad_row = (float *) dma_queue_pop(dma_queue).src; + float * src_spad_row = (float *) dma_queue_pop(dma_queue).dst; + + hvx_cumsum_row_f32(src_spad_row, dst_spad_row, ne00); + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data + (ir * dst_row_size), (uint8_t *) dst_spad_row), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < ir1) { + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr((uint8_t *) src_spad_row, src_data + (next_row * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + for (uint32_t ir = ir0; ir < ir1; ir++) { + const float * restrict src_row = (const float *) (src_data + ir * cctx->src_row_size); + float * restrict dst_row = (float *) (dst_data + ir * cctx->dst_row_size); + hvx_cumsum_row_f32(src_row, dst_row, ne00); + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_cumsum_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_rows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_rows); + + const size_t src_row_size = src0->nb[1]; + const size_t dst_row_size = dst->nb[1]; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 2 ping-pong buffers per thread for src and dst + const size_t spad_per_thread = 2 * (src_row_size_aligned + dst_row_size_aligned); + + octx->src0_spad.size_per_thread = src_row_size_aligned * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * 2; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_cumsum_context cctx = { + .octx = octx, + .src_row_size = src_row_size, + .dst_row_size = dst_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .rows_per_thread = (total_rows + n_threads - 1) / n_threads, + .total_rows = total_rows, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32, &cctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32_dma, &cctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_cumsum(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_cumsum_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/diag-ops.c b/ggml/src/ggml-hexagon/htp/diag-ops.c new file mode 100644 index 00000000000..9b3194d9084 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/diag-ops.c @@ -0,0 +1,216 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hex-utils.h" +#include "hvx-copy.h" +#include "hex-dma.h" + +#define htp_diag_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne02 = src0->ne[2]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_diag_context { + struct htp_ops_context * octx; + size_t src_batch_size; + size_t dst_row_size; + size_t src_batch_size_aligned; + size_t dst_row_size_aligned; + uint32_t batches_per_thread; + uint32_t total_batches; +}; + +#define htp_diag_preamble \ + struct htp_diag_context * dctx = (struct htp_diag_context *) data; \ + struct htp_ops_context * octx = dctx->octx; \ + htp_diag_tensors_preamble; + +static inline void hvx_diag_row_f32(const float * restrict src, float * restrict dst, + uint32_t row_idx, uint32_t n) { + hvx_splat_f32_a((uint8_t *) dst, 0.0f, n); + dst[row_idx] = src[row_idx]; +} + +// --------------------------------------------------------------------------- +// Per thread worker: DMA src fetch, compute in VTCM, DMA dst writeback +// --------------------------------------------------------------------------- + +static void diag_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + dma_queue * dma_queue = octx->ctx->dma[ith]; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + if (ib0 >= ib1) { + return; + } + + const size_t src_batch_size = dctx->src_batch_size; + const size_t dst_row_size = dctx->dst_row_size; + const size_t src_batch_size_aligned = dctx->src_batch_size_aligned; + const size_t dst_row_size_aligned = dctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + // 1 src buffer + 1 dst row buffer per thread in VTCM + uint8_t * src_spad = octx->src0_spad.data + (ith * src_batch_size_aligned); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const uint8_t * src_batch = src_data + i3 * nb03 + i2 * nb02; + + // Fetch source vector into VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad, src_batch), + src_batch_size_aligned, src_batch_size, 1); + dma_queue_flush(dma_queue); + + const float * src_spad_f32 = (const float *) src_spad; + float * dst_spad_f32 = (float *) dst_spad; + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + // Compute row in VTCM + hvx_diag_row_f32(src_spad_f32, dst_spad_f32, i1, ne0); + + // Write completed row back to DDR + uint8_t * dst_row = dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_row, dst_spad), + dst_row_size, dst_row_size_aligned, 1); + dma_queue_flush(dma_queue); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void diag_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const float * restrict src_batch = (const float *)(src_data + i3 * nb03 + i2 * nb02); + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + float * restrict dst_row = (float *)(dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1); + hvx_diag_row_f32(src_batch, dst_row, i1, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_diag_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_batches); + + const size_t src_batch_size = src0->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + const size_t src_batch_size_aligned = hex_round_up(src_batch_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 1 src buffer + 1 dst row buffer per thread + const size_t spad_per_thread = src_batch_size_aligned + dst_row_size_aligned; + + octx->src0_spad.size_per_thread = src_batch_size_aligned; + octx->dst_spad.size_per_thread = dst_row_size_aligned; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_diag_context dctx = { + .octx = octx, + .src_batch_size = src_batch_size, + .dst_row_size = dst_row_size, + .src_batch_size_aligned = src_batch_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .batches_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_batches = total_batches, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32, &dctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32_dma, &dctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_diag(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_diag_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/fill-ops.c b/ggml/src/ggml-hexagon/htp/fill-ops.c new file mode 100644 index 00000000000..3ccfbe74ee4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/fill-ops.c @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-copy.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +// ggml op_params layout for FILL: +// op_params[0] (as float) - the scalar fill value + +#define fill_preamble \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne1 * ne2 * ne3; + +struct htp_fill_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint32_t total_rows; // ne1 * ne2 * ne3 + bool opt_path; + HVX_Vector splat_vec; + uint32_t elem_size; +}; + +static void fill_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_fill_context * fctx = (const struct htp_fill_context *) data; + struct htp_ops_context * octx = fctx->octx; + fill_preamble; + + // Parallelise over the flat row index spanning ne1*ne2*ne3 + const uint32_t ir0 = fctx->nrows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + fctx->nrows_per_thread, fctx->total_rows); + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + if (fctx->opt_path) { + // Opt path: tensor is fully contiguous, treat as flat array + const uint32_t elem_start = ir0 * ne0; + const uint32_t elem_end = ir1 * ne0; + uint8_t * dst_ptr = (uint8_t *) dst->data + elem_start * fctx->elem_size; + hvx_splat_u(dst_ptr, fctx->splat_vec, elem_end - elem_start, fctx->elem_size); + } else { + // Non-contiguous path: must respect strides + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1*nb1 + i2*nb2 + i3*nb3; + hvx_splat_u(dst_ptr, fctx->splat_vec, ne0, fctx->elem_size); + } + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + FARF(HIGH, "fill %u/%u: rows %u:%u usec %u\n", + ith, nth, ir0, ir1, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_fill(struct htp_ops_context * octx) { + fill_preamble; + + if (dst->type != HTP_TYPE_F32 && dst->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // nr = ne1*ne2*ne3 (flat row count across all outer dims); parallelise over it. + const uint32_t n_threads = MIN(nr, octx->n_threads); + + // Optimize if fully contiguous: skip stride arithmetic, treat as flat array + const bool opt_path = (nb2 == nb1 * ne1) && (nb3 == nb2 * ne2); + + FARF(HIGH, "fill: (%ux%ux%ux%u) type=%u opt=%d\n", + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->type, (int) opt_path); + + float val_f32 = 0.f; + memcpy(&val_f32, &octx->op_params[0], sizeof(float)); + + struct htp_fill_context fctx = { + .octx = octx, + .nrows_per_thread = (nr + n_threads - 1) / n_threads, + .total_rows = nr, + .opt_path = opt_path, + }; + + switch (dst->type) { + case HTP_TYPE_F32: + fctx.splat_vec = hvx_vec_splat_f32(val_f32); + fctx.elem_size = sizeof(float); + break; + case HTP_TYPE_F16: + fctx.splat_vec = hvx_vec_splat_f16((_Float16) val_f32); + fctx.elem_size = sizeof(_Float16); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, fill_thread, &fctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 6dc978dd68a..d95df6ac9d5 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -15,15 +15,16 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" #include "htp-ops.h" +#include "htp-ops.h" +#include "hmx-ops.h" // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) // This is a bit of a hack because the compiler is strugling to properly inline // the default hvx_vec_f32_to_f16 with output into the local array. -static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) +static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) { *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1); } @@ -278,12 +279,12 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_fa_context * factx = (struct htp_fa_context *) data; const struct htp_ops_context * octx = factx->octx; - const struct htp_tensor * q = &octx->src0; - const struct htp_tensor * k = &octx->src1; - const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; - const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = octx->src[3]; + const struct htp_tensor * sinks = octx->src[4]; + const struct htp_tensor * dst = octx->dst; const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; @@ -346,6 +347,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap); + dma_cache m_cache; + dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE); + for (uint32_t ir = ir0; ir < ir1; ++ir) { const uint32_t iq3 = fastdiv(ir, &factx->src0_div21); const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1); @@ -389,9 +393,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); - uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block; // Mask is 1D contiguous for this row - dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", @@ -554,7 +557,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); - dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); } // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", @@ -608,17 +611,28 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * } int op_flash_attn_ext(struct htp_ops_context * octx) { - const struct htp_tensor * q = &octx->src0; - const struct htp_tensor * k = &octx->src1; - const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = octx->src[3]; + const struct htp_tensor * dst = octx->dst; // Check support if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } +#ifdef HTP_HAS_HMX + // HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) { + int ret = hmx_flash_attn_ext(octx); + if (ret == HTP_STATUS_OK) { + return ret; + } + // VTCM too small or other failure -> fall through to HVX path + } +#endif + struct htp_fa_context factx; factx.octx = octx; @@ -684,7 +698,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = size_q_block * 1; octx->src1_spad.size_per_thread = factx.size_k_block * 2; octx->src2_spad.size_per_thread = factx.size_v_block * 2; - octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0; + octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0; octx->dst_spad.size_per_thread = size_vkq_acc; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -699,11 +713,11 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->src2_spad.src = NULL; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->src3_spad.src = NULL; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; octx->dst_spad.src = NULL; if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 047d2850aaa..5a1dc933860 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -11,7 +11,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" @@ -23,27 +23,33 @@ struct get_rows_context { }; #define get_rows_preamble \ - const uint32_t ne00 = octx->src0.ne[0]; \ - const uint32_t ne01 = octx->src0.ne[1]; \ - const uint32_t ne02 = octx->src0.ne[2]; \ - const uint32_t ne03 = octx->src0.ne[3]; \ - \ - const uint32_t ne10 = octx->src1.ne[0]; \ - const uint32_t ne11 = octx->src1.ne[1]; \ - const uint32_t ne12 = octx->src1.ne[2]; \ - \ - const uint32_t nb01 = octx->src0.nb[1]; \ - const uint32_t nb02 = octx->src0.nb[2]; \ - const uint32_t nb03 = octx->src0.nb[3]; \ - \ - const uint32_t nb10 = octx->src1.nb[0]; \ - const uint32_t nb11 = octx->src1.nb[1]; \ - const uint32_t nb12 = octx->src1.nb[2]; \ - \ - const uint32_t nb1 = octx->dst.nb[1]; \ - const uint32_t nb2 = octx->dst.nb[2]; \ - const uint32_t nb3 = octx->dst.nb[3]; \ - \ + const uint32_t ne00 = octx->src[0]->ne[0]; \ + const uint32_t ne01 = octx->src[0]->ne[1]; \ + const uint32_t ne02 = octx->src[0]->ne[2]; \ + const uint32_t ne03 = octx->src[0]->ne[3]; \ + \ + const uint32_t ne10 = octx->src[1]->ne[0]; \ + const uint32_t ne11 = octx->src[1]->ne[1]; \ + const uint32_t ne12 = octx->src[1]->ne[2]; \ + const uint32_t ne13 = octx->src[1]->ne[3]; \ + \ + const uint32_t ne0 = octx->dst->ne[0]; \ + const uint32_t ne1 = octx->dst->ne[1]; \ + const uint32_t ne2 = octx->dst->ne[2]; \ + const uint32_t ne3 = octx->dst->ne[3]; \ + \ + const uint32_t nb01 = octx->src[0]->nb[1]; \ + const uint32_t nb02 = octx->src[0]->nb[2]; \ + const uint32_t nb03 = octx->src[0]->nb[3]; \ + \ + const uint32_t nb10 = octx->src[1]->nb[0]; \ + const uint32_t nb11 = octx->src[1]->nb[1]; \ + const uint32_t nb12 = octx->src[1]->nb[2]; \ + \ + const uint32_t nb1 = octx->dst->nb[1]; \ + const uint32_t nb2 = octx->dst->nb[2]; \ + const uint32_t nb3 = octx->dst->nb[3]; \ + \ const uint32_t nr = ne10 * ne11 * ne12; static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { @@ -51,12 +57,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da struct htp_ops_context * octx = grctx->octx; get_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by src1 elements (which correspond to dst rows) const uint32_t dr = grctx->src1_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i = ir0; i < ir1; ++i) { const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); @@ -64,7 +72,7 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); const uint32_t i10 = rem - i11 * ne10; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; @@ -73,10 +81,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; - const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3; hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_get_rows(struct htp_ops_context * octx) { @@ -84,15 +96,15 @@ int op_get_rows(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(nr, octx->n_threads); - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->dst.type != HTP_TYPE_F32) { + if (octx->dst->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) { return HTP_STATUS_NO_SUPPORT; } @@ -102,8 +114,8 @@ int op_get_rows(struct htp_ops_context * octx) { struct get_rows_context grctx; grctx.octx = octx; - grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); - grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]); + grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]); grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.c b/ggml/src/ggml-hexagon/htp/hex-dma.c index 44e1be40c5d..b66e2d2603c 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.c +++ b/ggml/src/ggml-hexagon/htp/hex-dma.c @@ -31,8 +31,8 @@ dma_queue * dma_queue_create(size_t capacity) { q->capacity = capacity; q->idx_mask = capacity - 1; - q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t)); - memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t)); + q->desc = (dma_descriptor_2d *) memalign(64, capacity * sizeof(dma_descriptor_2d)); + memset(q->desc, 0, capacity * sizeof(dma_descriptor_2d)); q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr)); memset(q->dptr, 0, capacity * sizeof(dma_ptr)); diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index 350ab9d966f..7685473f463 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -10,19 +10,84 @@ extern "C" { #endif +// Define the HW descriptor structs here since the ones in HexSDK are a bit out of date +typedef struct dma_descriptor_1d_s { + void * next; + uint32_t size:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; +} dma_descriptor_1d; + +#if __HVX_ARCH__ < 75 + +typedef struct dma_descriptor_2d_s { + void * next; + uint32_t reserved0:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; + uint32_t desc_type:8; + uint32_t reserved1:24; + uint32_t row_size:16; + uint32_t nrows:16; + uint32_t src_stride:16; + uint32_t dst_stride:16; + uint32_t src_offset:16; + uint32_t dst_offset:16; +} dma_descriptor_2d; + +#else + +typedef struct dma_descriptor_2d_s { + void * next; + uint32_t dst_stride:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; + uint32_t desc_type:8; + uint32_t reserved0:24; + uint32_t row_size:24; + uint32_t nrows_lo:8; + uint32_t nrows_hi:8; + uint32_t src_stride:24; + uint32_t offset:24; + uint32_t reserved1:8; +} dma_descriptor_2d; + +#endif + typedef struct { - void *dst; + void *dst; const void *src; } dma_ptr; typedef struct { - hexagon_udma_descriptor_type1_t * desc; // descriptor pointers - hexagon_udma_descriptor_type1_t * tail; // tail pointer - dma_ptr * dptr; // dst/src pointers - uint32_t push_idx; - uint32_t pop_idx; - uint32_t capacity; - uint32_t idx_mask; + dma_descriptor_2d * desc; // descriptor pointers + dma_descriptor_2d * tail; // tail pointer + dma_ptr * dptr; // dst/src pointers + uint32_t push_idx; + uint32_t pop_idx; + uint32_t capacity; + uint32_t idx_mask; } dma_queue; dma_queue * dma_queue_create(size_t capacity); @@ -59,71 +124,95 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src) return p; } -static inline bool dma_queue_push(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t width, // width in bytes. number of bytes to transfer per row - size_t nrows) { +#if __HVX_ARCH__ < 73 +static const uint32_t dma_src_l2_bypass_on = 1; +static const uint32_t dma_dst_l2_bypass_on = 0; +#else +static const uint32_t dma_src_l2_bypass_on = 1; +static const uint32_t dma_dst_l2_bypass_on = 1; +#endif + +static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) { if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { - FARF(ERROR, "dma-push: queue full\n"); + FARF(HIGH, "dma-push: queue full\n"); return false; } - hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx]; + dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx]; + desc->next = NULL; + desc->desc_size = 0; // 1D mode + desc->src_bypass = dma_src_l2_bypass_on; + desc->dst_bypass = dma_dst_l2_bypass_on; + desc->order = 0; + desc->done = 0; + desc->src = (void *) dptr.src; + desc->dst = (void *) dptr.dst; + desc->size = size; + + q->dptr[q->push_idx] = dptr; + + if (size) { + dmlink(q->tail, desc); + q->tail = (dma_descriptor_2d *) desc; + } else { + desc->done = 1; + } + + // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); + q->push_idx = (q->push_idx + 1) & q->idx_mask; + return true; +} + +static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { + FARF(HIGH, "dma-push: queue full\n"); + return false; + } + + dma_descriptor_2d * desc = &q->desc[q->push_idx]; desc->next = NULL; - desc->length = 0; - desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; - desc->dstbypass = 1; - desc->srcbypass = 1; -#if __HVX_ARCH__ >= 73 - desc->dstbypass = 1; - desc->srcbypass = 1; -#else - desc->dstbypass = 0; - desc->srcbypass = 1; -#endif + desc->reserved0 = 0; + desc->reserved1 = 0; + desc->desc_size = 1; // 2d mode + desc->src_bypass = dma_src_l2_bypass_on; + desc->dst_bypass = dma_dst_l2_bypass_on; + desc->src_comp = 0; + desc->dst_comp = 0; desc->order = 0; - desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; + desc->done = 0; + desc->src_stride = src_stride; + desc->dst_stride = dst_stride; desc->src = (void *) dptr.src; desc->dst = (void *) dptr.dst; - desc->allocation = 0; - desc->padding = 0; - desc->roiwidth = width; - desc->roiheight = nrows; - desc->srcstride = src_row_size; - desc->dststride = dst_row_size; - desc->srcwidthoffset = 0; - desc->dstwidthoffset = 0; + desc->row_size = row_size; + +#if __HVX_ARCH__ < 75 + desc->desc_type = 0; // 2d (16-bit) mode + desc->nrows = nrows; + desc->src_offset = 0; + desc->dst_offset = 0; +#else + desc->desc_type = 9; // 2d (24-bit) mode + desc->nrows_lo = (nrows & 0xff); + desc->nrows_hi = (nrows >> 8); + desc->offset = 0; +#endif q->dptr[q->push_idx] = dptr; - dmlink(q->tail, desc); - q->tail = desc; + if (nrows) { + dmlink(q->tail, desc); + q->tail = desc; + } else { + desc->done = 1; + } - // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src); + // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; return true; } -static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { - return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); -} - - -static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { - return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); -} - static inline dma_ptr dma_queue_pop(dma_queue * q) { dma_ptr dptr = { NULL }; @@ -131,15 +220,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { return dptr; } - hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx]; + dma_descriptor_2d * desc = &q->desc[q->pop_idx]; // Wait for desc to complete - while (1) { - dmpoll(); - if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) { - break; - } + while (!desc->done) { // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); + dmpoll(); } dptr = q->dptr[q->pop_idx]; @@ -175,6 +261,110 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) { return q->capacity; } +#if __HVX_ARCH__ < 75 + +// Overflow-safe DMA push: all 2d descriptor fields (row_size, nrows, src_stride, dst_stride) are 16-bit, max 65535. +// This version transparently handles values that exceed the 16-bit limit and submits chained DMA transtions. + +#define DMA_MAX_FIELD_VAL 65535u + +static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + // Fast path: everything fits in 16 bits + if (nrows == 0 || __builtin_expect( + row_size <= DMA_MAX_FIELD_VAL && + nrows <= DMA_MAX_FIELD_VAL && + src_stride <= DMA_MAX_FIELD_VAL && + dst_stride <= DMA_MAX_FIELD_VAL, 1)) { + return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows); + } + + // Contiguous block + // Use 1d DMA mode which supports sizes up to 24-bits (16MB) + if (nrows == 1 || (row_size == src_stride && row_size == dst_stride)) { + size_t total = row_size * nrows; + return dma_queue_push_single_1d(q, dptr, total); + } + + // Stride overflow — fall back to row-by-row. + { + const uint8_t *src = (const uint8_t *) dptr.src; + uint8_t *dst = (uint8_t *) dptr.dst; + for (size_t r = 0; r < nrows; ++r) { + dma_ptr p = dma_make_ptr(dst + r * dst_stride, src + r * src_stride); + if (!dma_queue_push_single_1d(q, p, row_size)) + return false; + if (r + 1 < nrows) + dma_queue_pop(q); + } + return true; + } +} + +#else // HVX_ARCH >= 75 + +static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + // On v75 and up we always use 2d 24-bit mode + return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows); +} + +#endif + +static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); +} + +static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); +} + +#define DMA_CACHE_MAX_SIZE 64U + +typedef struct { + uint8_t *base; + uint32_t line_size; + uint32_t capacity; + uint32_t src[DMA_CACHE_MAX_SIZE]; + uint16_t age[DMA_CACHE_MAX_SIZE]; +} dma_cache; + +static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity) +{ + c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity; + c->base = base; + c->line_size = line_size; + + for (unsigned i=0; i < c->capacity; i++) { + c->src[i] = 0; + c->age[i] = 0; + } +} + +static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows) +{ + uint32_t o_idx = 0; + uint16_t o_age = 0; + uint8_t * dst = 0; + + for (unsigned i=0; i < c->capacity; i++) { + if (c->src[i] == (uint32_t) src) { + c->age[i] = 0; + dst = c->base + (i * c->line_size); nrows = 0; // dummy dma + // FARF(ERROR, "dma-cache: found %p", src); + } else { + c->age[i]++; + if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; } + } + } + if (!dst) { + // FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src); + c->age[o_idx] = 0; + c->src[o_idx] = (uint32_t) src; + dst = c->base + o_idx * c->line_size; // normal nrows dma + } + + return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows); +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/hex-dump.h b/ggml/src/ggml-hexagon/htp/hex-dump.h index e3badb57f92..19d173c2232 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dump.h +++ b/ggml/src/ggml-hexagon/htp/hex-dump.h @@ -21,6 +21,15 @@ static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t FARF(HIGH, "%s\n", str); } +static inline void hex_dump_uint32_line(char * pref, const uint32_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%u, ", (unsigned int) x[i]); + } + FARF(HIGH, "%s\n", str); +} + static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { char str[1024], *p = str, *p_end = str + sizeof(str); p += snprintf(p, p_end - p, "%s: ", pref); diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index fb8a25a3f20..6239ceff4b4 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -3,8 +3,11 @@ #include #include +#include +#include #include "hexagon_types.h" +#include "hexagon_protos.h" #include "hex-fastdiv.h" #include "hex-dump.h" @@ -29,10 +32,30 @@ static inline uint64_t hex_get_pktcnt() { return pktcnt; } -static inline int32_t hex_is_aligned(void * addr, uint32_t align) { +static inline uint32_t hex_ceil_pow2(uint32_t x) { + if (x <= 1) { return 1; } + int p = 2; + x--; + while (x >>= 1) { p <<= 1; } + return p; +} + +static inline size_t hmx_ceil_div(size_t num, size_t den) { + return (num + den - 1) / den; +} + +static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { return ((size_t) addr & (align - 1)) == 0; } +static inline size_t hex_align_up(size_t v, size_t align) { + return hmx_ceil_div(v, align) * align; +} + +static inline size_t hex_align_down(size_t v, size_t align) { + return (v / align) * align; +} + static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { uint32_t left_off = (size_t) addr & (chunk_size - 1); uint32_t right_off = left_off + n; @@ -43,9 +66,72 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { return m * ((n + m - 1) / m); } +static inline size_t hex_smin(size_t a, size_t b) { + return a < b ? a : b; +} + +static inline size_t hex_smax(size_t a, size_t b) { + return a > b ? a : b; +} + +static inline void hex_swap_ptr(void ** p1, void ** p2) { + void * t = *p1; + *p1 = *p2; + *p2 = t; +} + static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); } +#define HEX_L2_LINE_SIZE 64 +#define HEX_L2_FLUSH_SIZE (128 * 1024) + +static inline void hex_l2flush(void * addr, size_t size) { + if (size > HEX_L2_FLUSH_SIZE) { + qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE); + } else { + const uint32_t s = (uint32_t) addr; + const uint32_t e = s + size; + for (uint32_t i = s; i < e; i += HEX_L2_LINE_SIZE * 4) { + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 0); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 1); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 2); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 3); + } + } +} + +static inline void hex_pause() { + asm volatile(" pause(#255)\n"); +} + +#ifndef HEX_NUM_PMU_COUNTERS +#define HEX_NUM_PMU_COUNTERS 8 +#endif + +static inline void hex_get_pmu(uint32_t counters[]) { +#if __HVX_ARCH__ >= 79 + asm volatile("%0 = upmucnt0" : "=r"(counters[0])); + asm volatile("%0 = upmucnt1" : "=r"(counters[1])); + asm volatile("%0 = upmucnt2" : "=r"(counters[2])); + asm volatile("%0 = upmucnt3" : "=r"(counters[3])); + asm volatile("%0 = upmucnt4" : "=r"(counters[4])); + asm volatile("%0 = upmucnt5" : "=r"(counters[5])); + asm volatile("%0 = upmucnt6" : "=r"(counters[6])); + asm volatile("%0 = upmucnt7" : "=r"(counters[7])); +#else + counters[0] = qurt_pmu_get(QURT_PMUCNT0); + counters[1] = qurt_pmu_get(QURT_PMUCNT1); + counters[2] = qurt_pmu_get(QURT_PMUCNT2); + counters[3] = qurt_pmu_get(QURT_PMUCNT3); + counters[4] = qurt_pmu_get(QURT_PMUCNT4); + counters[5] = qurt_pmu_get(QURT_PMUCNT5); + counters[6] = qurt_pmu_get(QURT_PMUCNT6); + counters[7] = qurt_pmu_get(QURT_PMUCNT7); + // qurt_pmu_get_pmucnt(counters); +#endif +} + #endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c new file mode 100644 index 00000000000..8a6d7c14edf --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -0,0 +1,1840 @@ +// HMX-accelerated Flash Attention for prefill (neq1 >= 32). +// Ported from htp-ops-lib/src/dsp/ops/flash_attn.c, adapted to the htp/ codebase. + +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "hex-dma.h" +#include "hmx-profile.h" +#include "hmx-queue.h" +#include "hmx-utils.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-dump.h" +#include "hvx-reduce.h" +#include "hvx-utils.h" +#include "vtcm-utils.h" +#include "worker-pool.h" + +// ============================================================================ +// Constants +// ============================================================================ + +// Tile constants from hmx-utils.h +// HMX_FP16_TILE_N_ROWS = 32 +// HMX_FP16_TILE_N_COLS = 32 +// HMX_FP16_TILE_N_ELMS = 1024 +// HMX_FP16_TILE_SIZE = 2048 + +// ============================================================================ +// Dynamic block size computation (GQA-aware) +// ============================================================================ + +// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration. +// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. +// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales +// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { + const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); + const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] + const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong + const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved + const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved + const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] + const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br] + const size_t col_vec_size = hex_align_up(g_br * sizeof(__fp16), 256); // m, l, etc. + const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096); + const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128); + + return q_tile_size * 1 // Q tiles + + o_tile_size * 2 // O ping-pong + + k_dma_size * 2 // K DMA x2 + + v_dma_size * 2 // V DMA x2 + + k_tile_size * 1 // K tiles + + v_tile_size * 1 // V tiles + + s_tile_size * 2 // S + P + + d_tile_size * 1 // D (diagonal matrix) + + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum + + row_vec_size * 2 * n_threads // per-thread softmax row scratch + + m_buf_size * 1 // mask VTCM buffer [Br rows] + + slopes_size // Slopes + + 256 * 2; // HMX scales (id + qk) +} + +// ============================================================================ +// FP16 exp2 polynomial (ported from htp-ops-lib/include/dsp/hvx_math.h) +// ============================================================================ +// 5th-order Horner polynomial for exp2(x) in qf16/hf16 domain. Input must be +// ≤ 0 (safe softmax invariant — overflow handling omitted). ~18 ALU ops per +// 64 fp16 lanes, fully parallel across HVX threads (no scatter/gather engine). +// Replaces the F32 round-trip (qf16→f32→exp→f32→f16, ~44 ops for 2×32 lanes). +static inline HVX_Vector hvx_exp2_hf(HVX_Vector x_v) { + const HVX_Vector zero_v = Q6_V_vzero(); + const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5 + + // k = round_toward_neg_inf(x); f = (float)k; frac = x - f + HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v)); + HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16 + HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16 + + HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16 + + // Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0 + HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0 + + // Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k + y = Q6_Vhf_equals_Vqf16(y); + HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11); + y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp); + HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp); + y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10); + return Q6_V_vmux_QVV(q_underflow, zero_v, y); +} + +#define FA_MIN_KV_BLOCKS 3 + +// Cost-based (Br, Bc) search for flash attention with pipeline constraint. +// +// VTCM model (same as before): +// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc +// +// Cost model (minimization objective): +// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc) +static int hmx_fa_find_chunk_size(size_t * Br_out, + size_t * Bc_out, + size_t gqa_factor, + size_t DK, + size_t DV, + size_t qo_len, + size_t kv_len, + size_t vtcm_budget, + size_t n_threads) { + const size_t T = HMX_FP16_TILE_N_ROWS; // 32 + const size_t br_unit = hmx_ceil_div(T, gqa_factor); + // Bc must be a multiple of 64 so that n_tiles_per_bc is even. The softmax + // P-tile write uses a dual-tile pattern (vshuff + two stores 16 slots apart) + // that would race across r0 blocks if the last dual-tile is half-occupied. + // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. + const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 + const size_t fp16 = sizeof(__fp16); + + // Approximate per-unit VTCM costs (without per-buffer alignment padding). + const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors + const size_t per_gbr2 = fp16; // D diagonal matrix + const size_t per_bc = + 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + const size_t per_gbr_bc = 2 * fp16; // S + P + + const size_t overhead = 256 * 2 + 13 * 4096; + + if (vtcm_budget <= overhead) { + return -1; + } + const size_t usable = vtcm_budget - overhead; + + // Br_max: largest Br aligned to br_unit that does not exceed qo_len. + const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit; + + // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. + // Only relax when kv_len is too short to form enough blocks. + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); + const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : + (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); + // Cost coefficients calibrated from profiling + const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store + const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers + + size_t best_cost = SIZE_MAX, best_mn = 0; + size_t best_Br = 0, best_Bc = 0; + + for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) { + const size_t g_br = hex_align_up(gqa_factor * Br, T); + + // g_br-dependent VTCM cost: g_br * per_gbr + g_br² * per_gbr2 + const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2; + if (gbr_cost >= usable) { + if (Br == br_unit) { + break; + } + continue; + } + + // Analytically solve for max Bc: + // remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16_mask) + // The Br * fp16 term accounts for the VTCM mask buffer [Br × Bc]. + const size_t remain = usable - gbr_cost; + const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16; + size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit); + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + // Exact VTCM verification (alignment padding may push over budget) + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + Bc -= bc_unit; + } + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + const size_t q_blocks = (qo_len + Br - 1) / Br; + const size_t kv_blocks = (kv_len + Bc - 1) / Bc; + const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed); + const size_t mn = Br * Bc; + + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_Br = Br; + best_Bc = Bc; + } + + if (Br == br_unit) { + break; + } + } + + if (best_Br == 0) { + return -1; + } + + *Br_out = best_Br; + *Bc_out = best_Bc; + return 0; +} + +// ============================================================================ +// Tile interleave / extract helpers +// ============================================================================ + +// transpose scatter offsets moved to hmx-utils.h as hmx_transpose_scatter_offsets + +// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6 +// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile +static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = { + 0 * 136, 0 * 136 + 6, + 1 * 136, 1 * 136 + 6, + 2 * 136, 2 * 136 + 6, + 3 * 136, 3 * 136 + 6, + 4 * 136, 4 * 136 + 6, + 5 * 136, 5 * 136 + 6, + 6 * 136, 6 * 136 + 6, + 7 * 136, 7 * 136 + 6, + 8 * 136, 8 * 136 + 6, + 9 * 136, 9 * 136 + 6, + 10 * 136, 10 * 136 + 6, + 11 * 136, 11 * 136 + 6, + 12 * 136, 12 * 136 + 6, + 13 * 136, 13 * 136 + 6, + 14 * 136, 14 * 136 + 6, + 15 * 136, 15 * 136 + 6, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, +}; + +// hmx_interleave_rows_to_tiles and hmx_interleave_cols_to_tiles are in hmx-utils.h + +// ============================================================================ +// HMX Flash Attention context (GQA-merged) +// ============================================================================ + +struct hmx_fa_context { + const struct htp_ops_context * octx; + bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 + uint32_t n_threads; + + // Op parameters + float scale; + float max_bias; + float logit_softcap; + uint32_t n_head_log2; + float m0, m1; + + // Dimensions + uint32_t DK, DV; + uint32_t n_kv; // kv_len + uint32_t n_kv_heads; // number of KV heads + uint32_t n_heads; // number of Q heads + uint32_t G; // GQA factor = n_heads / n_kv_heads + uint32_t n_kv_blocks; + uint32_t neq1; // Q token count + + // Types + bool is_q_fp32; + bool is_dst_fp32; + + // Dynamic block sizes + uint32_t Br; // Q tokens per block (before GQA expansion) + uint32_t Bc; + uint32_t g_br; // hex_align_up(G * Br, 32) - actual tile row dim + + // VTCM buffers (allocated by vtcm_seq_alloc) + __fp16 * vtcm_q_tiles; // Q tile format [g_br, D] + __fp16 * vtcm_o_tiles[2]; // O ping-pong [g_br, D] + __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] + __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] + __fp16 * vtcm_k_tiles; // K tiles (transposed) + __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] + __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] + __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] + HVX_Vector * vtcm_m_vec; // Row max [g_br] + HVX_Vector * vtcm_l_vec; // Row sum [g_br] + HVX_Vector * vtcm_s_rowmax; // Softmax intermediate [g_br] + HVX_Vector * vtcm_p_rowsum; // Softmax intermediate [g_br] + HVX_Vector * vtcm_row_bufs; // Per-thread softmax row scratch [n_threads][2][Bc/64] + uint8_t * vtcm_hmx_scales_id; // HMX output scales (identity) + uint8_t * vtcm_hmx_scales_qk; // HMX output scales (qk_scale) + __fp16 * vtcm_mask_buf; // VTCM mask buffer [Br × m_line], DMA'd per KV block + __fp16 * vtcm_slopes; // ALiBi slopes [g_br] + size_t row_buf_stride; // HVX vectors per row buffer (Bc/64) + size_t mask_buf_row_stride; // elements (__fp16) per row in mask buffer + bool mask_broadcast; // true when mask->ne[2] == 1 (head-independent, single 2D DMA) +}; + +// ============================================================================ +// Multi-thread K interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; +} fa_k_int_args_t; + +static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_k_int_args_t * args = (fa_k_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); // ensure even (row pairs) + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK, + (int) args->src_stride, start, end); +} + +static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_k_int_args_t args = { factx, kv_rows, src_stride, buf_idx }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_k_interleave_thread, &args, factx->n_threads); + } else { + fa_k_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread V interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; + size_t n_col_tiles; +} fa_v_int_args_t; + +static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_v_int_args_t * args = (fa_v_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + (int) args->src_stride, (int) args->n_col_tiles, start, end); +} + +static void fa_phase_v_interleave(struct hmx_fa_context * factx, + int kv_rows, + size_t src_stride, + size_t buf_idx, + size_t n_col_tiles) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_v_int_args_t args = { factx, kv_rows, src_stride, buf_idx, n_col_tiles }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_v_interleave_thread, &args, factx->n_threads); + } else { + fa_v_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread Q load phase: read Q[G × neq1, DK] from DDR, convert F32→F16 +// (or deal F16 pairs), and write interleaved into vtcm_q_tiles. +// Each thread owns a disjoint range of row pairs; writes target distinct tile +// slots (r0 selects tile row, r1 selects intra-tile slot), so there is no +// write conflict. Padding fill (when n_rows_g < g_br) is done single-threaded +// by the caller before dispatching. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * q; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_q_load_args_t; + +static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { + fa_q_load_args_t * args = (fa_q_load_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DK = factx->DK; + + // Partition row pairs across threads. Keep each thread's start even so r/r+1 + // are always in the same thread's range. + const size_t rows_per_t = hex_align_up(hmx_ceil_div(n_rows_g, n), 2); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * q = args->q; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; r += 2) { + const bool next_row_valid = (r + 1) < n_rows_g; + + const size_t q_idx0 = (r + 0) / G; + const size_t h_idx0 = (r + 0) % G; + const size_t q_idx1 = (r + 1) / G; + const size_t h_idx1 = (r + 1) % G; + + const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; + const uint8_t * q_ptr1 = next_row_valid ? ((const uint8_t *) q->data + (q_start + q_idx1) * q->nb[1] + + (kv_head * G + h_idx1) * q->nb[2] + ib3 * q->nb[3]) : + NULL; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + __fp16 * out_base = factx->vtcm_q_tiles + r0 * HMX_FP16_TILE_N_ROWS * DK; + + if (factx->is_q_fp32) { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 32; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_Vector v_hf = hvx_vec_f32_to_f16_shuff(v0, v1); + + HVX_Vector * out_tile = (HVX_Vector *) (out_base + d * HMX_FP16_TILE_N_ELMS); + out_tile[r1 / 2] = v_hf; + } + } else { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 64; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + + __fp16 * out_dual_tile = out_base + d * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_out1 = pv_out0 + 16; + + *pv_out0 = Q6_V_lo_W(vp); + *pv_out1 = Q6_V_hi_W(vp); + } + } + } +} + +static void fa_phase_q_load(struct hmx_fa_context * factx, + const struct htp_tensor * q, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_q_load_args_t args = { factx, q, q_start, kv_head, ib3, n_rows_g }; + // Require >= 2 row pairs per thread so partitioning is worthwhile. + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_q_load_thread, &args, factx->n_threads); + } else { + fa_q_load_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread O store phase: read O tiles from VTCM, convert F16->F32 (or +// deal F16 pairs), and write to strided DDR dst tensor. Each thread owns a +// disjoint row range; writes target distinct dst rows (different q_idx/h_idx +// pairs produced by r/G and r%G), so there is no write conflict. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * dst; + const __fp16 * o_tile_src; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_o_store_args_t; + +static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { + fa_o_store_args_t * args = (fa_o_store_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DV = factx->DV; + + const size_t rows_per_t = hmx_ceil_div(n_rows_g, n); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * dst = args->dst; + const __fp16 * o_tile_src = args->o_tile_src; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; ++r) { + const size_t q_idx = r / G; + const size_t h_idx = r % G; + + // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> + // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. + uint8_t * dst_row = (uint8_t *) dst->data + (kv_head * G + h_idx) * dst->nb[1] + + (q_start + q_idx) * dst->nb[2] + ib3 * dst->nb[3]; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + const __fp16 * tile_row_base = o_tile_src + r0 * HMX_FP16_TILE_N_ROWS * DV; + + if (factx->is_dst_fp32) { + float * out = (float *) dst_row; + for (uint32_t d = 0; d < DV / 32; ++d) { + const HVX_Vector * in_tile = (const HVX_Vector *) (tile_row_base + d * HMX_FP16_TILE_N_ELMS); + HVX_VectorPair vp = hvx_vec_f16_to_f32_shuff(in_tile[r1 / 2]); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 32) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 32) = Q6_V_hi_W(vp); + } + } + } else { + __fp16 * out = (__fp16 *) dst_row; + for (uint32_t d = 0; d < DV / 64; ++d) { + const __fp16 * in_dual_tile = tile_row_base + d * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_in1 = pv_in0 + 16; + HVX_VectorPair vp = Q6_W_vdeal_VVR(*pv_in1, *pv_in0, -2); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 64) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 64) = Q6_V_hi_W(vp); + } + } + } + } +} + +static void fa_phase_o_store(struct hmx_fa_context * factx, + const struct htp_tensor * dst, + const __fp16 * o_tile_src, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_o_store_args_t args = { factx, dst, o_tile_src, q_start, kv_head, ib3, n_rows_g }; + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_o_store_thread, &args, factx->n_threads); + } else { + fa_o_store_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread softmax phase + serial m/l update + build_D +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + size_t kv_rows; + size_t n_rows_g; + size_t n_col_tiles; + size_t n_tiles_per_bc; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + uint32_t Bc; + uint32_t G; + uint32_t kv_head; + uint32_t kv_start; + uint32_t q_start; + uint32_t ib3; + bool has_alibi; // true when max_bias != 0 (need slope * mask + add) + + // ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g)) + // slope[r] = 1.0 when max_bias == 0 (no ALiBi) + // Pointer into hmx_fa_context.vtcm_slopes (sized to g_br) + __fp16 * slopes; + + // Mask info (preloaded before softmax) + const struct htp_tensor * mask; + const __fp16 * mask_vtcm; // VTCM mask buffer base (NULL = DDR fallback) + size_t mask_vtcm_row_stride; // elements (__fp16) per row in VTCM mask buffer +} fa_softmax_args_t; + +static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { + fa_softmax_args_t * args = (fa_softmax_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t kv_rows = args->kv_rows; + const size_t Bc = args->Bc; + const size_t G = args->G; + const size_t n_tiles_per_bc = args->n_tiles_per_bc; + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + + // Partition r_vec_idx across threads + const size_t vecs_per_t = hmx_ceil_div(n_row_vec_cnt, n); + const size_t vec_start = i * vecs_per_t; + const size_t vec_end = hex_smin(vec_start + vecs_per_t, n_row_vec_cnt); + + if (vec_start >= n_row_vec_cnt) { + return; + } + + // Per-thread row scratch: thread i uses bufs at offset i * 2 * stride + const size_t row_buf_stride = factx->row_buf_stride; + HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride; + HVX_Vector * my_row_buf1 = my_row_buf0 + row_buf_stride; + + const HVX_Vector v_neg_inf = Q6_Vh_vsplat_R(0xfbff); + + // Per-row accumulators: each fp16 lane in a 64-lane vector holds one row's scalar. + // CONTRACT: lane bits must be IEEE fp16 (hf), never qf16 — qf16 uses a different + // bit layout, so a later hf-domain read would silently produce wrong values. + // Convert first via Q6_Vhf_equals_Vqf16(). For reference: vtcm_m_vec/vtcm_s_rowmax + // are hf; vtcm_l_vec is qf16 — don't mix them up. + + for (size_t r_vec_idx = vec_start; r_vec_idx < vec_end; ++r_vec_idx) { + HVX_Vector rowmax_acc_v = v_neg_inf; + HVX_Vector rowsum_acc_v = Q6_V_vzero(); + HVX_Vector m_prev_v = factx->vtcm_m_vec[r_vec_idx]; + + for (int r_vec_off = 0; r_vec_off < 64; r_vec_off += 2) { + int r = r_vec_idx * 64 + r_vec_off; + if (r >= (int) hex_align_up(n_rows_g, 2)) { + break; + } + + int r0 = r / HMX_FP16_TILE_N_ROWS; + int r1 = r % HMX_FP16_TILE_N_ROWS; + + const __fp16 * s_ld_base = factx->vtcm_s_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + __fp16 * p_st_base = factx->vtcm_p_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + + // Decode 2 rows from S tiles into per-thread row buffers + HVX_Vector * pv_row_buf0 = my_row_buf0; + HVX_Vector * pv_row_buf1 = my_row_buf1; + for (size_t c = 0; c < kv_rows; c += 64) { + const __fp16 * in_dual_tile = s_ld_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_s_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_s_in1 = pv_s_in0 + 16; + + HVX_VectorPair vp_s_dual_row = Q6_W_vdeal_VVR(*pv_s_in1, *pv_s_in0, -2); + *pv_row_buf0++ = Q6_V_lo_W(vp_s_dual_row); + *pv_row_buf1++ = Q6_V_hi_W(vp_s_dual_row); + } + + // Apply softcap if enabled (in F32 precision) + if (factx->logit_softcap != 0.0f) { + // When EXP2_HF is on, fold log2(e) into v_cap so the output lands in + // log2(e)-scaled space for the downstream exp2. log2(e) is kept OUT + // of qk_scale in this configuration (see scale setup) so tanh sees + // the physical QK/(√d·c) argument. + float cap = factx->logit_softcap; +#ifdef HMX_FA_USE_EXP2_HF + cap *= 1.44269504f; // log2(e) +#endif + const HVX_Vector v_cap = hvx_vec_splat_f32(cap); + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + + HVX_VectorPair r0_f32 = hvx_vec_f16_to_f32(my_row_buf0[ci]); + HVX_Vector t0_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r0_f32)); + HVX_Vector t0_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r0_f32)); + t0_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_lo, v_cap)); + t0_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_hi, v_cap)); + my_row_buf0[ci] = hvx_vec_f32_to_f16(t0_lo, t0_hi); + + HVX_VectorPair r1_f32 = hvx_vec_f16_to_f32(my_row_buf1[ci]); + HVX_Vector t1_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r1_f32)); + HVX_Vector t1_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r1_f32)); + t1_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_lo, v_cap)); + t1_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_hi, v_cap)); + my_row_buf1[ci] = hvx_vec_f32_to_f16(t1_lo, t1_hi); + } + } + + // Apply mask & compute rowmax(S) + // + // Optimizations over baseline: + // A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), skip the + // slope multiplication — still add mask (additive bias) but + // avoid the mul_f16_f16. Saves 2 ops/dual-row vs ALiBi path. + // B. GQA mask row dedup: G consecutive Q rows share one mask row + // (qi = r / G). Reuse mask vector when qi is unchanged between + // row0 and row1 (saves ~75% of VTCM loads for G=4). + + // ALiBi slopes — only needed when has_alibi (scheme A) + HVX_Vector v_slope0, v_slope1; + if (args->has_alibi) { + v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]); + v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero(); + } + + const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) + + HVX_Vector v_s_rowmax0 = v_neg_inf; + HVX_Vector v_s_rowmax1 = v_neg_inf; + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + const size_t ne = hex_smin(kv_rows - c, 64); + HVX_VectorPred q_tail_keep = Q6_Q_vsetq2_R(ne * sizeof(__fp16)); + + if (args->mask) { + HVX_Vector v_mask0, v_mask1; + + if (args->mask_vtcm) { + // Read mask from VTCM buffer (DMA'd per KV block). + // GQA dedup (scheme B): skip load when qi unchanged. + const size_t qi0 = (r + 0) / G; + v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); + v_mask1 = v_neg_inf; + if (r + 1 < (int) n_rows_g) { + const size_t qi1 = (r + 1) / G; + if (qi1 == qi0) { + v_mask1 = v_mask0; // scheme B: reuse — same mask row + } else { + v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c); + } + } + } else { + // Fallback: read mask directly from DDR (when mask->ne[2] > 1). + const struct htp_tensor * mask = args->mask; + const size_t q_idx0 = args->q_start + ((r + 0) / G); + const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const uint32_t im2_0 = h_idx0 % mask->ne[2]; + const uint32_t im3_0 = args->ib3 % mask->ne[3]; + + const __fp16 * m0_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx0 * mask->nb[1] + + im2_0 * mask->nb[2] + im3_0 * mask->nb[3]) + args->kv_start + c; + v_mask0 = *(const HVX_UVector *) m0_ptr; + v_mask1 = v_neg_inf; + + if (r + 1 < (int) n_rows_g) { + const size_t q_idx1 = args->q_start + ((r + 1) / G); + if (q_idx1 == q_idx0) { + // scheme B: same mask row in DDR path + v_mask1 = v_mask0; + } else { + const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const uint32_t im2_1 = h_idx1 % mask->ne[2]; + const uint32_t im3_1 = args->ib3 % mask->ne[3]; + const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + + im2_1 * mask->nb[2] + im3_1 * mask->nb[3]) + args->kv_start + c; + v_mask1 = *(const HVX_UVector *) m1_ptr; + } + } + } + + // Threshold: mask values below -16.0 are treated as -inf (causal mask). + HVX_VectorPred q_keep0 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep); + HVX_VectorPred q_keep1 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep); + + if (args->has_alibi) { + // ALiBi path: S += slope * mask (full mul + add) + HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0); + HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1); + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf); + } else { + // No-ALiBi fast path (scheme A): slope≡1.0, skip the mul + // but still add mask (additive positional bias). vmux + // clamps mask < -16 to -inf as a numerical safeguard. + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_mask0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_mask1), v_neg_inf); + } + } else { + if (ne < 64) { + my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf1[ci], v_neg_inf); + } + } + + v_s_rowmax0 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax0, my_row_buf0[ci]); + v_s_rowmax1 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax1, my_row_buf1[ci]); + } + + v_s_rowmax0 = hvx_vec_reduce_max_f16(v_s_rowmax0); + v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); + + // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. + // vror brings the target lane to lane 0, then extract + re-splat. + HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); + HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); + + // HVX max — both operands are splats, so result is splat of m_new. + HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); + HVX_Vector v_dup_m1 = Q6_Vhf_vmax_VhfVhf(v_m_prev1, v_s_rowmax1); + + // Insert row r, r+1 rowmax into rowmax_acc_v via 2-byte-wide vmux. + // Byte ranges: lane0 = [r_vec_off*2 .. r_vec_off*2+1], lane1 shifted by 2. + // vsetq2 handles the n=128 corner case when r_vec_off reaches 62. + { + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane0, v_dup_m0, rowmax_acc_v); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane1, v_dup_m1, rowmax_acc_v); + } + + // Compute P = exp(S - m_new), using HVX exp + const HVX_Vector v_zero = Q6_V_vzero(); + HVX_Vector v_p_rowsum0 = v_zero; + HVX_Vector v_p_rowsum1 = v_zero; + +#ifdef HMX_FA_USE_EXP2_HF + // FP16 exp2 polynomial path (matches htp-ops-lib flash_attn.c): + // P = exp2(S - m_new) + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_Vector v_p_row0_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector v_p_row1_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); +#else + // F32 exp path: qf16 → f32 → exp → f32 → f16. Higher precision, + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_VectorPair vp0 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector p0_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp0)); + HVX_Vector p0_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp0)); + HVX_Vector v_p_row0_hf = hvx_vec_f32_to_f16_shuff(p0_lo, p0_hi); + + HVX_VectorPair vp1 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); + HVX_Vector p1_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp1)); + HVX_Vector p1_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp1)); + HVX_Vector v_p_row1_hf = hvx_vec_f32_to_f16_shuff(p1_lo, p1_hi); +#endif + // Write P to tile format. Dual-tile pattern assumes Bc is a + // multiple of 64 (enforced by bc_unit=64 in hmx_fa_find_chunk_size), + // so both tile halves are always in the current r0 block. + __fp16 * out_dual_tile = p_st_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_p_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_p_out1 = pv_p_out0 + 16; + + HVX_VectorPair vp_p_dual = Q6_W_vshuff_VVR(v_p_row1_hf, v_p_row0_hf, -2); + *pv_p_out0 = Q6_V_lo_W(vp_p_dual); + *pv_p_out1 = Q6_V_hi_W(vp_p_dual); + + HVX_VectorPair vp_p0 = hvx_vec_f16_to_f32_shuff(v_p_row0_hf); + HVX_VectorPair vp_p1 = hvx_vec_f16_to_f32_shuff(v_p_row1_hf); + + v_p_rowsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum0, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p0), Q6_V_hi_W(vp_p0))); + v_p_rowsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum1, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p1), Q6_V_hi_W(vp_p1))); + } + + HVX_Vector rowsum0_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum0)); + HVX_Vector rowsum1_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum1)); + { + // Both inputs are f32 splats, so the f32->f16 output is an fp16 splat. + HVX_Vector rv0_v = hvx_vec_f32_to_f16(rowsum0_sf, rowsum0_sf); + HVX_Vector rv1_v = hvx_vec_f32_to_f16(rowsum1_sf, rowsum1_sf); + + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane0, rv0_v, rowsum_acc_v); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane1, rv1_v, rowsum_acc_v); + } + } + + factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v; + factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v; + } +} + +// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads). +// +// noinline: function boundary acts as a hard compiler barrier so the (size_t)addr scatter +// intrinsics inside cannot be hoisted past the call site. Mirrors the structural protection +// matmul gets for free via worker_pool function-pointer dispatch. Without this, the compiler +// can reorder the scatter past the subsequent hmx_queue_push and the HMX-queue worker thread +// reads stale VTCM (PPL → ~vocab-size). +static __attribute__((noinline)) void fa_ml_update_and_build_d(struct hmx_fa_context * factx, + size_t n_rows_g, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + // Reuse s_rowmax buffer for exp(m_diff) — safe because softmax is fully complete + HVX_Vector * const mvec_exp_m_diff = factx->vtcm_s_rowmax; + + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + for (size_t i = 0; i < n_row_vec_cnt; ++i) { + HVX_Vector v_m_prev = factx->vtcm_m_vec[i]; + HVX_Vector v_m_curr = Q6_Vhf_vmax_VhfVhf(v_m_prev, factx->vtcm_s_rowmax[i]); + HVX_Vector v_m_diff = Q6_Vqf16_vsub_VhfVhf(v_m_prev, v_m_curr); + +#ifdef HMX_FA_USE_EXP2_HF + // Base-2 path: must match P = exp2(S - m_new) in fa_softmax_thread. + HVX_Vector v_exp_m_diff = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_m_diff)); +#else + HVX_VectorPair vp_diff = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_m_diff)); + HVX_Vector exp_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp_diff)); + HVX_Vector exp_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp_diff)); + HVX_Vector v_exp_m_diff = hvx_vec_f32_to_f16_shuff(exp_lo, exp_hi); +#endif + + HVX_Vector v_l_curr = Q6_Vqf16_vmpy_Vqf16Vhf(factx->vtcm_l_vec[i], v_exp_m_diff); + v_l_curr = Q6_Vqf16_vadd_Vqf16Vhf(v_l_curr, factx->vtcm_p_rowsum[i]); + + factx->vtcm_m_vec[i] = v_m_curr; + factx->vtcm_l_vec[i] = v_l_curr; + mvec_exp_m_diff[i] = v_exp_m_diff; + } + + // Build diagonal tile D = diag(exp(m_diff)) + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + for (size_t i = 0; i < n_row_tiles; ++i) { + const HVX_Vector v_content = Q6_V_vror_VR(mvec_exp_m_diff[i / 2], (i % 2) * 64); + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — Q6_vscatter takes (size_t)addr; without this the + // compiler may not recognize the volatile read below as aliasing and + // could reorder it before the scatter, defeating the HW drain. + __asm__ __volatile__("" ::: "memory"); + // Per-tile drain: scatter regions are disjoint (stride > tile size), + // so a single drain at tile 0 does NOT retire later tiles' entries. + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Build D = diag(1/l) tile for the final O = D @ O normalization. +// +// noinline: same rationale as fa_ml_update_and_build_d — keeps Q6_vscatter from +// being hoisted past the subsequent hmx_queue_push at the o_norm call site. +static __attribute__((noinline)) void fa_build_d_diag_inv_l(struct hmx_fa_context * factx, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + const HVX_Vector one = hvx_vec_splat_f32(1.0f); + + HVX_Vector v_content = Q6_V_vzero(); + for (size_t i = 0; i < n_row_tiles; ++i) { + if ((i % 2) == 0) { + HVX_Vector v_l_hf = Q6_Vhf_equals_Vqf16(factx->vtcm_l_vec[i / 2]); + HVX_VectorPair vp_l = hvx_vec_f16_to_f32_shuff(v_l_hf); + HVX_Vector inv_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_lo_W(vp_l)))); + HVX_Vector inv_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_hi_W(vp_l)))); + v_content = hvx_vec_f32_to_f16_shuff(inv_lo, inv_hi); + } else { + v_content = Q6_V_vror_VR(v_content, 64); + } + + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — see fa_ml_update_and_build_d for rationale. + __asm__ __volatile__("" ::: "memory"); + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Combined: multi-thread softmax -> barrier -> serial m/l update + build_D +static void fa_phase_softmax_and_build_d(struct hmx_fa_context * factx, + fa_softmax_args_t * sargs, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + const size_t n_row_vec_cnt = hmx_ceil_div(sargs->n_rows_g, 64); + + if (factx->n_threads > 1 && n_row_vec_cnt >= 2) { + uint32_t n_use = (uint32_t) hex_smin((size_t) factx->n_threads, n_row_vec_cnt); + worker_pool_run_func(wp, fa_softmax_thread, sargs, n_use); + } else { + fa_softmax_thread(1, 0, sargs); + } + // barrier implicit in worker_pool_run_func return + + fa_ml_update_and_build_d(factx, sargs->n_rows_g, n_row_tiles, n_row_tiles_g_br); +} + +// ============================================================================ +// HMX job structs and worker functions +// ============================================================================ + +typedef struct { + const __fp16 * q_tiles; + const __fp16 * k_tiles; + __fp16 * s_tiles; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_dot_tiles; // DK / 32 + size_t n_tiles_per_bc; + uint8_t * hmx_scales; +} hmx_fa_qk_job_t; + +static void hmx_fa_qk_dot_worker(void * data) { + hmx_fa_qk_job_t * job = (hmx_fa_qk_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_dot_tiles = job->n_dot_tiles; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const __fp16 * restrict q_tiles = job->q_tiles; + const __fp16 * restrict k_tiles = job->k_tiles; + __fp16 * restrict s_tiles = job->s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_tiles + r * HMX_FP16_TILE_N_ROWS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + const __fp16 * col_tiles = k_tiles + c * HMX_FP16_TILE_N_COLS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + __fp16 * out_tile = s_tiles + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; + const __fp16 * o_prev; + const __fp16 * p_tiles; + const __fp16 * v_tiles; + const __fp16 * d_tiles; + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_row_tiles_g_br; + size_t n_tiles_per_bc; + size_t DV; +} hmx_fa_o_update_job_t; + +static void hmx_fa_o_update_worker(void * data) { + hmx_fa_o_update_job_t * job = (hmx_fa_o_update_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict p_tiles = job->p_tiles; + const __fp16 * restrict v_tiles = job->v_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + // D[r,r] @ O_prev[r,c] — only the diagonal tile + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + // P @ V (accumulate on same accumulator) + const __fp16 * p_tile_in = p_tiles + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_tiles + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = o_curr + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; // output (row-major tile layout) + const __fp16 * o_prev; // input (column-major tile layout) + const __fp16 * d_tiles; // diag(1/l) tiles + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + size_t DV; +} hmx_fa_o_norm_job_t; + +static void hmx_fa_o_norm_worker(void * data) { + hmx_fa_o_norm_job_t * job = (hmx_fa_o_norm_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = o_curr + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } +} + +// Populate per-GQA-row ALiBi slopes for a given KV head. +// Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. +// slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). +// When max_bias == 0, all slopes are 1.0 (no ALiBi). +static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, + const struct hmx_fa_context * factx, + uint32_t kv_head, + size_t n_rows_g) { + if (factx->max_bias == 0.0f) { + for (size_t r = 0; r < n_rows_g; ++r) { + sargs->slopes[r] = 1.0f; + } + return; + } + + const uint32_t G = factx->G; + const uint32_t n_head_log2 = factx->n_head_log2; + const float m0 = factx->m0; + const float m1 = factx->m1; + + for (size_t r = 0; r < n_rows_g; ++r) { + const uint32_t h = kv_head * G + r % G; + sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + } +} + +// ============================================================================ +// Core HMX flash attention algorithm (GQA-merged) +// ============================================================================ + +int hmx_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = (octx->src[3] && octx->src[3]->data) ? octx->src[3] : NULL; + const struct htp_tensor * dst = octx->dst; + + struct htp_context * const ctx = octx->ctx; + + if (!ctx->hmx_enabled) { + return HTP_STATUS_NO_SUPPORT; + } + + // Dimensions + const uint32_t neq0 = q->ne[0]; // head_dim (DK) + const uint32_t neq1 = q->ne[1]; // n_tokens + const uint32_t neq2 = q->ne[2]; // n_heads + const uint32_t neq3 = q->ne[3]; // n_seqs + + const uint32_t nek0 = k->ne[0]; // head_dim + const uint32_t nek1 = k->ne[1]; // kv_len + + const uint32_t nev0 = v->ne[0]; // head_dim (DV) + + const uint32_t DK = neq0; + const uint32_t DV = nev0; + + // HMX requires head_dim to be multiple of 32 + if (DK % 32 != 0 || DV % 32 != 0) { + return HTP_STATUS_NO_SUPPORT; + } + if (neq1 < 32) { + return HTP_STATUS_NO_SUPPORT; + } + + // GQA factor + const uint32_t n_kv_heads = k->ne[2]; + const uint32_t G = neq2 / n_kv_heads; + + // Thread count for multi-thread HVX phases + const uint32_t n_threads = octx->n_threads; + + // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) + size_t Br, Bc; + const size_t vtcm_budget = ctx->vtcm_size; + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); + + const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + + FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", + neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); + + // ======== Build context ======== + struct hmx_fa_context factx; + memset(&factx, 0, sizeof(factx)); + factx.octx = octx; + factx.n_threads = octx->ctx->n_threads; + factx.DK = DK; + factx.DV = DV; + factx.n_kv = nek1; + factx.n_kv_heads = n_kv_heads; + factx.n_heads = neq2; + factx.G = G; + factx.neq1 = neq1; + factx.Br = (uint32_t) Br; + factx.Bc = (uint32_t) Bc; + factx.g_br = (uint32_t) g_br; + factx.n_kv_blocks = n_kv_blocks; + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); + factx.use_pipeline = use_pipeline; + factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); + + // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) + float scale = 1.0f, max_bias = 0.0f, logit_softcap = 0.0f; + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + +#ifdef HMX_FA_USE_EXP2_HF + // Pre-bake log2(e) into qk_scale so HMX-produced S tiles are in log2(e)-scaled + // space. Then exp2(S - m) in the softmax equals base-e exp((S - m) / log2(e)), + // preserving ggml's base-e softmax semantics. Matches htp-ops-lib flash_attn.c. + // + // When softcap is active we cannot pre-bake log2(e) here — it would land inside + // the tanh argument and shift the softcap knee from x≈c to x≈c/log2(e), giving + // numerically wrong softcapped values. Instead fold log2(e) into the post-tanh + // multiplier (see softcap block: v_cap absorbs log2(e)). + if (logit_softcap == 0.0f) { + scale *= 1.44269504f; // log2(e) + } +#endif + + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; + + factx.n_head_log2 = 1u << (uint32_t) floor(log2(neq2)); + factx.m0 = powf(2.0f, -(max_bias) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + + // ======== VTCM allocation (GQA-aware) ======== + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); + const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); + const size_t d_tile_bytes = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); + const size_t col_vec_bytes = hex_align_up(g_br * sizeof(__fp16), 256); + const size_t row_vec_bytes = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_bytes = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_bytes = hex_align_up(Br * m_line_bytes, 4096); + const size_t slopes_bytes = hex_align_up(g_br * sizeof(__fp16), 128); + + uint8_t * vtcm_cur = ctx->vtcm_base; + + factx.vtcm_q_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, q_tile_bytes); + factx.vtcm_o_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_o_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_k_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_k_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); + factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); + factx.vtcm_m_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_l_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_s_rowmax = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_p_rowsum = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_row_bufs = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, row_vec_bytes * 2 * n_threads); + factx.row_buf_stride = row_vec_bytes / sizeof(HVX_Vector); + factx.vtcm_hmx_scales_id = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_hmx_scales_qk = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_mask_buf = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, m_buf_bytes); + factx.mask_buf_row_stride = m_line_bytes / sizeof(__fp16); + factx.vtcm_slopes = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, slopes_bytes); + + if ((size_t) (vtcm_cur - ctx->vtcm_base) > ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + // ======== Initialize HMX output scales ======== + // Identity scale (1.0) for O updates and normalization + hmx_init_column_scales(factx.vtcm_hmx_scales_id, Q6_V_vsplat_R(0x3c00)); // 1.0 + + // QK scale embedded in HMX output + hmx_init_column_scales(factx.vtcm_hmx_scales_qk, hvx_vec_splat_f16(factx.scale)); + + // ======== Skip compute if profiling ======== + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // Profiling timers + TIMER_DEFINE(total); + TIMER_DEFINE(q_load); + TIMER_DEFINE(kv_dma); + TIMER_DEFINE(k_interleave); + TIMER_DEFINE(v_interleave); + TIMER_DEFINE(qk_dot); + TIMER_DEFINE(softmax); + TIMER_DEFINE(o_update); + TIMER_DEFINE(o_norm); + TIMER_DEFINE(o_store); + + TIMER_START(total); + + // ======== DMA setup ======== + dma_queue * const dma = ctx->dma[0]; + + // Padded row sizes for DMA + const size_t size_k_row = nek0 * sizeof(__fp16); + const size_t size_v_row = nev0 * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); + const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); + + const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; + const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; + + // Q/O element size for Q load and O store + const size_t qo_element_size = factx.is_q_fp32 ? sizeof(float) : sizeof(__fp16); + + // ======== HMX lock strategy ======== + // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. + // Fallback: main thread holds the lock (original behavior). + if (!factx.use_pipeline) { + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + } + + // ======== Reusable job descriptors for pipeline ======== + hmx_fa_qk_job_t qk_job; + hmx_fa_o_update_job_t ou_job; + hmx_fa_o_norm_job_t on_job; + + // ======== Main loop: per batch, per KV head, per Q block ======== + for (uint32_t ib3 = 0; ib3 < neq3; ++ib3) { + for (uint32_t kv_head = 0; kv_head < n_kv_heads; ++kv_head) { + const uint32_t ik2 = kv_head; + const uint32_t ik3 = ib3 / (neq3 / k->ne[3]); + const uint32_t iv2 = kv_head; + const uint32_t iv3 = ib3 / (neq3 / v->ne[3]); + + for (uint32_t q_start = 0; q_start < neq1; q_start += Br) { + const uint32_t n_q_rows = hex_smin(Br, neq1 - q_start); + const size_t n_rows_g = n_q_rows * G; + const size_t g_br_actual = hex_align_up(n_rows_g, HMX_FP16_TILE_N_ROWS); + const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS; + + // ---- Load Q block [g_br, D] -> tiles, interleaving G heads ---- + TIMER_START(q_load); + if (n_rows_g < g_br) { + hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes); + } + fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(q_load); + + // ---- Initialize per-block state ---- + hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes); + hvx_splat_u8_a(factx.vtcm_d_tiles, 0, d_tile_bytes); + hvx_splat_u16_a(factx.vtcm_m_vec, 0xfbff, col_vec_bytes/2); + + __fp16 * o_tile_prev = factx.vtcm_o_tiles[0]; + __fp16 * o_tile_curr = factx.vtcm_o_tiles[1]; + hvx_splat_u8_a(o_tile_prev, 0, o_tile_bytes); + + // ---- KV block loop with DMA double-buffering ---- + size_t buf_idx = 0; + + // Prefetch first KV block + if (factx.n_kv_blocks > 0) { + const uint32_t kv_rows0 = hex_smin(Bc, nek1); + + const uint8_t * k_src = (const uint8_t *) k->data + ik2 * k->nb[2] + ik3 * k->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[0], k_src), size_k_row_padded, k->nb[1], + size_k_row, kv_rows0); + + const uint8_t * v_src = (const uint8_t *) v->data + iv2 * v->nb[2] + iv3 * v->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[0], v_src), size_v_row_padded, v->nb[1], + size_v_row, kv_rows0); + } + + // Mask DMA: single 2D transfer of n_q_rows unique mask rows into VTCM buffer. + // Only when mask is head-broadcast (ne[2]==1); otherwise softmax reads DDR directly. + #define MASK_DMA_PUSH(kv_start_val, kv_rows_val, has_mask_dma_var) \ + do { \ + has_mask_dma_var = false; \ + if (mask && factx.mask_broadcast) { \ + const uint32_t _im3 = ib3 % mask->ne[3]; \ + const uint8_t * _ms = (const uint8_t *) mask->data + q_start * mask->nb[1] + _im3 * mask->nb[3] + \ + (kv_start_val) * sizeof(__fp16); \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_mask_buf, _ms), m_line_bytes, mask->nb[1], \ + (kv_rows_val) * sizeof(__fp16), n_q_rows); \ + has_mask_dma_var = true; \ + } \ + } while (0) + + #define MASK_DMA_POP(has_mask_dma_var) \ + do { \ + if (has_mask_dma_var) { \ + dma_queue_pop(dma); \ + } \ + } while (0) + + #define DMA_PREFETCH_KV(blk_val) \ + do { \ + if ((blk_val) < factx.n_kv_blocks) { \ + const uint32_t _ns = (blk_val) * Bc; \ + const uint32_t _nr = hex_smin(Bc, nek1 - _ns); \ + size_t _nb = 1 - buf_idx; \ + const uint8_t * _ks = (const uint8_t *) k->data + _ns * k->nb[1] + ik2 * k->nb[2] + ik3 * k->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[_nb], _ks), size_k_row_padded, k->nb[1], size_k_row, _nr); \ + const uint8_t * _vs = (const uint8_t *) v->data + _ns * v->nb[1] + iv2 * v->nb[2] + iv3 * v->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[_nb], _vs), size_v_row_padded, v->nb[1], size_v_row, _nr); \ + } \ + } while (0) + + const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); + const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); + + if (factx.use_pipeline) { + // ================================================================== + // Pipeline path: HVX phases ‖ HMX queue worker + // ================================================================== + struct hmx_queue * hmx_q = ctx->hmx_queue; + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + // Wait for current KV DMA + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + // Push mask DMA for this block (single 2D DMA when broadcast) + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + + // ---- Phase 1: K_int(blk) ‖ O_update(blk-1) ---- + if (kv_blk > 0) { + // Submit O_update for previous block (HMX worker) + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = hmx_ceil_div(hex_smin(Bc, nek1 - (kv_blk - 1) * Bc), HMX_FP16_TILE_N_COLS); + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + } + + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- + qk_job.q_tiles = factx.vtcm_q_tiles; + qk_job.k_tiles = factx.vtcm_k_tiles; + qk_job.s_tiles = factx.vtcm_s_tiles; + qk_job.n_row_tiles = n_row_tiles; + qk_job.n_col_tiles = n_col_tiles; + qk_job.n_dot_tiles = DK / 32; + qk_job.n_tiles_per_bc = n_tiles_per_bc; + qk_job.hmx_scales = factx.vtcm_hmx_scales_qk; + TIMER_START(qk_dot); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job)); + + // DMA push next block (non-blocking, before worker_pool) + DMA_PREFETCH_KV(kv_blk + 1); + + TIMER_START(v_interleave); + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + hmx_queue_pop(hmx_q); + TIMER_STOP(qk_dot); + + // ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ---- + // Pop mask DMA before softmax (ensures VTCM buffer is ready) + MASK_DMA_POP(has_mask_dma); + + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + buf_idx = 1 - buf_idx; + } // end KV block loop (pipeline) + + // Epilogue: O_update for last block + if (factx.n_kv_blocks > 0) { + const uint32_t last_blk = factx.n_kv_blocks - 1; + const size_t last_cols = hmx_ceil_div(hex_smin(Bc, nek1 - last_blk * Bc), HMX_FP16_TILE_N_COLS); + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = last_cols; + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + + TIMER_START(o_update); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + hmx_queue_pop(hmx_q); + TIMER_STOP(o_update); + + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + } else { + // ================================================================== + // Fallback path: sequential with multi-thread HVX phases + // Main thread holds HMX lock, runs HMX inline. + // ================================================================== + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + DMA_PREFETCH_KV(kv_blk + 1); + + // K interleave (multi-thread HVX) + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + // QK dot (inline HMX on main thread) + TIMER_START(qk_dot); + { + const size_t n_dot_tiles = (size_t) (DK / 32); + const __fp16 * restrict q_base = factx.vtcm_q_tiles; + const __fp16 * restrict k_base = factx.vtcm_k_tiles; + __fp16 * restrict s_base = factx.vtcm_s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_base + r * HMX_FP16_TILE_N_ROWS * DK; + const __fp16 * col_tiles = k_base + c * HMX_FP16_TILE_N_COLS * DK; + __fp16 * out_tile = s_base + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } + } + TIMER_STOP(qk_dot); + + // Pop mask DMA + MASK_DMA_POP(has_mask_dma); + + // Softmax + build_D (multi-thread HVX + serial m/l update) + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + // V interleave (multi-thread HVX) + TIMER_START(v_interleave); + // FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout + // stride to match o_update's v_tile access. Using per-block n_col_tiles + // misplaces DV_tile 1..3 in the last partial KV block. + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + // O update (inline HMX on main thread) + TIMER_START(o_update); + { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict p_base = factx.vtcm_p_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + const __fp16 * p_tile_in = p_base + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_base + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = oc_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + TIMER_STOP(o_update); + + buf_idx = 1 - buf_idx; + } // end KV block loop (fallback) + } + + // ---- Final normalization: O = diag(1/l) @ O ---- + TIMER_START(o_norm); + { + fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); + + // HMX: O_final = diag(1/l) @ O_prev + if (factx.use_pipeline) { + on_job.o_curr = o_tile_curr; + on_job.o_prev = o_tile_prev; + on_job.d_tiles = factx.vtcm_d_tiles; + on_job.hmx_scales = factx.vtcm_hmx_scales_id; + on_job.n_row_tiles = n_row_tiles; + on_job.n_row_tiles_g_br = n_row_tiles_g_br; + on_job.DV = DV; + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_fa_o_norm_worker, &on_job)); + hmx_queue_pop(ctx->hmx_queue); + } else { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = oc_base + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } + } + } + TIMER_STOP(o_norm); + + // ---- Store O block ---- + TIMER_START(o_store); + fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(o_store); + +#undef MASK_DMA_PUSH +#undef MASK_DMA_POP +#undef DMA_PREFETCH_KV + + } // end Q block loop + } // end KV head loop + } // end batch loop + + if (factx.use_pipeline) { + hmx_queue_suspend(ctx->hmx_queue); + } else { + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total), + TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave)); + FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax), + TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store)); +#endif + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c new file mode 100644 index 00000000000..2666a78a96a --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -0,0 +1,1757 @@ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "hex-dma.h" +#include "worker-pool.h" + +#include "hvx-utils.h" +#include "hvx-dump.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +#include "hmx-ops.h" +#include "hmx-utils.h" +#include "hmx-queue.h" +#include "hmx-profile.h" + +#include "vtcm-utils.h" + +static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, +}; + +// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value +// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 +static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, +}; + +static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, + 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, +}; + +// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes +#define HMX_X4X2_SCALES_PER_BLK 8 +#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) +#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) + +// Compute the byte stride of one row in x4x2 format. +// Numerically equals ggml_row_size(type, k) when k is 256-aligned, because +// x4x2 packing has the same density as block_q4_0 / block_q8_0. +// Layout per row: [quants: nb*128 (Q4) or nb*256 (Q8)][scales: nb*16 bytes] +// Total per row = nb * (128+16) = 144*nb (Q4) or nb * (256+16) = 272*nb (Q8). +// Callers must ensure k is a multiple of 256 (enforced by proc_hmx_matmul_req). +static inline size_t get_x4x2_row_stride(int weight_type, int k) { + int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q8_0: + return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb + case HTP_TYPE_MXFP4: + return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb + default: + return 0; + } +} + +// --- Overflow-safe arithmetic for VTCM budget calculation --- + +static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) { + if (a != 0 && b > SIZE_MAX / a) return true; + *out = a * b; + return false; +} + +static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { + if (a > SIZE_MAX - b) return true; + *out = a + b; + return false; +} + +// Search for optimal (mc, nc) chunk sizes within VTCM budget. +// +// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead +// +// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost. +// All matmul paths repeat weight processing per M-block and activation loading +// per N-block, so discrete block counts drive total overhead. +// Tie-break: when cost is equal, prefer larger mc * nc. +// +// Caller-provided coefficients: +// m_block_cost: penalty per extra M-block (weight redundancy, scales with n). +// n_block_cost: penalty per extra N-block (activation redundancy, scales with m). +// +// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. +// Returns 0 on success, -1 if VTCM is insufficient. +static int hmx_compute_chunks(size_t vtcm_total, + size_t overhead, + size_t per_n_cost, + size_t per_m_cost, + size_t per_mn_cost, + int m, + int n, + size_t m_block_cost, + size_t n_block_cost, + size_t * m_chunk_out, + size_t * n_chunk_out, + size_t * total_out) { + if (m <= 0 || n <= 0) return -1; + if (vtcm_total <= overhead) return -1; + if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; + + const size_t usable = vtcm_total - overhead; + + size_t best_cost = SIZE_MAX; + size_t best_mn = 0; + size_t best_m = 0, best_n = 0; + + const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); + for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { + size_t n_fixed = 0, ncmn = 0, mc_denom = 0; + if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; + if (n_fixed >= usable) goto next_nc; + + if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; + if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; + + { + size_t remain = usable - n_fixed; + size_t mc = remain / mc_denom; + mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); + mc = hex_smin(mc, (size_t)m); + + if (mc == 0) { + goto next_nc; + } + + size_t mblocks = ((size_t) m + mc - 1) / mc; + size_t nblocks = ((size_t) n + nc - 1) / nc; + size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; + size_t mn = mc * nc; + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_m = mc; + best_n = nc; + } + } + +next_nc: + if (nc == HMX_FP16_TILE_N_COLS) break; // avoid size_t underflow + } + + if (best_m == 0 || best_n == 0) return -1; + + // Compute exact total (with overflow checks) + size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; + if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1; + if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1; + if (hmx_mul_overflow(best_m, best_n, &mn)) return -1; + if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1; + if (hmx_add_overflow(t0, t1, &total)) return -1; + if (hmx_add_overflow(total, t2, &total)) return -1; + if (hmx_add_overflow(total, overhead, &total)) return -1; + + *m_chunk_out = best_m; + *n_chunk_out = best_n; + *total_out = total; + return 0; +} + +// --- x4x2 format dequantizers --- + +// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. +// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles +// of the same 32 packed bytes. +static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx( + const uint8_t *packed_32, bool upper_nibbles, + const __fp16 *scale, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + // Shuffle before LUT + v_quants = Q6_Vb_vshuff_Vb(v_quants); + // Use standard vlut16 (not _nomatch) to avoid stale-register NaN. + // _nomatch retains the previous destination-register value for colliding + // indices, but the C intrinsic doesn't model the implicit read so the + // compiler may allocate a register containing garbage/NaN. + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using +// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. +// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes. +static inline void dequantize_x4x2_q4_0_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const HVX_Vector vlut_cvt, + HVX_Vector out[4]) { + // Load all 128 packed bytes (4 contiguous 32-byte groups) + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + // Shuffle before LUT + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + // Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16] + HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] + + // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b + HVX_VectorPred q64 = Q6_Q_vsetq_R(64); + HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1])); + HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3])); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter + out[0] = v_lo; // group0 already in [0:63] + out[1] = v_hi; // group2 already in [0:63] +} + +// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. +static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { + HVX_Vector vq = hvx_vmemu(quants_32); + HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// --- MXFP4 E8M0 scale conversion and dequantization --- +// +// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack. +// Scalar loads from the stack array execute on the scalar pipeline, in parallel +// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop. +// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10 +// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15. + +typedef struct { + __fp16 v[8] __attribute__((aligned(16))); +} mxfp4_scales_t; + +static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) { + mxfp4_scales_t s; + HVX_Vector v = hvx_vmemu(e8m0_8); + HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); + vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); + vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); + vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); + vh = Q6_Vh_vasl_VhR(vh, 10); + hvx_vec_store_u(s.v, 16, vh); + return s; +} + +static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) { + return hvx_vec_splat_f16(scales.v[idx]); +} + +// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16. +static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32, + bool upper_nibbles, + int sub_blk, + const HVX_Vector vlut_cvt, + mxfp4_scales_t scales) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc)); +} + +// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). +static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, + bool upper_nibbles, + int sub_blk_base, + const HVX_Vector vlut_cvt, + mxfp4_scales_t scales, + HVX_Vector out[4]) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_VectorPred q64 = Q6_Q_vsetq_R(64); + HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0), + mxfp4_extract_splat(scales, sub_blk_base + 1)); + HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2), + mxfp4_extract_splat(scales, sub_blk_base + 3)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + out[0] = v_lo; + out[1] = Q6_V_vror_VR(v_lo, 64); + out[2] = v_hi; + out[3] = Q6_V_vror_VR(v_hi, 64); +} + +// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. +// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. +// Output: vtcm_dst in tile-major FP16 layout. +static void dequantize_x4x2_weight_to_fp16_tiles_task( + __fp16 *restrict vtcm_dst, + const uint8_t *restrict vtcm_src, + int n_cols, int k_block, + size_t row_stride, int weight_type, + int start_tile, int end_tile) { + + const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; + + const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : + (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : + hvx_vmem(q4_0_to_fp16_lut); + + // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. + // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 + // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) + + unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index + unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index + for (unsigned t = start_tile; t < end_tile; ) { + if (kt >= n_k_tiles) { kt = 0; ct++; } + + // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- + if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; + unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 + bool upper = (sub_blk_base >= 4); + unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes + unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + + __fp16 *tile_bases[4]; + for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } + + HVX_Vector v_off = v_scat_base; + + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; + + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + HVX_Vector v0[2]; + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + + r0 = vtcm_src + row_offset; row_offset += row_stride; + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + + for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } + t += 4; kt += 4; + continue; + } + + // --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales --- + if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + int blk_idx = (kt * 32) / QK_MXFP4x4x2; + int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4 + bool upper = (sub_blk_base >= 4); + int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales + + __fp16 * tile_bases[4]; + for (int g = 0; g < 4; g++) { + tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; + } + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + const uint8_t * r0 = vtcm_src + row0 * row_stride; + const uint8_t * r1 = vtcm_src + row1 * row_stride; + + // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) + mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); + + HVX_Vector v0[4], v1[4]; + dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0); + if (row1 < n_cols) { + mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); + dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1); + } else { + v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + } + + for (int g = 0; g < 4; g++) { + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); + } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + for (int g = 0; g < 4; g++) { + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); + } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + + for (int g = 0; g < 4; g++) { + (void) *(volatile HVX_Vector *) (tile_bases[g]); + } + + t += 4; + continue; + } + + // --- Single-tile fallback --- + __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; + + if (is_q4) { + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; + unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; + bool upper = (sub_blk >= 4); + unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; // reset to column 0 + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_0_group_hvx( + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } else if (weight_type == HTP_TYPE_MXFP4) { + int blk_idx = (kt * 32) / QK_MXFP4x4x2; + int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; + bool upper = (sub_blk >= 4); + int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t * r0 = vtcm_src + row0 * row_stride; + const uint8_t * r1 = vtcm_src + row1 * row_stride; + + // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) + mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); + + HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); + HVX_Vector v1; + if (row1 < n_cols) { + mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); + v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); + } else { + v1 = Q6_V_vzero(); + } + + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *) (tile_base); + } else { + // Q8_0 + int blk_idx = (kt * 32) / QK_Q8_0x4x2; + int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; + int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; + int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; // reset to column 0 + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = vtcm_src + row0 * row_stride; + const uint8_t *r1 = vtcm_src + row1 * row_stride; + + HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx( + (const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q8_0_group_hvx( + (const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + // Drain HVX scatter write buffer: a vmem load on the same HW thread retires + // all pending scatter entries to VTCM. Without this, the main thread's HMX + // reads may see stale data because atomic_fetch_sub (release) only orders + // regular stores, not the HVX scatter buffer. + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +typedef struct { + __fp16 *dst; + const uint8_t *src; + int n_cols; + int k_block; + size_t row_stride; + int weight_type; + int n_tot_tiles; + int n_tiles_per_task; + int n_tasks; +} x4x2_dequantize_state_t; + +static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + + dequantize_x4x2_weight_to_fp16_tiles_task( + state->dst, state->src, state->n_cols, state->k_block, + state->row_stride, state->weight_type, start, end); + } +} + +static void dequantize_x4x2_weight_chunk_to_fp16_tiles( + struct htp_context *ctx, __fp16 *vtcm_dst, + const void *vtcm_src, int n_cols, int k_block, + size_t row_stride, int weight_type) { + + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + assert(k_block % HMX_FP16_TILE_N_COLS == 0); + + size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; + size_t n_tot_tiles = n_col_tiles * n_k_tiles; + + size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); + + x4x2_dequantize_state_t state; + state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; + state.n_tot_tiles = n_tot_tiles; + state.n_tiles_per_task = n_tiles_per_task; + state.dst = vtcm_dst; + state.src = (const uint8_t *)vtcm_src; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; + + worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); +} + +// --- End x4x2 dequantizers --- + +// requires external HMX lock +static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, + int n_row_tiles, int n_col_tiles, int n_dot_tiles) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)scales); + for (int r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + + for (int k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + + __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +// --- Async HMX matmul job (for pipeline overlap) --- + +typedef struct { + __fp16 * output; + const __fp16 * activation; + const __fp16 * weight; + const __fp16 * scales; + uint32_t n_row_tiles; + uint32_t n_col_tiles; + uint32_t n_dot_tiles; +} hmx_matmul_job_t; + +static void hmx_matmul_worker_fn(void * data) { + hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; + FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); + core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); +} + +static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, + __fp16 * output, + const __fp16 * activation, + const __fp16 * weight, + const __fp16 * scales, + int n_row_tiles, + int n_col_tiles, + int n_dot_tiles) { + job->output = output; + job->activation = activation; + job->weight = weight; + job->scales = scales; + job->n_row_tiles = n_row_tiles; + job->n_col_tiles = n_col_tiles; + job->n_dot_tiles = n_dot_tiles; +} + +// output : fp16 -> f32p + +static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + float *output_row_base = dst + r * n; // global memory row base for row r (and r+1) + + #pragma unroll(4) + for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0); + volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (r + 1 < n_rows) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + int n; // DDR row stride (total output columns) +} output_transfer_task_state_t; + +static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + float *dst = st->dst + chunk_idx * st->n; + const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols; + transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n); + } +} + +static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, + int n_rows, int n_cols, int n) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + + output_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.vtcm_src = vtcm_src; + state.n_cols = n_cols; + state.n = n; + + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); +} + +// activations : fp32 -> fp16 + +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { + for (int r = 0; r < n_rows; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool next_row_valid = (r + 1) < n_rows; + + const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); + const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + int k_stride; +} activation_transfer_task_state_t; + +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + // one chunk: one row + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); + } +} + +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { + assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); + assert(VLEN == 32 * sizeof(float)); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; + + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); +} + +// + +#define FALLBACK_TO_STANDARD 1 + +// C += AB +static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, + const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, + int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)col_scales); + + const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); + + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; + if (!zero_init) { + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); + } + + for (int k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(accum_tile, 0); + } + } +} + +static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, + float *restrict out, const float *restrict x, const uint8_t *restrict w, + int m, int k, int n, int weight_type) { + // assume k % 32 == 0 && n % 32 == 0 + const size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + const size_t vtcm_budget = ctx->vtcm_size; + + const size_t K_BLOCK_SIZE = 1024; + + // Fallback: if k doesn't need K-blocking, out-stationary has no advantage + const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; + if (k_iters_check <= 1) { + FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); + return FALLBACK_TO_STANDARD; + } + + // Dynamic M,N search via hmx_compute_chunks + const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); + const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) + + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) + const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) + + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) + const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) + + // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; + // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin + const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; + const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment + + size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; + // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. + // From profiling: wt_dequant per element ≈ 1.5× activation load per element. + // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). + // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). + const size_t m_block_cost = (size_t) n * 3; + const size_t n_block_cost = (size_t) m * 2; + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + &N_BLOCK_SIZE, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + + // Compute precise buffer sizes from searched M,N and fixed K + const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); + const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); + + const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; + if (total_vtcm > vtcm_budget) { + FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, + vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); + return -1; + } + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); + uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); + uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); + __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); + + FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, + M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + // initialize eye tile (32x32 identity matrix) + { + HVX_Vector v; + v = Q6_V_vzero(); + v = Q6_Vw_vinsert_VwR(v, 0x3c000000); + v = Q6_V_vror_VR(v, VLEN - 4); + v = Q6_Vw_vinsert_VwR(v, 0x00003c00); + for (int i = 0; i < 16; ++i) { + ((HVX_Vector *) vtcm_eye_tile)[i] = v; + v = Q6_V_vror_VR(v, VLEN - 8); + } + } + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + TIMER_DEFINE(fetch); + TIMER_DEFINE(act_load); + TIMER_DEFINE(wt_dequant); + TIMER_DEFINE(core); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { + size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); + for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { + size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); + + const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); + + for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { + const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + + TIMER_START(fetch); + // fetch activation block into VTCM + { + const float *activation_block = x + mr * k + kk; + + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_scratch1, activation_block), + k_blk_sz * sizeof(float), + k * sizeof(float), + k_blk_sz * sizeof(float), + m_blk_sz); + } + + // fetch weight block into VTCM (x4x2 sub-block: quants + scales) + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + { + const int blk_start = kk / QK_Q4_0x4x2; + const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); + const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; + uint8_t *dst = vtcm_scratch0; + const uint8_t *src = w + nc * row_stride; + const size_t n_rows = n_blk_sz; + const size_t src_stride = row_stride; + const size_t dst_stride = sub_row_stride; + const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); + const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); + const size_t scale_off = full_qrow + blk_start * scale_blk_size; + const size_t scale_width = nb_sub * scale_blk_size; + + // 2D DMA: quants sub-range + dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); + // 2D DMA: scales sub-range + dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); + } + TIMER_STOP(fetch); + + TIMER_START(act_load); + // load activation block + { + dma_queue_pop(ctx->dma[0]); // wait for act DNA + transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); + } + TIMER_STOP(act_load); + + TIMER_START(wt_dequant); + // dequantize weight block + { + dma_queue_pop(ctx->dma[0]); + dma_queue_pop(ctx->dma[0]); + // vtcm_scratch0 is used to store the qweight chunk + // worker_pool_run_func already returned, so fetch is done + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, + n_blk_sz, k_blk_sz, sub_row_stride, weight_type); + } + TIMER_STOP(wt_dequant); + + // core mma + TIMER_START(core); + { + core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, + n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); + } + TIMER_STOP(core); + } + + // store output block + { + float *output_block = out + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", + TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); +#endif + return 0; +} + +int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const uint8_t *restrict permuted_weight, int m, int k, int n, + int weight_type) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + // for large m, k (e.g. prefill FFN Down), use out-stationary version + if (m >= 128 && k > n && n > 1024) { + int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + if (rc != FALLBACK_TO_STANDARD) { + return rc; // 0 success, -1 error + } + FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); + // fall through to standard path + } + + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); + + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = k * sizeof(__fp16); + + // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. + // Only pays off when the chunker yields >=2 n-chunks, so the main loop can + // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for + // double-buffered output and the worker-dispatch overhead are pure loss. + // Try pipeline costs first; fall back to sequential if the layout collapses + // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. + const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) + const size_t seq_per_mn = sizeof(__fp16); // O x 1 + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + bool use_pipeline = false; + + if (m >= 128) { + size_t mc = 0, nc = 0, used = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && + hmx_ceil_div((size_t) n, nc) >= 2) { + m_chunk_n_rows = mc; + n_chunk_n_cols = nc; + vtcm_used = used; + use_pipeline = true; + } + } + + if (!use_pipeline) { + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + } + + // Compute precise buffer sizes per execution path + const size_t weight_area_size = hex_align_up( + n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up( + m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size, scratch1_size, scratch2_size; + if (use_pipeline) { + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 + } else { + scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 + scratch1_size = scratch0_size; // x4x2 DMA buf 1 + scratch2_size = 0; // unused + } + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, weight_type, use_pipeline, + m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + + TIMER_DEFINE(total); + TIMER_START(total); + + FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", + use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + if (!use_pipeline) { + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); + } + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + + const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); + } + + // Dequant + vscatter writes directly to [K, N] transposed tiles. + // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); + + hex_swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + TIMER_START(hmx_core); + { + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); + } + TIMER_STOP(output_store); + } + } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } else { + // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) + // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). + + // A --> B: vtcm_qweight, 1 buffer + // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers + // C --> D: vtcm_output0/vtcm_output1, 2 buffers + + // Async timeline (C overlaps B+D): + // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] + // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] + + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + } + + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } + + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + } + + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); + } + } + + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + } + + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's + // counterpart — and (i+1)%2 was last used by C_{i-1} which completed + // before C_i was submitted. + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); + } + } + } + + hmx_queue_suspend(ctx->hmx_queue); + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline); + if (!use_pipeline) { + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + size_t weight_size = (size_t)n * row_stride; + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } +#endif + + return 0; +} + +// + +static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} + +static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} + +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_matmul_batch_r2(params); + const int r3 = hmx_matmul_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->permuted_weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} + +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); +} + +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} + +static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mat_mul_permuted_w16a32(ctx, + hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); + } + } + return ret; +} + +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { + if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } + if (!params->m || !params->k || !params->n) { return -1; } + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } + + if (!hex_is_aligned(params->dst, VLEN) || + !hex_is_aligned(params->activation, VLEN) || + !hex_is_aligned(params->permuted_weight, VLEN)) { + return -1; + } + + const int group_size = hmx_matmul_batch_r2(params); + + if (group_size <= 1) { + FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + // Grouped path: reuse interleaved weight across all q_heads sharing a + // kv_head. Each q_head gets its own activation buffer in VTCM (so + // activation is loaded once per m_chunk and reused across all n_chunks), + // and each q_head is computed individually to avoid tile-major packing + // issues. m_chunk_n_rows is always a multiple of 32 (from + // hmx_compute_chunks), so per-head tile arrays don't overlap. + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = params->k * sizeof(__fp16); + + // When the activation has a large stride (e.g. permuted Q tensor with + // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. + // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather + // strided rows into a contiguous block before the F32->F16 conversion. + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*m_block_cost=*/(size_t) params->n, + /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + TIMER_DEFINE(total); + + TIMER_START(total); + + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + TIMER_START(activation_load); + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) params->k * sizeof(float); + const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + vtcm_f32_act, (int) n_rows, + params->k, params->k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride); + } + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, + 0, n_cols); + hex_swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + TIMER_START(hmx_core); + { + const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + } + TIMER_STOP(output_store); + } + } + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), + params->m, params->k, params->n, group_size); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); +#endif + + return 0; +} + +// + +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const __fp16 *restrict permuted_weight, int m, int k, int n, + int act_stride, int weight_stride) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (act_stride < k || weight_stride < k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = k * sizeof(__fp16); + + // DMA-based activation gather for strided tensors (see batched path comment). + const bool use_dma_activation = (act_stride > k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, + /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, // W + S0 + S1 + /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch + /*per_mn=*/sizeof(__fp16), // O + m, n, + /*m_block_cost=*/(size_t) n, + /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + + TIMER_DEFINE(total); + TIMER_START(total); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * act_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) k * sizeof(float); + const size_t stride_bytes = (size_t) act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_activation, + vtcm_f32_act, n_rows, k, k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_activation, + activation_chunk, n_rows, k, act_stride); + } + } + TIMER_STOP(activation_load); + + const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + // issue async DMA for the first weight chunk + // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. + // The source rows can be strided (e.g. KV-cache K after ggml_permute). + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + + // issue async DMA for the next weight chunk (double buffering) + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + // interleave row-major fp16 from scratch into tile-major in vtcm_weight + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); + + hex_swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + TIMER_START(hmx_core); + { + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); + } + TIMER_STOP(output_store); + } + + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + { + size_t weight_size = (size_t)k * n * sizeof(__fp16); + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } +#endif + + return 0; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h new file mode 100644 index 00000000000..1c78ffadd1c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -0,0 +1,71 @@ +// HMX operation entry-point declarations. +// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib) + +#ifndef HMX_OPS_H +#define HMX_OPS_H + +#include +#include + +#include "htp-ops.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct { + float *dst; + const float *activation; + const __fp16 *permuted_weight; + int m; + int k; + int n; + int act_stride; + int weight_stride; + int dst_stride; + int ne02; + int ne03; + int ne12; + int ne13; + size_t src0_nb2; + size_t src0_nb3; + size_t src1_nb2; + size_t src1_nb3; + size_t dst_nb2; + size_t dst_nb3; +} hmx_matmul_w16a32_batched_params_t; + +// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output +// act_stride: activation row stride in elements (= k for contiguous, or +// nb[1]/sizeof(float) for permuted tensors like attention Q). +// weight_stride: weight row stride in elements (= k for compact weights, or +// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const __fp16 *permuted_weight, + int m, int k, int n, + int act_stride, + int weight_stride); + +// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32. +// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params); + +// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL) +int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int weight_type); + +// HMX flash attention +int hmx_flash_attn_ext(struct htp_ops_context * octx); + +#ifdef __cplusplus +} +#endif + +#endif // HMX_OPS_H diff --git a/ggml/src/ggml-hexagon/htp/hmx-profile.h b/ggml/src/ggml-hexagon/htp/hmx-profile.h new file mode 100644 index 00000000000..01eece720c5 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-profile.h @@ -0,0 +1,34 @@ +// Conditional fine-grained profiling macros for HMX operations. +// +// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this +// header) to instrument sub-operation latencies with HAP qtimer. When the +// macro is not defined the TIMER_* helpers expand to nothing so there is zero +// overhead. +// +// Usage: +// TIMER_DEFINE(my_phase); // declare accumulator variable +// TIMER_START(my_phase); // snapshot start time +// ... work ... +// TIMER_STOP(my_phase); // accumulate elapsed ticks +// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase)); + +#ifndef HMX_PROFILE_H +#define HMX_PROFILE_H + +#include + +// #define ENABLE_PROFILE_TIMERS + +#if defined(ENABLE_PROFILE_TIMERS) +# define TIMER_DEFINE(name) int64_t name##_ticks = 0 +# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count() +# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0 +# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks) +#else +# define TIMER_DEFINE(name) +# define TIMER_START(name) +# define TIMER_STOP(name) +# define TIMER_US(name) 0LL +#endif + +#endif // HMX_PROFILE_H diff --git a/ggml/src/ggml-hexagon/htp/hmx-queue.c b/ggml/src/ggml-hexagon/htp/hmx-queue.c new file mode 100644 index 00000000000..5b1d83a0cbf --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-queue.c @@ -0,0 +1,158 @@ +#pragma clang diagnostic ignored "-Wunused-function" + +#include +#include +#include + +#include +#include + +#include + +#include "hmx-queue.h" + +#define QURT_LOWEST_PRIO (254) + +static inline void hmx_lock(struct hmx_queue *q) +{ + if (!q->hmx_locked) { + HAP_compute_res_hmx_lock(q->hap_rctx); + q->hmx_locked = true; + } +} + +static inline void hmx_unlock(struct hmx_queue *q) +{ + if (q->hmx_locked) { + HAP_compute_res_hmx_unlock(q->hap_rctx); + q->hmx_locked = false; + } +} + +static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) { + unsigned int ir = atomic_load(&q->idx_read); + + while (ir != atomic_load(&q->idx_write)) { + struct hmx_queue_desc *d = &q->desc[ir]; + if (!d->done) { + FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data); + + enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func; + switch (sig) { + case HMX_QUEUE_NOOP: /* noop */; break; + case HMX_QUEUE_KILL: *killed = true; break; + case HMX_QUEUE_SUSPEND: hmx_unlock(q); break; + default: + hmx_lock(q); + d->func(d->data); + break; + } + + atomic_fetch_add(&d->done, 1); + } + + ir = (ir + 1) & q->idx_mask; + atomic_store(&q->idx_read, ir); + } +} + +static void hmx_queue_thread(void * arg) { + struct hmx_queue * q = (struct hmx_queue *) arg; + + FARF(HIGH, "hmx-queue-thread: started"); + + bool killed = false; + + unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT; + unsigned int prev_seqn = 0; + while (!killed) { + unsigned int seqn = atomic_load(&q->seqn); + if (seqn == prev_seqn) { + if (--poll_cnt) { hex_pause(); continue; } + FARF(HIGH, "hmx-queue-thread: sleeping"); + qurt_futex_wait(&q->seqn, prev_seqn); + continue; + } + prev_seqn = seqn; + poll_cnt = HMX_QUEUE_POLL_COUNT; + + FARF(HIGH, "hmx-queue-thread: new work"); + + hmx_queue_process(q, &killed); + } + + FARF(HIGH, "hmx-queue-thread: stopped"); +} + +struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) { + capacity = hex_ceil_pow2(capacity); + + struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue)); + if (q == NULL) { + FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__); + return NULL; + } + memset(q, 0, sizeof(struct hmx_queue)); + q->capacity = capacity; + q->idx_mask = capacity - 1; + q->hap_rctx = hap_rctx; + + q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc)); + if (!q->desc) { + FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n"); + return NULL; + } + memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc)); + + const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE; + q->stack = (unsigned char *) memalign(64, stack_size); + if (!q->stack) { + FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size); + return NULL; + } + memset(q->stack, 0, stack_size); + + // Match caller thread priority (same pattern as worker-pool.c). + int prio = qurt_thread_get_priority(qurt_thread_get_id()); + if (prio < 1) { + prio = 1; + } + if (prio > QURT_LOWEST_PRIO) { + prio = QURT_LOWEST_PRIO; + } + + qurt_thread_attr_t attr; + qurt_thread_attr_init(&attr); + qurt_thread_attr_set_stack_addr(&attr, q->stack); + qurt_thread_attr_set_stack_size(&attr, stack_size); + qurt_thread_attr_set_priority(&attr, prio); + qurt_thread_attr_set_name(&attr, "hmx-queue"); + + int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q); + if (err) { + FARF(ERROR, "hmx-worker: thread create failed (%d)", err); + return NULL; + } + + FARF(HIGH, "hmx-queue: capacity %u\n", capacity); + + return q; +} + +void hmx_queue_delete(struct hmx_queue * q) { + if (!q) { + return; + } + + // Tell the worker to exit. + hmx_queue_flush(q); + hmx_queue_signal(q, HMX_QUEUE_KILL); + hmx_queue_flush(q); + + int status; + qurt_thread_join(q->thread, &status); + + free(q->desc); + free(q->stack); + free(q); +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-queue.h b/ggml/src/ggml-hexagon/htp/hmx-queue.h new file mode 100644 index 00000000000..0d48c280f52 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-queue.h @@ -0,0 +1,134 @@ +#ifndef HMX_QUEUE_H +#define HMX_QUEUE_H + +#include +#include +#include + +#include +#include +#include +#include + +#include "hex-utils.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024) +#define HMX_QUEUE_POLL_COUNT 2000 + +typedef void (*hmx_queue_func)(void *); + +// Dummy funcs used as signals +enum hmx_queue_signal { + HMX_QUEUE_NOOP = 0, // aka NULL + HMX_QUEUE_SUSPEND, + HMX_QUEUE_KILL +}; + +struct hmx_queue_desc { + hmx_queue_func func; + void * data; + atomic_uint done; +}; + +struct hmx_queue { + struct hmx_queue_desc * desc; + atomic_uint idx_write; // updated by producer (push) + atomic_uint idx_read; // updated by consumer (process) + unsigned int idx_pop; // updated by producer (pop) + uint32_t idx_mask; + uint32_t capacity; + + atomic_uint seqn; // incremented for all pushes, used with futex + qurt_thread_t thread; + void * stack; + uint32_t hap_rctx; + bool hmx_locked; +}; + +struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx); +void hmx_queue_delete(struct hmx_queue * q); + +static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) { + struct hmx_queue_desc d = { func, data }; + return d; +} + +static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) { + unsigned int ir = atomic_load(&q->idx_read); + unsigned int iw = q->idx_write; + + if (((iw + 1) & q->idx_mask) == ir) { + FARF(HIGH, "hmx-queue-push: queue is full\n"); + return false; + } + + atomic_store(&d.done, 0); + + FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data); + + q->desc[iw] = d; + atomic_store(&q->idx_write, (iw + 1) & q->idx_mask); + // wake up our thread + atomic_fetch_add(&q->seqn, 1); + qurt_futex_wake(&q->seqn, 1); + + return true; +} + +static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) { + return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL)); +} + +static inline bool hmx_queue_empty(struct hmx_queue * q) { + return q->idx_pop == q->idx_write; +} + +static inline uint32_t hmx_queue_depth(struct hmx_queue * q) { + return (q->idx_read - q->idx_read) & q->idx_mask; +} + +static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) { + return q->capacity; +} + +static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) { + unsigned int ip = q->idx_pop; + unsigned int iw = q->idx_write; + + struct hmx_queue_desc rd = { NULL, NULL }; + if (ip == iw) { + return rd; + } + + // Wait for desc to complete + struct hmx_queue_desc * d = &q->desc[ip]; + while (!atomic_load(&d->done)) { + FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip); + hex_pause(); + } + + rd = *d; + q->idx_pop = (ip + 1) & q->idx_mask; + + FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data); + return rd; +} + +static inline void hmx_queue_flush(struct hmx_queue * q) { + while (hmx_queue_pop(q).func != NULL) ; +} + +static inline void hmx_queue_suspend(struct hmx_queue *q) { + hmx_queue_signal(q, HMX_QUEUE_SUSPEND); + hmx_queue_flush(q); +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* HMX_QUEUE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h new file mode 100644 index 00000000000..68f174d6937 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -0,0 +1,202 @@ +// HMX tile-level inline helpers (FP16 32x32 tile operations). +// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib) + +#ifndef HMX_UTILS_H +#define HMX_UTILS_H + +#include "hvx-base.h" + +#include +#include +#include + +#define HMX_FP16_TILE_N_ROWS 32 +#define HMX_FP16_TILE_N_COLS 32 +#define HMX_FP16_TILE_N_ELMS 1024 +#define HMX_FP16_TILE_SIZE 2048 + +// Initialise aligned 256-byte area with scale vector + zero padding. +static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { + volatile HVX_Vector *pv = (HVX_Vector *) out_scales; + pv[0] = v_scale; + pv[1] = Q6_V_vzero(); +} + +// --- Shared scatter offsets and interleave helper --- + +// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. +// word[i] = i*128 maps K-row-pair i to byte offset i*128. +// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047); +// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the +// call site to scatter into one tile (masked) or two contiguous tiles (unmasked). +static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { + 0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128, + 11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128, + 22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128, +}; + +// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles. +// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used) +// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_cols. +static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, + const __fp16 * restrict vtcm_src, + int n_cols, + int k, + int src_stride, + int start_row, + int end_row) { + assert(k % HMX_FP16_TILE_N_COLS == 0); + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + // Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles. + // When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask) + // using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when + // n_k_tiles is odd) falls back to single-tile masked scatter. + const bool pair_scatter = (n_k_tiles & 1) == 0; + const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1); + const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1); + __builtin_assume(k > 0); + __builtin_assume(end_row > start_row); + + if (pair_scatter) { + // Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter. + const int c_step = 2 * HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); + p1 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } else { + // Fallback: scatter one K-tile per call (region 2047, masked). + const int c_step = HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); + p1 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } +} + +// Interleave row-major FP16 data into column-major tile format. +// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile]. +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_rows. +static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out, + const __fp16 * restrict src, + int n_rows, + int head_dim, + int src_stride, + int n_row_tiles, + int start_row, + int end_row) { + __builtin_assume(head_dim > 0); + const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS; + + for (int r = start_row; r < end_row; r += 2) { + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows; + + const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride); + const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL; + + // Row-pair invariants hoisted out of the c loop. + const int r0 = r / HMX_FP16_TILE_N_ROWS; + const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2; + + // tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0). + // Each c step (+= 64) advances both by 2 dim-tiles worth of fp16. + __fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS; + __fp16 * tb1 = tb0 + tile_stride_elms; + const size_t tb_step = 2 * tile_stride_elms; + + if (pv_in1) { + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } + } +} + +#endif // HMX_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index a707d98239c..e9c563ca887 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -2,34 +2,109 @@ #define HTP_CTX_H #include "hex-dma.h" +#include "hmx-queue.h" +#include "htp-ops.h" #include "worker-pool.h" #include #include #include #include +#include #define HTP_MAX_NTHREADS 10 +#define HTP_MAX_MMAPS 16 + +// Memory mapping +struct htp_mmap { + uint64_t size; + uint64_t base; + uint32_t fd; + uint32_t reserved; +}; + +// Scratchpad state +struct htp_spad { + const struct htp_tensor * src; // original src of the data (for reuse) + uint8_t * data; // pointer to an area in vtcm + uint32_t stride; // stride used inside this spad + uint32_t size; // total size + uint32_t size_per_thread; // size per thread +}; + +struct htp_context; + +// Context while processing an Op +// TODO: fold this into the main context +struct htp_ops_context { + struct htp_context * ctx; + + enum htp_op_code op; // FIXME: rename to opcode + int32_t op_params[HTP_OP_MAX_PARAMS]; + + const struct htp_tensor * src[HTP_OP_MAX_INPUTS]; + const struct htp_tensor * dst; + + // TODO convert these to an array + struct htp_spad src0_spad; + struct htp_spad src1_spad; + struct htp_spad src2_spad; + struct htp_spad src3_spad; + struct htp_spad dst_spad; + + uint32_t n_threads; + uint32_t flags; +}; // Main context for htp DSP backend struct htp_context { - dspqueue_t queue; - dma_queue * dma[HTP_MAX_NTHREADS]; - worker_pool_context_t worker_pool; - uint32_t n_threads; + dspqueue_t queue; + dma_queue * dma[HTP_MAX_NTHREADS]; + struct htp_mmap mmap[HTP_MAX_MMAPS]; + worker_pool_context_t worker_pool; + uint32_t n_threads; - int thread_id; - int thread_prio; + int thread_id; + int thread_prio; - uint8_t * vtcm_base; - size_t vtcm_size; - uint32_t vtcm_rctx; + bool hmx_enabled; + bool etm; + uint32_t profiler; - atomic_bool vtcm_valid; - atomic_bool vtcm_inuse; - atomic_bool vtcm_needs_release; + uint8_t * vtcm_base; + size_t vtcm_size; + uint32_t vtcm_rctx; + atomic_bool vtcm_valid; + atomic_bool vtcm_needs_release; - uint32_t opmask; + uint64_t max_vmem; + + struct htp_ops_context octx; + +#ifdef HTP_HAS_HMX + struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap +#endif }; +int op_matmul(struct htp_ops_context * octx); +int op_matmul_id(struct htp_ops_context * octx); +int op_binary(struct htp_ops_context * octx); +int op_unary(struct htp_ops_context * octx); +int op_sum_rows(struct htp_ops_context * octx); +int op_activations(struct htp_ops_context * octx); +int op_softmax(struct htp_ops_context * octx); +int op_add_id(struct htp_ops_context * octx); +int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); +int op_cpy(struct htp_ops_context * octx); +int op_repeat(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); +int op_ssm_conv(struct htp_ops_context * octx); +int op_cumsum(struct htp_ops_context * octx); +int op_fill(struct htp_ops_context * octx); +int op_diag(struct htp_ops_context * octx); +int op_solve_tri(struct htp_ops_context * octx); + #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h deleted file mode 100644 index 52dcc36d8f7..00000000000 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ /dev/null @@ -1,155 +0,0 @@ -#ifndef HTP_MSG_H -#define HTP_MSG_H - -#include - -// ggml-common.h must be included prio to this header - -// Mask to enable various stages of the Ops. -// Used for debugging and profiling. -enum { - HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) - HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize - HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute -}; - -// Op flags -enum { - HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors) - HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling) - HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification -}; - -enum htp_status { - HTP_STATUS_OK = 1, - HTP_STATUS_INTERNAL_ERR = 2, - HTP_STATUS_NO_SUPPORT = 3, - HTP_STATUS_INVAL_PARAMS = 4, - HTP_STATUS_VTCM_TOO_SMALL = 5, -}; - -// The values must match the ggml_type. -// Duplicated here because we can't include full ggml.h in the htp build. -// We have some static_asserts in the cpp code to ensure things are in sync. -enum htp_data_type { - HTP_TYPE_F32 = 0, - HTP_TYPE_F16 = 1, - HTP_TYPE_Q4_0 = 2, - HTP_TYPE_Q8_0 = 8, - HTP_TYPE_I32 = 26, - HTP_TYPE_I64 = 27, - HTP_TYPE_MXFP4 = 39, - HTP_TYPE_COUNT -}; - -// Do not reorder first 4 (used as an index) -enum htp_op { - HTP_OP_MUL = 0, - HTP_OP_ADD = 1, - HTP_OP_SUB = 2, - HTP_OP_DIV = 3, - HTP_OP_MUL_MAT, - HTP_OP_MUL_MAT_ID, - HTP_OP_RMS_NORM, - HTP_OP_UNARY_SILU, - HTP_OP_UNARY_GELU, - HTP_OP_GLU_SWIGLU, - HTP_OP_GLU_SWIGLU_OAI, - HTP_OP_GLU_GEGLU, - HTP_OP_SOFTMAX, - HTP_OP_ADD_ID, - HTP_OP_ROPE, - HTP_OP_FLASH_ATTN_EXT, - HTP_OP_SET_ROWS, - HTP_OP_GET_ROWS, - HTP_OP_SCALE, - HTP_OP_CPY, - HTP_OP_ARGSORT, - HTP_OP_SQR, - HTP_OP_SQRT, - HTP_OP_SUM_ROWS, - HTP_OP_SSM_CONV, - INVALID -}; - -static inline size_t htp_t_block_size(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return 1; - case HTP_TYPE_F16: - return 1; - case HTP_TYPE_Q4_0: - return QK4_0; - case HTP_TYPE_Q8_0: - return QK8_0; - case HTP_TYPE_MXFP4: - return QK_MXFP4; - default: - assert(0 && "unsupported HTP data type"); - } - return 0; -} - -static inline size_t htp_type_nbytes(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return 4; - case HTP_TYPE_F16: - return 2; - case HTP_TYPE_Q4_0: - return sizeof(block_q4_0); - case HTP_TYPE_Q8_0: - return sizeof(block_q8_0); - case HTP_TYPE_MXFP4: - return sizeof(block_mxfp4); - default: - assert(0 && "unsupported HTP data type"); - } - return 0; -} - -// Internal types -#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) -#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks -#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks - -#define HTP_MAX_DIMS 4 - -struct htp_tensor { - uint32_t data; // Buffer offset in the messages, and data pointer on the NSP - uint32_t type; // Data type - uint32_t ne[HTP_MAX_DIMS]; // Number of elements - uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) -}; - -#define HTP_MAX_OP_PARAMS 64 - -struct htp_general_req { - uint32_t op; // GGML/HTP Op - int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; - // Params for the op, e.g. epsilon of RMS norm - uint32_t flags; // Request flags - - struct htp_tensor src0; // Input0 tensor - struct htp_tensor src1; // Input1 tensor - struct htp_tensor src2; // Input2 tensor - struct htp_tensor src3; // Input3 tensor - struct htp_tensor src4; // Input4 tensor - struct htp_tensor dst; // Output tensor - - // should be multiple of 64 bytes (cacheline) -}; - -struct htp_general_rsp { - uint32_t op; // GGML/HTP Op - uint32_t status; // HTP_STATUS_... - uint32_t prof_usecs; // Number of usec per request - uint32_t prof_cycles; // Number of cycles per request - uint32_t prof_pkts; // Number of instruction packets per request - uint8_t unused[44]; // Pad to 64 bytes -}; - -#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) -#define HTP_MAX_PACKET_BUFFERS 8 - -#endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 2ef20936f1b..66a3150c1a0 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -1,63 +1,177 @@ #ifndef HTP_OPS_H #define HTP_OPS_H -#include "htp-ctx.h" -#include "htp-msg.h" -#include "worker-pool.h" - #include -#include - -#include - -// ggml-common.h must be included prior to this header - -struct htp_spad { - uint8_t * data; - size_t stride; - size_t size; - size_t size_per_thread; -}; - -struct htp_ops_context { - struct htp_context * ctx; - - enum htp_op op; - int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; - - struct htp_tensor src0; - struct htp_tensor src1; - struct htp_tensor src2; - struct htp_tensor src3; - struct htp_tensor src4; - struct htp_tensor dst; - - struct htp_spad src0_spad; - struct htp_spad src1_spad; - struct htp_spad src2_spad; - struct htp_spad src3_spad; - struct htp_spad dst_spad; - - worker_pool_context_t * wpool; // worker pool - uint32_t n_threads; // num threads - - uint32_t flags; -}; - -int op_matmul(struct htp_ops_context * octx); -int op_matmul_id(struct htp_ops_context * octx); -int op_binary(struct htp_ops_context * octx); -int op_unary(struct htp_ops_context * octx); -int op_sum_rows(struct htp_ops_context * octx); -int op_activations(struct htp_ops_context * octx); -int op_softmax(struct htp_ops_context * octx); -int op_add_id(struct htp_ops_context * octx); -int op_rope(struct htp_ops_context * octx); -int op_flash_attn_ext(struct htp_ops_context * octx); -int op_set_rows(struct htp_ops_context * octx); -int op_get_rows(struct htp_ops_context * octx); -int op_cpy(struct htp_ops_context * octx); -int op_argsort(struct htp_ops_context * octx); -int op_ssm_conv(struct htp_ops_context * octx); + +// ggml-common.h must be included prio to this header + +enum htp_status { + HTP_STATUS_OK = 1, + HTP_STATUS_INTERNAL_ERR = 2, + HTP_STATUS_NO_SUPPORT = 3, + HTP_STATUS_INVAL_PARAMS = 4, + HTP_STATUS_VTCM_TOO_SMALL = 5, +}; + +// First set of values must match the ggml_type. +// Duplicated here because we can't include full ggml.h in the htp build. +// We have some static_asserts in the cpp code to ensure things are in sync. +enum htp_data_type { + HTP_TYPE_F32 = 0, + HTP_TYPE_F16 = 1, + HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q8_0 = 8, + HTP_TYPE_IQ4_NL = 20, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, + HTP_TYPE_MXFP4 = 39, + + // types used internally for repack, dyn.quant, etc + HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q8_0x4x2, + HTP_TYPE_MXFP4x4x2, + + HTP_TYPE_INVALID +}; + +// Constats for internal types +#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks + + +// Mask to enable various stages of the Ops. +// Used for debugging and profiling. +enum htp_op_stage { + HTP_OPSTAGE_QUEUE = (1 << 0), // Enable Queueing (ie calls into NPU) + HTP_OPSTAGE_COMPUTE = (1 << 1), // Enable Compute +}; + +// Do not reorder first 4 (used as an index) +enum htp_op_code { + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT, + HTP_OP_MUL_MAT_ID, + HTP_OP_RMS_NORM, + HTP_OP_UNARY_SILU, + HTP_OP_UNARY_GELU, + HTP_OP_UNARY_SIGMOID, + HTP_OP_UNARY_EXP, + HTP_OP_UNARY_NEG, + HTP_OP_UNARY_SOFTPLUS, + HTP_OP_GLU_SWIGLU, + HTP_OP_GLU_SWIGLU_OAI, + HTP_OP_GLU_GEGLU, + HTP_OP_SOFTMAX, + HTP_OP_ADD_ID, + HTP_OP_ROPE, + HTP_OP_FLASH_ATTN_EXT, + HTP_OP_SET_ROWS, + HTP_OP_GET_ROWS, + HTP_OP_SCALE, + HTP_OP_CPY, + HTP_OP_ARGSORT, + HTP_OP_SQR, + HTP_OP_SQRT, + HTP_OP_SUM_ROWS, + HTP_OP_SSM_CONV, + HTP_OP_REPEAT, + HTP_OP_CUMSUM, + HTP_OP_FILL, + HTP_OP_DIAG, + HTP_OP_SOLVE_TRI, + HTP_OP_INVALID +}; + +#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS +#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS +#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS + +#define HTP_OP_MAX_BUFS 16 +#define HTP_OP_MAX_REQS 256 +#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) + +#define HTP_OP_MAX_VMEM_DEFAULT (3355443200u) + +#define HTP_MMAP_MAX_VMEM (2147483648u) + +enum htp_tensor_flags { + HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights) + HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU) +}; + +// Tensor descriptor +struct htp_tensor { + uint32_t data; // Buffer offset in the messages, and data pointer on the NPU + uint32_t size; // Data size in bytes + uint32_t flags; // Buffer / tensor flags + uint16_t type; // Data type + uint16_t bi; // Buffer index + uint32_t ne[HTP_OP_MAX_DIMS]; // Number of elements + uint32_t nb[HTP_OP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) +}; + +// Buffer descriptor +struct htp_buf_desc { + uint64_t base; // base address + uint64_t size; // total size + uint32_t flags; // buffer flags (unused) + uint32_t fd; // file descriptor +}; + +enum htp_op_flags { + HTP_OPFLAGS_SKIP_COMPUTE = (1U << 0), // Skip actual computation (used for profiling) +}; + +// Op descriptor +struct htp_op_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t flags; // Op flags + int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm + uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices + uint16_t dst; // Output tensor index +}; + +enum htp_profiler_mode { + HTP_PROF_DISABLED = 0, + HTP_PROF_BASIC = 1, + HTP_PROF_PMU = 2, +}; + +#define HTP_PROF_PMU_NCNT 8 + +// Profile descriptor +struct htp_prof_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t usecs; // Number of usec + uint32_t cycles; // Number of cycles + uint32_t pad; // Unused + uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters +}; + +struct htp_opbatch_req { + uint32_t id; // Batch id + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of ops + uint32_t flags; // unused + uint32_t pad; // unused + // struct htp_buf_desc bufs[]; -- dspqueue buf 0 + // struct htp_tensor tensors[]; -- dspqueue buf 0 + // struct htp_op_desc ops[]; -- dspqueue buf 0 +}; + +struct htp_opbatch_rsp { + uint32_t id; // Batch id + uint32_t status; // HTP_STATUS_... + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of op profile descriptors + uint32_t pad; // unused + // struct htp_prof_desc profs[]; -- dspqueue buf 0 +}; #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 9ebd937e46d..d696a5fba0c 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -6,11 +6,17 @@ #include "AEEStdDef.idl" #include "remote.idl" +struct htp_iface_pmu_conf { + uint32 events[8]; +}; + interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem); AEEResult stop(); - AEEResult enable_etm(); - AEEResult disable_etm(); + AEEResult mmap(in uint32 fd, in uint32 size); + AEEResult munmap(in uint32 fd); + AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); + AEEResult etm(in uint32 enable); }; #endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 578ca288fb6..f6cb02951d0 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -3,10 +3,15 @@ #include #include +#include +#include #include "hex-utils.h" #include "hvx-types.h" +#define hvx_vmem(A) *((HVX_Vector *)(A)) +#define hvx_vmemu(A) *((HVX_UVector *)(A)) + static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { // Rotate as needed. v = Q6_V_vlalign_VVR(v, v, (size_t) dst); @@ -72,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) { return x; } +static inline _Float16 hvx_vec_get_f16(HVX_Vector v) { + _Float16 __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 2, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); @@ -110,11 +121,20 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { return Q6_Q_and_QQ(p_exp, p_frac); } -static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { - const HVX_Vector zero = Q6_V_vsplat_R(0); +static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) { +#if __HVX_ARCH__ >= 81 + HVX_Vector q0 = Q6_Vqf32_equals_Vsf(v0); + HVX_Vector q1 = Q6_Vqf32_equals_Vsf(v1); +#else + const HVX_Vector zero = Q6_V_vzero(); HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); - HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0))); +#endif + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)); +} + +static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { + HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1)); #if __HVX_ARCH__ < 79 // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) @@ -126,6 +146,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { return v; } +#if __HVX_ARCH__ >= 79 +static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one); + return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p)); +} +static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one); + return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p)); +} +#else +static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one); + return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p))); +} +static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one); + return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p))); +} +#endif + /* Q6_Vsf_equals_Vw is only available on v73+.*/ #if __HVX_ARCH__ < 73 static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) @@ -218,6 +262,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)); +} + #else static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) @@ -235,6 +291,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_vmpy_VhfVhf(a, b); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vsub_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vmpy_VsfVsf(a, b); +} + #endif // __HVX_ARCH__ < 79 #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 851482e01b2..a3e33c3b3af 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -7,7 +7,8 @@ #include "hvx-base.h" -#define hvx_splat_loop_body(dst_type, vec_store) \ +#define hvx_splat_pragma(x) _Pragma(#x) +#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ \ @@ -16,7 +17,7 @@ \ uint32_t i = 0; \ \ - _Pragma("unroll(4)") \ + hvx_splat_pragma(unroll(unroll_cnt)) \ for (; i < nvec; i++) { \ vdst[i] = src; \ } \ @@ -25,31 +26,47 @@ } \ } while(0) -static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { +static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { assert((unsigned long) dst % 128 == 0); - hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a); + hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4); } -static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { - hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u); +static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4); } -static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) { hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) { +static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } -static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) { +static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } +static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1); +} + +static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1); +} + #define hvx_copy_loop_body(dst_type, src_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h index 05cefea039f..53ee304e749 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-div.h +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -16,8 +16,10 @@ #if __HVX_ARCH__ < 79 #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)) #else #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b) #endif // Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32. @@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX return res; } -#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \ - do { \ - dst_type * restrict vdst = (dst_type *) dst; \ - src_type * restrict vsrc = (src_type *) src; \ - HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ - \ - const uint32_t nvec = n / VLEN_FP16; \ - const uint32_t nloe = n % VLEN_FP16; \ - \ - uint32_t i = 0; \ - \ - _Pragma("unroll(4)") \ - for (; i < nvec; i++) { \ - HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ - vdst[i] = res; \ - } \ - if (nloe) { \ - HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ - vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ - } \ +// Variant for =v79 +static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // For older architectures, use f16 reciprocal to avoid NaN/-inf issues + HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask); + return HVX_OP_MUL_F16(vec1, vec2_inv); +#else + return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0); +#endif +} + #define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src0_type * restrict vsrc0 = (src0_type *) src0; \ src1_type * restrict vsrc1 = (src1_type *) src1; \ \ - const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ \ const uint32_t nvec = n / VLEN_FP16; \ @@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v \ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ - HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ vdst[i] = res; \ } \ if (nloe) { \ - HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ } \ } while(0) @@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32) HVX_DIV_DISPATCHER(hvx_div_f16) #undef HVX_OP_MUL_F32 +#undef HVX_OP_MUL_F16 #endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h index 44dfe232a3d..e71ec4909a6 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.h +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -3,6 +3,7 @@ #include #include +#include #include "hvx-base.h" #include "hvx-floor.h" @@ -16,8 +17,8 @@ #define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 #define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 #define EXP_ONE (0x3f800000) // 1.0 -#define EXP_RANGE_R (0x41a00000) // 20.0 -#define EXP_RANGE_L (0xc1a00000) // -20.0 +#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228 +#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN)) static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { HVX_Vector z_qf32_v; @@ -47,12 +48,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { HVX_Vector temp_v = in_vec; - // Clamp inputs to (-20.0, 20.0) + // Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec); epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); @@ -69,12 +70,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { // normalize before every QFloat's vmpy x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + // z = x * x; z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); - x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); - // y = E4 + E5 * x; E_const = Q6_V_vsplat_R(EXP_COEFF_5); y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); @@ -145,7 +146,7 @@ static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max return Q6_V_vmux_QVV(pred0, inf, out); } -static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { +static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -162,7 +163,7 @@ static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict HVX_Vector vec_out = Q6_V_vzero(); static const float kInf = INFINITY; - static const float kMaxExp = 88.02f; // log(INF) + static const float kMaxExp = 88.7228f; const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); const HVX_Vector inf = hvx_vec_splat_f32(kInf); diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h index 095193277ea..37f3e7b6fae 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -2,6 +2,7 @@ #define HVX_SIGMOID_H #include "hvx-base.h" +#include "hvx-inverse.h" #define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 #define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 08343798794..a518ad37331 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -15,12 +15,4 @@ #include "hvx-div.h" #include "hvx-base.h" -#ifndef GATHER_TYPE -# if defined(__hexagon__) -# define GATHER_TYPE(_a) (intptr_t) _a -# else -# define GATHER_TYPE(_a) (HVX_Vector *) _a -# endif -#endif - #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 3f99dbb32c4..49c1a15b344 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -1,5 +1,7 @@ #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" #pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" #include #include @@ -12,17 +14,20 @@ #include #include #include +#include #include #include -#include "hex-dma.h" #include "hex-utils.h" +#include "hex-dma.h" +#include "hmx-queue.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" #include "htp-ops.h" +#include "htp-ops.h" +#include "htp_iface.h" #include "worker-pool.h" AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { @@ -34,7 +39,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return AEE_ENOMEMORY; } - // Use the context structure as a handle + // Use the context structure as the handle *handle = (remote_handle64) ctx; // Enable FARF logs @@ -96,6 +101,74 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 + { + // Set HMX clock + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) &ctx, &request); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Error setting HMX clock."); + return err; + } + } +#endif + + return AEE_SUCCESS; +} + +AEEResult htp_iface_etm(remote_handle64 handle, uint32_t enable) { + int err = enable ? HAP_user_etm_enable() : HAP_user_etm_disable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_enable/disable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_enable/disable with error code : 0x%x\n", err); + } + } + return err; +} + +AEEResult htp_iface_profiler(remote_handle64 handle, uint32_t mode, const htp_iface_pmu_conf* pmu_conf) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + if (mode == HTP_PROF_PMU) { + const uint32_t* events = pmu_conf->events; + + // Pack 4 event IDs (low 8 bits) into each 32-bit config register + uint32_t evtcfg = 0, evtcfg1 = 0, cfg = 0, i = 0; + for (; i < HEX_NUM_PMU_COUNTERS/2; i++) { + evtcfg |= ((events[i + 0] & 0xFF) << (i * 8)); + evtcfg1 |= ((events[i + 4] & 0xFF) << (i * 8)); + } + + // For events >255 pack high 2 bits of all 8 event IDs into cfg register + // 2 bits per counter: bits [1:0] for counter 0, [3:2] for counter 1, etc. + for (i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + cfg |= (((events[i] >> 8) & 3) << (i * 2)); + } + + FARF(ALWAYS, "Configuring PMU registers: evtcfg = 0x%x, evtcfg1 = 0x%x, pmucfg = 0x%x", evtcfg, evtcfg1, cfg); + + // Configure PMU registers + qurt_pmu_set(QURT_PMUCFG, cfg); + qurt_pmu_set(QURT_PMUEVTCFG, evtcfg); + qurt_pmu_set(QURT_PMUEVTCFG1, evtcfg1); + qurt_pmu_enable(1); + } + + ctx->profiler = mode; + return AEE_SUCCESS; } @@ -111,91 +184,128 @@ AEEResult htp_iface_close(remote_handle64 handle) { return AEE_EITEMBUSY; } + // release the mmaps (if any) + for (uint32_t i=0; immap[i].size) { +#if __HVX_ARCH__ > 73 + HAP_munmap2((void *) ctx->mmap[i].base, ctx->mmap[i].size); +#else + HAP_munmap((void *) ctx->mmap[i].base, ctx->mmap[i].size); +#endif + ctx->mmap[i].size = 0; + ctx->mmap[i].base = NULL; + ctx->mmap[i].fd = -1; + } + } + + if (ctx->profiler) { + qurt_pmu_enable(1); + } + + if (ctx->etm) { + HAP_user_etm_disable(); + } + free(ctx); return AEE_SUCCESS; } -AEEResult htp_iface_enable_etm(remote_handle64 handle) { - int err = HAP_user_etm_enable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_enable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err); +AEEResult htp_iface_mmap(remote_handle64 handle, uint32_t fd, uint32_t size) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + // See if we already have this mapping + for (uint32_t i=0; immap[i]; + if (m->fd == fd) { + return AEE_SUCCESS; } } - return err; + + // Add new mapping + for (uint32_t i=0; immap[i]; + if (!m->size) { + FARF(HIGH, "mmap : fd %u size %u", fd, size); +#if __HVX_ARCH__ > 73 + void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); +#else + if (size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB + FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) size); + abort(); // can't do much else at this point + } + + void *va = HAP_mmap(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); +#endif + if (va == (void*)-1) { + FARF(ERROR, "mmap failed : va %p fd %u size %u", va, fd, (uint32_t) size); + return AEE_EFAILED; + } + + m->base = (uint64_t) va; + m->fd = fd; + m->size = size; + + return AEE_SUCCESS; + } + } + + return AEE_ENOMEMORY; } -AEEResult htp_iface_disable_etm(remote_handle64 handle) { - int err = HAP_user_etm_disable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_disable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err); +AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + for (uint32_t i=0; immap[i]; + if (fd < 0 || m->fd == fd) { + FARF(HIGH, "unmmap : base %p fd %u size %u", (void*) m->base, m->fd, (uint32_t) m->size); +#if __HVX_ARCH__ > 73 + HAP_munmap2((void *) m->base, m->size); +#else + HAP_munmap((void *) m->base, m->size); +#endif + m->size = 0; + m->base = NULL; + m->fd = -1; } } - return err; + + return AEE_SUCCESS; } -static int vtcm_acquire(struct htp_context * ctx) { - int err; +static void vtcm_acquire(struct htp_context * ctx) { if (!ctx->vtcm_valid) { - // Temporarily bump thread priority to make sure it's higher than other sessions. - // This way the resource manager will notify the other thread to release VTCM. - // Note that we need to reaquire VTCM at normal priority for this to work next time. - qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10); - err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + int err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000u); if (err != 0) { - FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); + FARF(ERROR, "ggml-hex: failed to acquire VTCM: 0x%08x", (unsigned)err); abort(); } - HAP_compute_res_release_cached(ctx->vtcm_rctx); - qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio); - err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); - if (err != 0) { - FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); - abort(); - } + ctx->vtcm_needs_release = false; ctx->vtcm_valid = true; - } - ctx->vtcm_inuse = true; - return 0; + // Drop the priority to make sure we get the release callback from other GGML-HTP and QNN-HTP sessions + HAP_compute_res_update_priority(ctx->vtcm_rctx, ctx->thread_prio + 10); + } } -static int vtcm_release(struct htp_context * ctx) { - ctx->vtcm_inuse = false; - - if (ctx->vtcm_valid && ctx->vtcm_needs_release) { +static void vtcm_release(struct htp_context * ctx) { + if (ctx->vtcm_valid) { ctx->vtcm_valid = false; ctx->vtcm_needs_release = false; HAP_compute_res_release_cached(ctx->vtcm_rctx); } - - return 0; } static int vtcm_release_callback(unsigned int rctx, void * state) { struct htp_context * ctx = (struct htp_context *) state; - - if (!ctx || ctx->vtcm_rctx != rctx) { - return AEE_EBADPARM; - } - - // If VTCM is not inuse (not processing Ops) release it right here - // otherwise we'll release it once we're done with the current Op. - - if (ctx->vtcm_inuse) { - ctx->vtcm_needs_release = true; - return 0; - } - - ctx->vtcm_valid = false; - HAP_compute_res_release_cached(ctx->vtcm_rctx); - + ctx->vtcm_needs_release = true; return 0; } @@ -207,7 +317,7 @@ static int vtcm_alloc(struct htp_context * ctx) { HAP_compute_res_attr_init(&attr); HAP_compute_res_attr_set_serialize(&attr, 0); HAP_compute_res_attr_set_cache_mode(&attr, 1); - HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size); + HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); // single page HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); HAP_compute_res_attr_set_hmx_param(&attr, 1); @@ -229,7 +339,6 @@ static int vtcm_alloc(struct htp_context * ctx) { ctx->vtcm_size = vtcm_size; ctx->vtcm_rctx = rctx; ctx->vtcm_valid = false; - ctx->vtcm_inuse = false; ctx->vtcm_needs_release = false; return 0; @@ -246,7 +355,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) { +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -264,12 +373,12 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que htp_error_callback, // Error callback; no errors expected on the DSP (void *) ctx, // Callback context &ctx->queue); - if (err) { FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err); return err; } + ctx->max_vmem = max_vmem; ctx->thread_id = qurt_thread_get_id(); ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id); @@ -280,6 +389,19 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que return AEE_ENOMEMORY; } +#ifdef HTP_HAS_HMX + ctx->hmx_enabled = use_hmx; + ctx->hmx_queue = NULL; + if (use_hmx) { + ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx); + if (!ctx->hmx_queue) { + FARF(ERROR, "hmx-queue-create failed"); + ctx->hmx_enabled = false; + } + } + FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); +#endif + qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_get_max_hw_threads(&hw_threads); uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; @@ -297,7 +419,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 - ctx->dma[i] = dma_queue_create(64); + ctx->dma[i] = dma_queue_create(128); } // init worker pool @@ -341,6 +463,14 @@ AEEResult htp_iface_stop(remote_handle64 handle) { dma_queue_delete(ctx->dma[i]); } +#ifdef HTP_HAS_HMX + if (ctx->hmx_queue) { + hmx_queue_delete(ctx->hmx_queue); + ctx->hmx_queue = NULL; + } + ctx->hmx_enabled = false; +#endif + vtcm_free(ctx); return AEE_SUCCESS; @@ -354,846 +484,392 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) { struct profile_data { uint64_t usecs; uint64_t cycles; - uint64_t pkts; + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; }; -static inline void profile_start(struct profile_data * d) { - d->usecs = HAP_perf_get_qtimer_count(); - d->cycles = hex_get_cycles(); - d->pkts = hex_get_pktcnt(); +static inline void profile_start(uint32_t mode, struct profile_data * d) { + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(d->pmu_counters); + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_get_qtimer_count(); + d->cycles = hex_get_cycles(); + break; + default: + break; + } } -static inline void profile_stop(struct profile_data * d) { - d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); - d->cycles = hex_get_cycles() - d->cycles; - d->pkts = hex_get_pktcnt() - d->pkts; +static inline void profile_stop(uint32_t mode, struct profile_data * d) { + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(pmu_counters); + for (int i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + d->pmu_counters[i] = pmu_counters[i] - d->pmu_counters[i]; + } + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); + d->cycles = hex_get_cycles() - d->cycles; + break; + default: + break; + } } -static int send_htp_rsp(struct htp_context * c, - uint32_t op, - uint32_t status, - struct dspqueue_buffer * bufs, - size_t n_bufs, - struct profile_data * prof) { - // Prep response struct - struct htp_general_rsp rsp; - rsp.op = op; - rsp.status = status; - rsp.prof_usecs = prof->usecs; - rsp.prof_cycles = prof->cycles; - rsp.prof_pkts = prof->pkts; - - int err = dspqueue_write(c->queue, - 0, // Flags - n_bufs, - bufs, // Buffer references - sizeof(rsp), - (const uint8_t *) &rsp, // Message - DSPQUEUE_TIMEOUT_NONE); +static int execute_op(struct htp_ops_context * octx) { + switch (octx->op) { + case HTP_OP_MUL_MAT: + return op_matmul(octx); - if (err != 0) { - FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); - } + case HTP_OP_MUL_MAT_ID: + return op_matmul_id(octx); - return err; -} + case HTP_OP_MUL: + case HTP_OP_ADD: + case HTP_OP_SUB: + case HTP_OP_DIV: + case HTP_OP_ADD_ID: + return op_binary(octx); -static void proc_matmul_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_matmul(&octx); - vtcm_release(ctx); - } + case HTP_OP_RMS_NORM: + case HTP_OP_SCALE: + case HTP_OP_SQR: + case HTP_OP_SQRT: + case HTP_OP_UNARY_SOFTPLUS: + case HTP_OP_UNARY_SIGMOID: + case HTP_OP_UNARY_NEG: + case HTP_OP_UNARY_EXP: + return op_unary(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_UNARY_SILU: + case HTP_OP_UNARY_GELU: + case HTP_OP_GLU_SWIGLU: + case HTP_OP_GLU_SWIGLU_OAI: + case HTP_OP_GLU_GEGLU: + return op_activations(octx); -static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_argsort(&octx); - vtcm_release(ctx); - } + case HTP_OP_SOFTMAX: + return op_softmax(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_ROPE: + return op_rope(octx); -static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_cpy(&octx); - vtcm_release(ctx); - } + case HTP_OP_FLASH_ATTN_EXT: + return op_flash_attn_ext(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_SET_ROWS: + return op_set_rows(octx); -static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_get_rows(&octx); - vtcm_release(ctx); - } + case HTP_OP_GET_ROWS: + return op_get_rows(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_SUM_ROWS: + return op_sum_rows(octx); -static void proc_matmul_id_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[3].fd; - rsp_bufs[0].ptr = bufs[3].ptr; - rsp_bufs[0].size = bufs[3].size; - rsp_bufs[0].offset = bufs[3].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_matmul_id(&octx); - vtcm_release(ctx); - } + case HTP_OP_CPY: + return op_cpy(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_REPEAT: + return op_repeat(octx); -static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_binary(&octx); - vtcm_release(ctx); - } + case HTP_OP_ARGSORT: + return op_argsort(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_SSM_CONV: + return op_ssm_conv(octx); -static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[3].fd; - rsp_bufs[0].ptr = bufs[3].ptr; - rsp_bufs[0].offset = bufs[3].offset; - rsp_bufs[0].size = bufs[3].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_binary(&octx); - vtcm_release(ctx); - } + case HTP_OP_CUMSUM: + return op_cumsum(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_FILL: + return op_fill(octx); -static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_unary(&octx); - vtcm_release(ctx); - } + case HTP_OP_DIAG: + return op_diag(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_SOLVE_TRI: + return op_solve_tri(octx); -static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_sum_rows(&octx); - vtcm_release(ctx); + case HTP_OP_INVALID: + break; + + // No default to catch missing cases } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + FARF(ERROR, "Unknown Op %u", octx->op); + return -1; } -static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We've written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup OP context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_ssm_conv(&octx); - vtcm_release(ctx); +static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct htp_buf_desc *b) { + b->base = NULL; + + for (uint32_t i=0; immap + i; + if (m->size && m->fd == b->fd) { + b->base = m->base; + *m_reuse |= (1 << i); + return true; + } } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + return false; } -static void proc_activations_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - int write_idx = (n_bufs == 3) ? 2 : 1; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[write_idx].fd; - rsp_bufs[0].ptr = bufs[write_idx].ptr; - rsp_bufs[0].offset = bufs[write_idx].offset; - rsp_bufs[0].size = bufs[write_idx].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - if (3 == n_bufs) { - octx.src1 = req->src1; +static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { + if (m->size) { + FARF(HIGH, "unmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); +#if __HVX_ARCH__ > 73 + HAP_munmap2((void *) m->base, m->size); +#else + HAP_munmap((void *) m->base, m->size); +#endif + m->size = 0; + m->base = 0; + m->fd = -1; } - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - if (3 == n_bufs) { - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - } else { - octx.dst.data = (uint32_t) bufs[1].ptr; - } - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); +} - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - if (octx.op == HTP_OP_SOFTMAX) { - rsp_status = op_softmax(&octx); - } else { - rsp_status = op_activations(&octx); +static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { + if (b->base) return; // already mapped + + // find unused mapping + for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { + struct htp_mmap *m = &ctx->mmap[i]; + if (!m->size) { +#if __HVX_ARCH__ > 73 + void *va = HAP_mmap2(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); +#else + if (b->size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB + FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) b->size); + abort(); // can't do much else at this point + } + + void *va = HAP_mmap(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); +#endif + if (va == (void*)-1) { + FARF(ERROR, "mmap failed : va %p fd %u size %u", va, b->fd, (uint32_t) b->size); + abort(); // can't do much else at this point + } + + m->base = b->base = (uint64_t) va; + m->fd = b->fd; + m->size = b->size; + + FARF(HIGH, "mmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); + return; } - vtcm_release(ctx); } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } -static void proc_rope_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - int write_idx = n_bufs - 1; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[write_idx].fd; - rsp_bufs[0].ptr = bufs[write_idx].ptr; - rsp_bufs[0].offset = bufs[write_idx].offset; - rsp_bufs[0].size = bufs[write_idx].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - if (4 == n_bufs) { - octx.src2 = req->src2; - } - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - if (4 == n_bufs) { - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - } else { - octx.dst.data = (uint32_t) bufs[2].ptr; +static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uint32_t n_bufs) { + uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array) + uint32_t b_reuse = 0; // buf reuse count + + uint64_t m_vmem = 0; // mapped vmem + uint64_t e_vmem = 0; // extra vmem + + // See what we can reuse + for (uint32_t i=0; i < n_bufs; i++) { + struct htp_buf_desc *b = bufs + i; + if (reuse_buf(ctx, &m_reuse, b)) { b_reuse++; } else { e_vmem += b->size; } + FARF(HIGH, "prep-buf #%u : pass0 fd %u base %p size %u flags 0x%x", i, b->fd, (void*) b->base, (uint32_t) b->size, b->flags); } - octx.n_threads = ctx->n_threads; - struct profile_data prof; - profile_start(&prof); + if (b_reuse == n_bufs) return; // all bufs reuse existing mappings + + // See how much vmem we have mmaped right now + for (uint32_t i=0; immap[i].size; } - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_rope(&octx); - vtcm_release(ctx); + FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu max-vmem %zu : n-bufs %u b-reuse %u", + (size_t) m_vmem, (size_t) e_vmem, (size_t) ctx->max_vmem, n_bufs, b_reuse); + + if ((m_vmem + e_vmem) > ctx->max_vmem) { + // Drop unused mappings + for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { + bool used = m_reuse & (1<mmap + i); } + } } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + // Create missing mappings + for (uint32_t i=0; i < n_bufs; i++) { + struct htp_buf_desc *b = bufs + i; + mmap_buf(ctx, b); + FARF(HIGH, "prep-buf #%u : pass1 fd %u base %p size %u flags 0x%x", i, b->fd, (void*) b->base, (uint32_t) b->size, b->flags); + } } -static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_set_rows(&octx); - vtcm_release(ctx); - } +static void prep_tensor(struct htp_context *ctx, struct htp_buf_desc *bufs, uint32_t idx, struct htp_tensor *t) { + uint32_t offset = t->data; + uint32_t size = t->size; + uint32_t bi = t->bi; + + t->data = bufs[bi].base + offset; // update data to the actual pointer - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + FARF(HIGH, "prep-tensor #%u: bi %u offset %u size %u data %p : %u:%u:%u:%u", idx, t->bi, offset, t->size, (void*) t->data, + t->ne[0], t->ne[1], t->ne[3], t->ne[3]); } -static void proc_flash_attn_ext_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - // Setup Op context - struct htp_ops_context octx; - memset(&octx, 0, sizeof(octx)); - - octx.ctx = ctx; - octx.n_threads = ctx->n_threads; - - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.src3 = req->src3; - octx.src4 = req->src4; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - - int last_buf = 3; - - if (octx.src3.ne[0]) { - octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid +static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, struct htp_tensor *tens, uint32_t n_tens) { + for (uint32_t i=0; i < n_tens; i++) { + prep_tensor(ctx, bufs, i, tens + i); } +} - if (octx.src4.ne[0]) { - octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid - } +static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { + memcpy(octx->op_params, op->params, sizeof(octx->op_params)); + octx->flags = op->flags; + octx->op = op->opcode; + + FARF(HIGH, "proc-op #%u: opcode %u flags 0x%x", idx, octx->op, octx->flags); - octx.dst.data = (uint32_t) bufs[last_buf].ptr; + // Prep input tensors + for (uint32_t i=0; isrc[i] == 0xffff ? NULL : tens + op->src[i]; - struct profile_data prof; - profile_start(&prof); + octx->src[i] = src; + if (!src) continue; - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_flash_attn_ext(&octx); - vtcm_release(ctx); + if (!(src->flags & HTP_TENSOR_FLUSHED) && (src->flags & HTP_TENSOR_COMPUTE)) { + // flush compute buffers on input + hex_l2flush((void *) src->data, src->size); + } + + FARF(HIGH, "prep-src #%u: data %p size %u : %u:%u:%u:%u", op->src[i], (void*) src->data, src->size, + src->ne[0], src->ne[1], src->ne[3], src->ne[3]); } - profile_stop(&prof); + // Prep output tensor + struct htp_tensor *dst = tens + op->dst; + + octx->dst = dst; + + FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + + (void) execute_op(octx); - struct dspqueue_buffer rsp_buf = bufs[last_buf]; - rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + // flush buffers on output + hex_l2flush((void *) dst->data, dst->size); + dst->flags |= HTP_TENSOR_FLUSHED; - send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); + FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); } +#define DSPQUEUE_POLL_TIMEOUT_USEC 100 +#define DSPQUEUE_POLL_COUNT 100 + static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; - // Repeatedly read packets from the queue until it's empty. We don't - // necessarily get a separate callback for each packet, and new packets - // may arrive while we're processing the previous one. This ensures we - // keep the DSP busy as much as possible and avoid waiting for the CPU. + int err; + + uint32_t poll_count = DSPQUEUE_POLL_COUNT; - while (1) { - struct htp_general_req req; - uint32_t req_size; + vtcm_acquire(ctx); - struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - uint32_t n_bufs; - uint32_t flags; + while (!ctx->vtcm_needs_release) { + struct htp_opbatch_req req; + uint32_t r_size = sizeof(req); - // Read packet from queue - int err = dspqueue_read_noblock(queue, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(req), // Max message length - &req_size, // Message length - (uint8_t *) &req); // Message + struct dspqueue_buffer dbuf; + uint32_t n_dbufs = 1; + uint32_t flags = 0; + err = dspqueue_read_noblock(queue, &flags, n_dbufs, &n_dbufs, &dbuf, r_size, &r_size, (uint8_t *) &req); if (err == AEE_EWOULDBLOCK) { - // Consumed all packets available for now - return; + if (--poll_count) { + qurt_sleep(DSPQUEUE_POLL_TIMEOUT_USEC); + continue; + } + break; } if (err != 0) { FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err); - return; + break; } - if (req_size != sizeof(req)) { - FARF(ERROR, "Invalid request size"); + if (r_size < sizeof(req) || n_dbufs != 1) { + FARF(ERROR, "invalid request : size %u n-dbufs %u", r_size, n_dbufs); continue; } - if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) { - // Host wants early notification - dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); - } + // Reset poll count for valid requests + poll_count = DSPQUEUE_POLL_COUNT; - // Process packet based on its message type - switch (req.op) { - case HTP_OP_MUL_MAT: - if (n_bufs != 3) { - FARF(ERROR, "Bad matmul-req buffer list"); - continue; - } - proc_matmul_req(ctx, &req, bufs, n_bufs); - break; + const uint32_t n_bufs = req.n_bufs; + const uint32_t n_tens = req.n_tensors; + const uint32_t n_ops = req.n_ops; - case HTP_OP_MUL_MAT_ID: - if (n_bufs != 4) { - FARF(ERROR, "Bad matmul-id-req buffer list"); - continue; - } - proc_matmul_id_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_MUL: - case HTP_OP_ADD: - case HTP_OP_SUB: - case HTP_OP_DIV: - if (n_bufs != 3) { - FARF(ERROR, "Bad binary-req buffer list"); - continue; - } - proc_binary_req(ctx, &req, bufs); - break; - - case HTP_OP_RMS_NORM: - case HTP_OP_SCALE: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } + const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; + const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; + const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops; - proc_unary_req(ctx, &req, bufs); - break; + if (dbuf.size < b_size + t_size + o_size + p_size) { + FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size); + break; + } - case HTP_OP_SQR: - case HTP_OP_SQRT: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } + FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id, + n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); - proc_unary_req(ctx, &req, bufs); - break; + // Setup descriptor pointers + uint8_t * m_ptr = dbuf.ptr; + struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; + struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; + struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; m_ptr += o_size; + struct htp_prof_desc* pds = (struct htp_prof_desc*) m_ptr; - case HTP_OP_SUM_ROWS: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } + prep_op_bufs(ctx, bufs, n_bufs); + prep_tensors(ctx, bufs, tens, n_tens); - proc_sum_rows_req(ctx, &req, bufs); - break; + struct htp_ops_context *octx = &ctx->octx; + memset(octx, 0, sizeof(*octx)); + octx->n_threads = ctx->n_threads; + octx->ctx = ctx; - case HTP_OP_UNARY_SILU: - case HTP_OP_UNARY_GELU: - if (n_bufs != 2) { - FARF(ERROR, "Bad act-req buffer list"); - continue; - } - proc_activations_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_GLU_SWIGLU: - case HTP_OP_GLU_SWIGLU_OAI: - case HTP_OP_SOFTMAX: - case HTP_OP_GLU_GEGLU: - if ((n_bufs != 2) && (n_bufs != 3)) { - FARF(ERROR, "Bad act-req buffer list"); - continue; - } - proc_activations_req(ctx, &req, bufs, n_bufs); - break; + for (uint32_t i=0; i < n_ops; i++) { + struct profile_data prof; - case HTP_OP_ADD_ID: - if (n_bufs != 4) { - FARF(ERROR, "Bad add-id-req buffer list"); - continue; - } - proc_add_id_req(ctx, &req, bufs); - break; + profile_start(ctx->profiler, &prof); - case HTP_OP_ROPE: - if ((n_bufs != 3) && (n_bufs != 4)) { - FARF(ERROR, "Bad rope-req buffer list"); - continue; - } - proc_rope_req(ctx, &req, bufs, n_bufs); - break; + proc_op_req(octx, tens, i, &ops[i]); - case HTP_OP_FLASH_ATTN_EXT: - if (!(n_bufs >= 4 && n_bufs <= 6)) { - FARF(ERROR, "Bad flash-attn-ext-req buffer list"); - continue; - } - proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); - break; + profile_stop(ctx->profiler, &prof); - case HTP_OP_SET_ROWS: - if (n_bufs != 3) { - FARF(ERROR, "Bad set-rows-req buffer list"); - continue; + if (ctx->profiler) { + pds[i].opcode = ops[i].opcode; + pds[i].usecs = prof.usecs; + pds[i].cycles = prof.cycles; + for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) { + pds[i].pmu[j] = prof.pmu_counters[j]; } - proc_set_rows_req(ctx, &req, bufs); - break; - - case HTP_OP_GET_ROWS: - if (n_bufs != 3) { - FARF(ERROR, "Bad get-rows-req buffer list"); - continue; - } - proc_get_rows_req(ctx, &req, bufs); - break; + } + } - case HTP_OP_CPY: - if (n_bufs != 2) { - FARF(ERROR, "Bad cpy-req buffer list"); - continue; - } - proc_cpy_req(ctx, &req, bufs); - break; + // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); - case HTP_OP_ARGSORT: - if (n_bufs != 2) { - FARF(ERROR, "Bad argsort-req buffer list"); - continue; - } - proc_argsort_req(ctx, &req, bufs); - break; + struct htp_opbatch_rsp rsp; + rsp.id = req.id; + rsp.status = HTP_STATUS_OK; + rsp.n_bufs = n_bufs; + rsp.n_tensors = n_tens; + rsp.n_ops = n_ops; - case HTP_OP_SSM_CONV: - if (n_bufs != 3) { - FARF(ERROR, "Bad ssm-conv-req buffer list"); - continue; - } - proc_ssm_conv_req(ctx, &req, bufs); - break; + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - default: - FARF(ERROR, "Unknown Op %u", req.op); - break; + err = dspqueue_write(queue, 0, 1, &dbuf, sizeof(rsp), (const uint8_t *) &rsp, DSPQUEUE_TIMEOUT_NONE); + if (err != 0) { + FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); + break; } } + + vtcm_release(ctx); } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 73aaba79ebf..a0c265132c8 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -16,8 +16,9 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" #include "htp-ops.h" +#include "htp-ops.h" +#include "hmx-ops.h" #define MM_SPAD_SRC0_NROWS 16 #define MM_SPAD_SRC1_NROWS 16 @@ -60,6 +61,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20, }; +// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue +// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 +static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = { + 0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0, + 0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0, 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -68,6 +79,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; +static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); + r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); + } + + return r; +} + // q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales static inline size_t q8x4x2_row_size(uint32_t ne) { @@ -921,6 +999,293 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } +// ======== IQ4_NL x Q8_0 vec_dot kernels ======== +// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue). +// Scale format is identical to Q4_0 (fp16 scales). + +static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, + float * restrict s0, + float * restrict s1, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0, + const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); +} + static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); @@ -1533,11 +1898,11 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -#define htp_matmul_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict src1 = &octx->src1; \ - struct htp_tensor * restrict src2 = &octx->src2; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_matmul_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict src2 = octx->src[2]; \ + const struct htp_tensor * restrict dst = octx->dst; \ struct htp_spad * restrict src0_spad = &octx->src0_spad; \ struct htp_spad * restrict src1_spad = &octx->src1_spad; \ struct htp_spad * restrict dst_spad = &octx->dst_spad; \ @@ -1859,8 +2224,8 @@ struct mmid_row_mapping { static void matmul_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_tensor * restrict ids = &octx->src2; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict ids = octx->src[2]; + struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1978,8 +2343,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { static void matvec_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_tensor * restrict ids = &octx->src2; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict ids = octx->src[2]; + struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -2248,7 +2613,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; struct htp_spad * spad = &octx->src0_spad; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; @@ -2295,7 +2660,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint32_t dst_stride = octx->src1_spad.stride; @@ -2337,7 +2702,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint32_t dst_stride = octx->src1_spad.stride; @@ -2393,6 +2758,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; return 0; + case HTP_TYPE_IQ4_NL: + mmctx->type = "iq4nlx4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + return 0; case HTP_TYPE_MXFP4: mmctx->type = "mxfp4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; @@ -2430,7 +2801,7 @@ static void htp_mminit_spad(struct htp_ops_context * octx, octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; } -int op_matmul(struct htp_ops_context * octx) { +static int op_matmul_hvx(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; struct htp_matmul_context mmctx_struct = {0}; @@ -2454,7 +2825,7 @@ int op_matmul(struct htp_ops_context * octx) { worker_callback_t quant_job_func; worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; - bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); + bool need_quant = true; if (src0->type == HTP_TYPE_F16) { // Try optimized f16-f16 path first (src1 in VTCM) @@ -2468,7 +2839,7 @@ int op_matmul(struct htp_ops_context * octx) { // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { // Optimized path @@ -2545,27 +2916,176 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Place src1 spad first. We use it for dyn.quant and may reuse between ops + octx->src1_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; - if (need_quant) { + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + if (need_quant && !octx->src1_spad.src) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + octx->src1_spad.src = src1; } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); - } + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); return HTP_STATUS_OK; } +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; + +#ifndef HTP_HAS_HMX + return op_matmul_hvx(octx); +#else + if (!octx->ctx->hmx_enabled) { + return op_matmul_hvx(octx); + } + + // HMX weight tile requires N to be 32-aligned. + if (src0->ne[1] % 32 != 0) { + return op_matmul_hvx(octx); + } + + // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // Other types fall back to HVX. + uint32_t wtype = src0->type; + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + return op_matmul_hvx(octx); + } + + // Quantised HMX path requires K aligned to 256 (x4x2 super-block). + // F16 HMX path requires K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) { + return op_matmul_hvx(octx); + } + + if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) { + return op_matmul_hvx(octx); + } + + const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); + + // Quantised HMX kernels only handle flat 2D matmul (host already rejects + // batched quantised, but guard here too). F16 batched matmul is handled + // by the dedicated wrapper in hmx-matmul-ops.c. + if (is_batched && src0->type != HTP_TYPE_F16) { + return op_matmul_hvx(octx); + } + + // HMX assumes contiguous row-major layout. Fall back for permuted + // tensors where strides are non-monotonic (e.g. transposed KV cache). + if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { + return op_matmul_hvx(octx); + } + + // M alignment: when M > 32 but not 32-aligned, we split into + // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). + // When M <= 32 and not 32-aligned, fall back entirely to HVX. + const int m_total = (int) src1->ne[1]; + const int m_tail = m_total % 32; + const int m_hmx = m_total - m_tail; + + if (m_hmx == 0) { + return op_matmul_hvx(octx); + } + + // Always re-quantize src1 since HMX kernel overwrites vtcm/spad, + // so any previously cached quantized data is invalid. + octx->src1_spad.src = NULL; + + int k = (int) src0->ne[0]; // inner dimension + int n = (int) src0->ne[1]; // weight columns + + // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- + int ret = -1; + + // Row strides in elements. For compact tensors these equal k; for + // permuted attention views they can be larger, so pass the real stride. + const int act_stride = (int)(src1->nb[1] / sizeof(float)); + const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + if (src0->type == HTP_TYPE_F16) { + if (is_batched) { + hmx_matmul_w16a32_batched_params_t batch_params = { + .dst = (float *) dst->data, + .activation = (float *) src1->data, + .permuted_weight = (const __fp16 *) src0->data, + .m = m_hmx, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = wgt_stride, + .dst_stride = (int) (dst->nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = src0->nb[2], + .src0_nb3 = src0->nb[3], + .src1_nb2 = src1->nb[2], + .src1_nb3 = src1->nb[3], + .dst_nb2 = dst->nb[2], + .dst_nb3 = dst->nb[3], + }; + ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params); + } else { + ret = hmx_mat_mul_permuted_w16a32(octx->ctx, + (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, + m_hmx, k, n, act_stride, wgt_stride); + } + } else { + ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, + (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_hmx, k, n, (int) src0->type); + } + + if (ret != 0) { + FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); + return op_matmul(octx); + } + + // --- Phase 2: HVX on the remaining m_tail rows --- + if (m_tail > 0) { + // copy of src1 and dst + struct htp_tensor src1_tail = *src1; + struct htp_tensor dst_tail = *dst; + + src1_tail.ne[1] = m_tail; // only tail rows + dst_tail.ne[1] = m_tail; // only tail rows + + // Offset activation and dst pointers past the HMX-processed rows. + // Use nb[1] (row stride in bytes) to compute the byte offset. + src1_tail.data += (uint32_t) m_hmx * src1->nb[1]; + dst_tail.data += (uint32_t) m_hmx * dst->nb[1]; + + octx->src[1] = &src1_tail; + octx->dst = &dst_tail; + + FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data); + return op_matmul_hvx(octx); + } + + return 0; +#endif // HTP_HAS_HMX +} + int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; @@ -2573,7 +3093,7 @@ int op_matmul_id(struct htp_ops_context * octx) { struct htp_matmul_context * mmctx = &mmctx_struct; mmctx->octx = octx; - struct htp_tensor * restrict ids = &octx->src2; + const struct htp_tensor * restrict ids = octx->src[2]; const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -2626,11 +3146,17 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops. + octx->src1_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; + octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->dst_spad.src = NULL; + octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; @@ -2654,17 +3180,18 @@ int op_matmul_id(struct htp_ops_context * octx) { } } - // Setup worker pool callbacks - if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + if (octx->src1_spad.src != src1) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + octx->src1_spad.src = src1; } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); - } + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/repeat-ops.c b/ggml/src/ggml-hexagon/htp/repeat-ops.c new file mode 100644 index 00000000000..a6f2f0ed5f3 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/repeat-ops.c @@ -0,0 +1,148 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "htp-ops.h" + +struct htp_repeat_context { + struct htp_ops_context * octx; + + uint32_t nr0; + uint32_t nr1; + uint32_t nr2; + uint32_t nr3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; // ne1 * ne2 * ne3 + + size_t type_size; +}; + +static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data; + struct htp_ops_context * octx = rctx->octx; + const struct htp_tensor * src = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t ne00 = src->ne[0]; + const uint32_t ne01 = src->ne[1]; + const uint32_t ne02 = src->ne[2]; + const uint32_t ne03 = src->ne[3]; + + const uint32_t nb00 = src->nb[0]; + const uint32_t nb01 = src->nb[1]; + const uint32_t nb02 = src->nb[2]; + const uint32_t nb03 = src->nb[3]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb0 = dst->nb[0]; + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t nr0 = rctx->nr0; + const uint32_t nr1 = rctx->nr1; + const uint32_t nr2 = rctx->nr2; + const uint32_t nr3 = rctx->nr3; + + const size_t row_bytes = ne00 * rctx->type_size; + + const uint32_t row_start = rctx->nrows_per_thread * ith; + const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + // Decompose flat dst row index into (i1, i2, i3) + const uint32_t i1 = dst_row % ne1; + const uint32_t i2 = (dst_row / ne1) % ne2; + const uint32_t i3 = dst_row / (ne1 * ne2); + + // Map to source indices (tiling) + const uint32_t k1 = i1 % ne01; + const uint32_t k2 = i2 % ne02; + const uint32_t k3 = i3 % ne03; + + const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03; + uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + // Tile along dimension 0 + for (uint32_t i0 = 0; i0 < nr0; i0++) { + uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0; + memcpy(dst_ptr, src_row, row_bytes); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_repeat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + // Validate that dst dims are multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0 || + dst->ne[1] % src0->ne[1] != 0 || + dst->ne[2] % src0->ne[2] != 0 || + dst->ne[3] % src0->ne[3] != 0) { + FARF(ERROR, "repeat: dst dims must be multiples of src dims\n"); + return HTP_STATUS_INVAL_PARAMS; + } + + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + case HTP_TYPE_F16: type_size = 2; break; + default: + FARF(ERROR, "repeat: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows); + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + struct htp_repeat_context rctx = { + .octx = octx, + .nr0 = dst->ne[0] / src0->ne[0], + .nr1 = dst->ne[1] / src0->ne[1], + .nr2 = dst->ne[2] / src0->ne[2], + .nr3 = dst->ne[3] / src0->ne[3], + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + }; + + FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3); + + worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index be9469538f6..1d8b0796bc9 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -15,7 +15,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" // Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h @@ -253,10 +253,10 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { struct htp_rope_context * rctx = (struct htp_rope_context *) data; struct htp_ops_context * octx = rctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; htp_rope_preamble; @@ -284,7 +284,7 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { dma_queue * dma_queue = octx->ctx->dma[ith]; const int32_t * pos = (const int32_t *) src1->data; - const float * freq_factors = src2->data ? (const float *) src2->data : NULL; + const float * freq_factors = src2 ? (const float *) src2->data : NULL; uint32_t ir = 0; uint32_t prev_i2 = (uint32_t) -1; @@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); } - // Skip DMA transactions from prev block (if any) - // No need to wait for these since the DMA is setup for in-order processing + // Skip output DMA transactions from prev block (if any) + // No need to wait for those here since we're explicitly waiting for the latest prefecthes below. for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); } // Compute loop @@ -384,10 +384,10 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { static int execute_op_rope_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; const char * op_type = "rope-f32"; @@ -424,19 +424,16 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - // Assign sizes octx->src0_spad.size_per_thread = src0_spad_per_thread; octx->dst_spad.size_per_thread = dst_spad_per_thread; octx->src0_spad.size = n_threads * src0_spad_per_thread; octx->dst_spad.size = n_threads * dst_spad_per_thread; octx->src1_spad.size = 0; - // Assign pointers - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = NULL; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = NULL; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; - // Fill context struct htp_rope_context rctx; memset(&rctx, 0, sizeof(struct htp_rope_context)); @@ -483,7 +480,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { int op_rope(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_rope_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 4b6967749f8..0def7b408bf 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -14,33 +14,37 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define set_rows_preamble \ - const uint32_t ne00 = octx->src0.ne[0]; \ - const uint32_t ne01 = octx->src0.ne[1]; \ - const uint32_t ne02 = octx->src0.ne[2]; \ - const uint32_t ne03 = octx->src0.ne[3]; \ - \ - const uint32_t ne10 = octx->src1.ne[0]; \ - const uint32_t ne11 = octx->src1.ne[1]; \ - const uint32_t ne12 = octx->src1.ne[2]; \ - \ - const uint32_t nb01 = octx->src0.nb[1]; \ - const uint32_t nb02 = octx->src0.nb[2]; \ - const uint32_t nb03 = octx->src0.nb[3]; \ - \ - const uint32_t nb10 = octx->src1.nb[0]; \ - const uint32_t nb11 = octx->src1.nb[1]; \ - const uint32_t nb12 = octx->src1.nb[2]; \ - \ - const uint32_t nb1 = octx->dst.nb[1]; \ - const uint32_t nb2 = octx->dst.nb[2]; \ - const uint32_t nb3 = octx->dst.nb[3]; \ - \ - const uint32_t ne1 = octx->dst.ne[1]; \ - \ +#define set_rows_preamble \ + const uint32_t ne00 = octx->src[0]->ne[0]; \ + const uint32_t ne01 = octx->src[0]->ne[1]; \ + const uint32_t ne02 = octx->src[0]->ne[2]; \ + const uint32_t ne03 = octx->src[0]->ne[3]; \ + \ + const uint32_t ne10 = octx->src[1]->ne[0]; \ + const uint32_t ne11 = octx->src[1]->ne[1]; \ + const uint32_t ne12 = octx->src[1]->ne[2]; \ + const uint32_t ne13 = octx->src[1]->ne[3]; \ + \ + const uint32_t nb01 = octx->src[0]->nb[1]; \ + const uint32_t nb02 = octx->src[0]->nb[2]; \ + const uint32_t nb03 = octx->src[0]->nb[3]; \ + \ + const uint32_t nb10 = octx->src[1]->nb[0]; \ + const uint32_t nb11 = octx->src[1]->nb[1]; \ + const uint32_t nb12 = octx->src[1]->nb[2]; \ + \ + const uint32_t nb1 = octx->dst->nb[1]; \ + const uint32_t nb2 = octx->dst->nb[2]; \ + const uint32_t nb3 = octx->dst->nb[3]; \ + \ + const uint32_t ne0 = octx->dst->ne[0]; \ + const uint32_t ne1 = octx->dst->ne[1]; \ + const uint32_t ne2 = octx->dst->ne[2]; \ + const uint32_t ne3 = octx->dst->ne[3]; \ + \ const uint32_t nr = ne01; struct htp_set_rows_context { @@ -56,12 +60,14 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da set_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { @@ -70,7 +76,7 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i1 >= ne1) { @@ -78,14 +84,18 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; - const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + const uintptr_t src0_ptr = octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3; // copy row hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } } } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "set-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) { @@ -94,12 +104,14 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da set_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { @@ -108,7 +120,7 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i1 >= ne1) { @@ -116,13 +128,17 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; - uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + const uint8_t* src0_ptr = (const uint8_t *) octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3; hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); } } } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "set-rows-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_set_rows(struct htp_ops_context * octx) { @@ -130,15 +146,15 @@ int op_set_rows(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(nr, octx->n_threads); - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + if (octx->dst->type != HTP_TYPE_F32 && octx->dst->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } - if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) { return HTP_STATUS_NO_SUPPORT; } @@ -153,7 +169,7 @@ int op_set_rows(struct htp_ops_context * octx) { srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; - switch(octx->dst.type) { + switch(octx->dst->type) { case HTP_TYPE_F32: worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads); break; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 8dae7f1ed55..d78bcc0eb24 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -15,68 +15,89 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define htp_softmax_preamble3 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \ - const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \ - const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \ - const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \ - \ - const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \ - const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \ - const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \ - const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_softmax_preamble3 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = src1 ? src1->ne[0] : 1; \ + const uint32_t ne11 = src1 ? src1->ne[1] : 1; \ + const uint32_t ne12 = src1 ? src1->ne[2] : 1; \ + const uint32_t ne13 = src1 ? src1->ne[3] : 1; \ + \ + const uint32_t nb10 = src1 ? src1->nb[0] : 1; \ + const uint32_t nb11 = src1 ? src1->nb[1] : 1; \ + const uint32_t nb12 = src1 ? src1->nb[2] : 1; \ + const uint32_t nb13 = src1 ? src1->nb[3] : 1; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_softmax_context { + struct htp_ops_context * octx; + bool use_f16; bool use_src1; + uint32_t n_head; uint32_t n_head_log2; - float scale; - float max_bias; - float m0; - float m1; + float scale; + float max_bias; + float m0; + float m1; - uint32_t src0_nrows_per_thread; struct fastdiv_values fastdiv_ne01; struct fastdiv_values fastdiv_ne02; struct fastdiv_values fastdiv_ne12; // For mask broadcasting struct fastdiv_values fastdiv_ne13; // For mask broadcasting - size_t spad_stride; - struct htp_ops_context * octx; + uint32_t src0_nrows_per_thread; }; +static void apply_mask(float * restrict wp0, + const float * restrict mp_f32, + const __fp16 * restrict mp_f16, + uint32_t ne00, + float slope, + bool use_f16) { + if (!mp_f32) { + return; + } + if (use_f16) { + for (uint32_t i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (uint32_t i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } +} + static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; memset(smctx, 0, sizeof(struct htp_softmax_context)); - memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); smctx->n_head = src0->ne[2]; @@ -85,8 +106,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_ smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2); smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2); - smctx->use_src1 = (src1->ne[0] != 0); - smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + smctx->use_src1 = (src1 != 0); + smctx->use_f16 = (src1 != 0) && (src1->type == HTP_TYPE_F16); smctx->octx = octx; @@ -97,8 +118,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_ if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01); if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02); - const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; - const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; + const uint32_t ne12 = src1 ? src1->ne[2] : 1; + const uint32_t ne13 = src1 ? src1->ne[3] : 1; if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12); if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13); @@ -139,10 +160,7 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, } } -static void hvx_fast_softmax_f32(const uint8_t * restrict src, - uint8_t * restrict dst, - uint8_t * restrict pad, - const int num_elems) { +static void hvx_fast_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, const int num_elems) { const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_pad = (HVX_Vector *) pad; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; @@ -188,27 +206,20 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, } } -static float hvx_softmax_f32(const uint8_t * restrict src, - uint8_t * restrict dst, - uint8_t * restrict spad, - const int num_elems, - const float max) { +static float hvx_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict spad, const int num_elems, const float max) { hvx_sub_scalar_f32(spad, src, max, num_elems); - hvx_exp_f32(spad, dst, num_elems, false); - - float sum = hvx_reduce_sum_f32(dst, num_elems); - - return sum; + hvx_exp_f32(dst, spad, num_elems, false); + return hvx_reduce_sum_f32(dst, num_elems); } static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { struct htp_softmax_context * smctx = (struct htp_softmax_context *) data; struct htp_ops_context * octx = smctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; htp_softmax_preamble3; @@ -223,22 +234,26 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { return; } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + uint64_t qt = HAP_perf_get_qtimer_count(); int is_aligned = 1; int opt_path = 0; + if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) { is_aligned = 0; FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n"); } + + // Only use the fast path when aligned AND row size is multiple of VLEN (128 bytes) + // The fast path (hvx_fast_softmax_f32) doesn't handle tail elements + // The non-opt path uses hvx_softmax_f32 which properly handles all sizes via its helper functions if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { opt_path = 1; } - uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride); - uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride); - uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride); + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); float * wp0 = (float *) src0_spad_data; float * wp1 = (float *) src1_spad_data; @@ -278,47 +293,29 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { // ALiBi if (i2 != prev_i2) { const uint32_t h = i2; // head - - slope = (smctx->max_bias > 0.0f) ? - h < smctx->n_head_log2 ? - powf(smctx->m0, h + 1) : - powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : - 1.0f; + slope = (smctx->max_bias > 0.0f) ? h < smctx->n_head_log2 ? powf(smctx->m0, h + 1) : powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : 1.0f; prev_i2 = i2; } - float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03); - float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3); + float * sp = (float *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + float * dp = (float *) ((char *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3); // broadcast the mask across rows - __fp16 * mp_f16 = (smctx->use_src1) ? - (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - float * mp_f32 = (smctx->use_src1) ? - (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; + __fp16 * mp_f16 = (smctx->use_src1) ? (__fp16 *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL; + float * mp_f32 = (smctx->use_src1) ? (float *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL; if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) { - hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, - (const uint8_t *) mp_f32, slope); - } else { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, (const uint8_t *) mp_f32, slope); + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else if (1 == opt_path) { hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); - if (mp_f32) { - if (smctx->use_f16) { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * (float) mp_f16[i]; - } - } else { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * mp_f32[i]; - } - } - } - } - - if (1 == opt_path) { + apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16); hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); } else { + // Non-optimized path: uses HVX helper functions that properly handle all tensor sizes + // including non-multiples of 32 (the HVX vector lane count for f32) + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); + apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16); float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); sum = sum > 0.0 ? (1.0 / sum) : 1; @@ -326,54 +323,47 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { } } - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, - ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "softmax-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u : opt %u f16 %u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + ne0, ne1, ne2, ne3, opt_path, smctx->use_f16, (unsigned) qt); } static int execute_op_softmax_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; struct htp_softmax_context smctx; const char * op_type = "softmax-f32"; - switch (octx->op) { - case HTP_OP_SOFTMAX: - init_softmax_ctx(&smctx, octx); - break; - - default: - FARF(ERROR, "Unsupported Op %u\n", octx->op); - return HTP_STATUS_NO_SUPPORT; - } + init_softmax_ctx(&smctx, octx); const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); + smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + const size_t src0_row_size = src0->nb[1]; const size_t src1_row_size = src0_row_size; const size_t dst_row_size = dst->nb[1]; // VTCM scratchpads for all tensors - // N rows per thread, padded to HVX vector size - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // 4 rows per thread, padded to HVX vector size + octx->src0_spad.size_per_thread = hex_round_up(4 * src0_row_size, 128); + octx->src1_spad.size_per_thread = hex_round_up(4 * src1_row_size, 128); + octx->dst_spad.size_per_thread = hex_round_up(4 * dst_row_size, 128); - // Use stride for calculating offset - smctx.spad_stride = hex_round_up(src0_row_size, 128); + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; - if (src1->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + if (src1) { + FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); @@ -385,19 +375,17 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; - worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); - } + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return err; + + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); return err; } @@ -405,7 +393,7 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { int op_softmax(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_softmax_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/solve-tri-ops.c b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c new file mode 100644 index 00000000000..ae8e1a50495 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" + +struct htp_solve_tri_context { + struct htp_ops_context * octx; + uint32_t jobs_per_thread; + uint32_t total_jobs; + uint32_t k_chunks; + uint32_t col_block; +}; + +static inline void solve_tri_row_scalar(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + for (uint32_t col = col0; col < col0 + coln; ++col) { + float sum = 0.0f; + for (uint32_t t = 0; t < row; ++t) { + sum += A_row[t] * X[t * k + col]; + } + X[row * k + col] = (B_row[col] - sum) * inv_diag; + } +} + +static inline HVX_Vector hvx_load_partial_f32(const float * src, uint32_t n) { + HVX_Vector v = *((const HVX_UVector *) src); + HVX_VectorPred mask = Q6_Q_vsetq2_R(n * sizeof(float)); + return Q6_V_vmux_QVV(mask, v, Q6_V_vzero()); +} + +static inline void solve_tri_row_hvx(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + const bool full = (coln == VLEN_FP32); + + HVX_Vector sum_v = Q6_V_vzero(); + for (uint32_t t = 0; t < row; ++t) { + const float a = A_row[t]; + const float * x_row_col = X + t * k + col0; + + HVX_Vector x_v = full ? *((const HVX_UVector *) x_row_col) : hvx_load_partial_f32(x_row_col, coln); + HVX_Vector a_v = hvx_vec_splat_f32(a); + sum_v = hvx_vec_add_f32_f32(sum_v, hvx_vec_mul_f32_f32(x_v, a_v)); + } + + const float * b_row_col = B_row + col0; + float * x_out_col = X + row * k + col0; + + HVX_Vector b_v = full ? *((const HVX_UVector *) b_row_col) : hvx_load_partial_f32(b_row_col, coln); + HVX_Vector inv_diag_v = hvx_vec_splat_f32(inv_diag); + + HVX_Vector out_v = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(b_v, sum_v), inv_diag_v); + hvx_vec_store_u((void *) x_out_col, coln * sizeof(float), out_v); +} + +// Batch-level thread: each job is one full batch. +static void solve_tri_batch_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_full = (k / col_block) * col_block; + + const uint32_t start_batch = sctx->jobs_per_thread * ith; + const uint32_t end_batch = MIN(start_batch + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t batch = start_batch; batch < end_batch; ++batch) { + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + uint32_t col0 = 0; + for (; col0 < k_full; col0 += col_block) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, col_block, inv_diag); + } + + if (col0 < k) { + const uint32_t coln = k - col0; + if (coln >= 8) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-batch %d/%d: A=(%ux%u) B=(%ux%u) batch %u:%u usec %u\n", + ith, nth, n, n, k, n, start_batch, end_batch, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// Chunk-level thread: each job is one (batch, col_chunk) pair. +static void solve_tri_chunk_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t start_job = sctx->jobs_per_thread * ith; + const uint32_t end_job = MIN(start_job + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t job = start_job; job < end_job; ++job) { + const uint32_t batch = job / sctx->k_chunks; + const uint32_t chunk = job - batch * sctx->k_chunks; + + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const uint32_t col0 = chunk * sctx->col_block; + const uint32_t coln = MIN(sctx->col_block, k - col0); + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + const bool use_hvx = (coln >= 8); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + if (use_hvx) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-chunk %d/%d: A=(%ux%u) B=(%ux%u) job %u:%u usec %u\n", + ith, nth, n, n, k, n, start_job, end_job, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_solve_tri(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // left=true, lower=true, uni=false only + if (src0->ne[0] != src0->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[1] != src1->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || + dst->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t k = src1->ne[0]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_chunks = (k + col_block - 1) / col_block; + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const bool batched = total_batches >= (uint32_t) octx->n_threads; + + FARF(HIGH, "solve-tri: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : batched %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], batched); + + if (batched) { + // Batch-level parallelism + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, total_batches); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_jobs = total_batches, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_batch_thread_f32, &sctx, n_threads); + } else { + // Chunk-level parallelism + const uint32_t total_jobs = total_batches * k_chunks; + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, MAX(total_jobs, 1)); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_jobs + n_threads - 1) / n_threads, + .total_jobs = total_jobs, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_chunk_thread_f32, &sctx, n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index b3c1ef9572e..a28fd03e978 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -16,14 +16,14 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "hex-dma.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" -#define htp_ssm_conv_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict src1 = &octx->src1; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_ssm_conv_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict dst = octx->dst; \ struct htp_spad * restrict src0_spad = &octx->src0_spad; \ struct htp_spad * restrict src1_spad = &octx->src1_spad; \ struct htp_spad * restrict dst_spad = &octx->dst_spad; \ @@ -151,7 +151,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void const int dr = scctx->nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = MIN(ir0 + dr, d_inner); - const int ir = ir1 - ir0; + const uint32_t ir = ir1 - ir0; if (ir0 >= ir1) { return; // No work for this thread @@ -205,10 +205,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc_vec = Q6_V_vsplat_R(0); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), - src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), - src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); + uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); + Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); @@ -222,10 +222,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc_vec = Q6_V_vsplat_R(0); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), - src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), - src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); + uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); + Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); @@ -289,9 +289,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { // Compute gather scratchpad size for src0 and src1 const size_t gather_spad_size = n_threads * VLEN * 2; - octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, @@ -323,8 +323,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { } int op_ssm_conv(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; switch (dst->type) { case HTP_TYPE_F32: diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 352650b689b..874c41ab2ac 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -14,13 +14,13 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define sum_rows_preamble \ - struct htp_tensor *src0 = &octx->src0;\ - struct htp_tensor *dst = &octx->dst; \ - \ +#define sum_rows_preamble \ + const struct htp_tensor *src0 = octx->src[0]; \ + const struct htp_tensor *dst = octx->dst; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ @@ -94,7 +94,7 @@ static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) int op_sum_rows(struct htp_ops_context * octx) { sum_rows_preamble; - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 5bbd5040d3d..819cdc49bd9 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -9,12 +9,14 @@ #include #include "hex-dma.h" +#include "hvx-exp.h" +#include "hvx-sigmoid.h" #include "hvx-utils.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" struct htp_unary_context { @@ -24,8 +26,8 @@ struct htp_unary_context { const uint8_t * data_src0; uint8_t * data_dst; - size_t src0_row_size; - size_t dst_row_size; + size_t src0_data_row_size; // actual data bytes per row + size_t dst_data_row_size; // actual data bytes per row size_t src0_row_size_aligned; size_t dst_row_size_aligned; @@ -39,6 +41,40 @@ struct htp_unary_context { uint32_t nc; }; +// Convert flat row index to DDR byte offset using the tensor's actual strides. +// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3 +static inline size_t unary_row_offset(uint32_t ir, + uint32_t ne1, uint32_t ne2, + size_t nb1, size_t nb2, size_t nb3) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + return i1 * nb1 + i2 * nb2 + i3 * nb3; +} +// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice +// boundary of src and dst so the nb1 stride stays valid for all rows. +static inline uint32_t unary_block_size(uint32_t ir, + uint32_t end_row, + uint32_t block, + bool src_contig, + bool dst_contig, + uint32_t src_ne1, + uint32_t dst_ne1) { + uint32_t limit = MIN(block, end_row - ir); + + if (!src_contig) { + const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1; + limit = MIN(limit, src_slice_end - ir); + } + + if (!dst_contig) { + const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1; + limit = MIN(limit, dst_slice_end - ir); + } + + return limit; +} + #define htp_unary_preamble \ const uint32_t ne00 = src->ne[0]; \ const uint32_t ne01 = src->ne[1]; \ @@ -65,34 +101,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, uint8_t * restrict pad, const int num_elems, float epsilon) { + (void)pad; + const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; - HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); - int step_of_1 = num_elems >> 5; #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); - sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + // Scale full vectors HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); - v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v2); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); } } @@ -166,11 +229,80 @@ static void sqrt_f32(const float * restrict src, } } +static void neg_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f); + } +} + +static void exp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_exp_f32(dst_local, src_local, row_elems, false); + } +} + +static void sigmoid_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_sigmoid_f32_aa(dst_local, src_local, row_elems); + } +} + +static void softplus_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + // softplus(x) = log(1 + exp(x)) + // Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size)); + float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size)); + + for (uint32_t i = 0; i < row_elems; i++) { + float x = src_f[i]; + // For x > 20: softplus(x) ≈ x (avoids exp overflow) + dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x)); + } + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; - const struct htp_tensor * src = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src = octx->src[0]; + const struct htp_tensor * dst = octx->dst; htp_unary_preamble; @@ -178,8 +310,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * int32_t * op_params = octx->op_params; uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread; - const size_t src0_row_size = uctx->src0_row_size; - const size_t dst_row_size = uctx->dst_row_size; + const size_t src0_data_row_size = uctx->src0_data_row_size; + const size_t dst_data_row_size = uctx->dst_data_row_size; const size_t src0_row_size_aligned = uctx->src0_row_size_aligned; const size_t dst_row_size_aligned = uctx->dst_row_size_aligned; @@ -205,7 +337,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * size_t src0_spad_half_size = uctx->src0_spad_half_size; size_t dst_spad_half_size = uctx->dst_spad_half_size; - const int BLOCK = uctx->block; + // Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride + // 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every + // transfer stays within a nb1-uniform region. Skipped for contiguous tensors. + const bool src0_contig = (nb02 == (size_t)ne01 * nb01) && + (nb03 == (size_t)ne02 * nb02); + const bool dst_contig = (nb2 == (size_t)ne1 * nb1) && + (nb3 == (size_t)ne2 * nb2); + const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01); + const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1); + const uint32_t BLOCK = MIN(src0_max_block, dst_max_block); if (BLOCK == 0) { FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", octx->src0_spad.size_per_thread, src0_row_size_aligned); @@ -214,21 +355,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue * dma_queue = octx->ctx->dma[ith]; - for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { - const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) { + const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) - dma_queue_push_vtcm_to_ddr(dma_queue, + dma_queue_push(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), - dst_row_size, dst_row_size_aligned, 0); + nb1, dst_row_size_aligned, dst_data_row_size, 0); - dma_queue_push_ddr_to_vtcm(dma_queue, - dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)), - src0_row_size_aligned, src0_row_size, block_size); + const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off), + src0_row_size_aligned, nb01, src0_data_row_size, block_size); + ir += block_size; } - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { - const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + for (uint32_t ir = src0_start_row; ir < src0_end_row; ) { + const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); float * dst_spad = (float *) dma_queue_pop(dma_queue).src; float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; @@ -247,22 +390,41 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_SQRT: sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_UNARY_NEG: + neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_EXP: + exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SIGMOID: + sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SOFTPLUS: + softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; default: break; } - dma_queue_push_vtcm_to_ddr(dma_queue, - dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), - dst_row_size, dst_row_size_aligned, block_size); + const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3); + dma_queue_push(dma_queue, + dma_make_ptr(data_dst + dst_off, dst_spad), + nb1, dst_row_size_aligned, dst_data_row_size, block_size); // prefetch N+2 loop iteration if any - const uint32_t pref_block = (ir + BLOCK * 2); - if (pref_block < src0_end_row) { - const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); - dma_queue_push_ddr_to_vtcm(dma_queue, - dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)), - src0_row_size_aligned, src0_row_size, pref_block_size); + const uint32_t next_ir = ir + block_size; + if (next_ir < src0_end_row) { + const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + const uint32_t pref_ir = next_ir + next_block_size; + if (pref_ir < src0_end_row) { + const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad, data_src + src0_pref_off), + src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + } } + ir += block_size; } dma_queue_flush(dma_queue); @@ -277,8 +439,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * static int execute_op_unary_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; const char * op_type = NULL; @@ -295,6 +457,18 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_SQRT: op_type = "sqrt-f32"; break; + case HTP_OP_UNARY_NEG: + op_type = "neg-f32"; + break; + case HTP_OP_UNARY_EXP: + op_type = "exp-f32"; + break; + case HTP_OP_UNARY_SIGMOID: + op_type = "sigmoid-f32"; + break; + case HTP_OP_UNARY_SOFTPLUS: + op_type = "softplus-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); @@ -304,11 +478,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); - const size_t src0_row_size = src0->nb[1]; - const size_t dst_row_size = dst->nb[1]; + const size_t src0_data_row_size = src0->ne[0] * sizeof(float); + const size_t dst_data_row_size = dst->ne[0] * sizeof(float); - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN); // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size @@ -346,8 +520,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .data_src0 = (const uint8_t *)src0->data, .data_dst = (uint8_t *)dst->data, - .src0_row_size = src0_row_size, - .dst_row_size = dst_row_size, + .src0_data_row_size = src0_data_row_size, + .dst_data_row_size = dst_data_row_size, .src0_row_size_aligned = src0_row_size_aligned, .dst_row_size_aligned = dst_row_size_aligned, @@ -368,7 +542,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { int op_unary(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_unary_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/vtcm-utils.h b/ggml/src/ggml-hexagon/htp/vtcm-utils.h new file mode 100644 index 00000000000..b129fb74e31 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/vtcm-utils.h @@ -0,0 +1,16 @@ +#ifndef VTCM_UTILS_H +#define VTCM_UTILS_H + +#include "hex-utils.h" + +#include +#include +#include + +static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { + uint8_t *p = *vtcm_ptr; + *vtcm_ptr += size; + return p; +} + +#endif // VTCM_UTILS_H diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 656d2d9ab26..39cefcdda38 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -8,7 +8,7 @@ CatalogFile = libggml-htp.cat PnpLockDown = 1 [DestinationDirs] -Drivers_Dir = 6 +Drivers_Dir = 13 [SourceDisksNames] 1 = %DiskId% @@ -18,6 +18,7 @@ libggml-htp-v68.so = 1 libggml-htp-v69.so = 1 libggml-htp-v73.so = 1 libggml-htp-v75.so = 1 +libggml-htp-v79.so = 1 libggml-htp-v81.so = 1 [ControlFlags] @@ -31,6 +32,7 @@ libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE [Strings] diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index b44ed0f7215..a7d4e0ea2b5 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -47,15 +47,16 @@ find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (GGML_HIP_RCCL) + find_package(rccl REQUIRED) +endif() + if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") endif() message(STATUS "HIP and hipBLAS found") -# Workaround old compilers -set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024") - file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") @@ -74,12 +75,11 @@ if (GGML_CUDA_FA_ALL_QUANTS) list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) + list(APPEND GGML_SOURCES_ROCM + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-hip @@ -122,6 +122,10 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() +if (GGML_HIP_RCCL) + add_compile_definitions(GGML_USE_NCCL) # RCCL has the same interface as NCCL. +endif() + if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() @@ -132,6 +136,11 @@ endif() if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) + if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug") + # CMake on Windows doesn't support the HIP language yet. + # Therefore we workaround debug build's failure on HIP backend this way. + set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g") + endif() target_link_libraries(ggml-hip PRIVATE hip::device) else() set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP) @@ -141,4 +150,8 @@ if (GGML_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() +if (GGML_HIP_RCCL) + target_link_libraries(ggml-hip PRIVATE ggml-base roc::rccl) +endif() + target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 92568655956..62b76abbcec 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -30,6 +30,8 @@ extern "C" { void ggml_print_backtrace(void); +uint64_t ggml_graph_next_uid(void); + #ifndef MIN # define MIN(a, b) ((a) < (b) ? (a) : (b)) #endif @@ -338,6 +340,10 @@ struct ggml_cgraph { struct ggml_hash_set visited_hash_set; enum ggml_cgraph_eval_order order; + + // an optional identifier that can be utilized to recognize same graphs if two non-zero values match + // a value of 0 means it is not set and should be ignored + uint64_t uid; }; // returns a slice of cgraph with nodes [i0, i1) @@ -773,6 +779,5 @@ inline bool ggml_check_edges(const struct ggml_cgraph * cgraph, // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); -GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta); #endif // __cplusplus diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 72ad876d5e4..d211bf79f14 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -246,6 +246,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; + case GGML_UNARY_OP_FLOOR: op_num = OP_UNARY_NUM_FLOOR; break; + case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; + case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; + case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; + case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); @@ -672,7 +677,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta const ggml_type tsrc1 = op->src[1]->type; const bool bc_inp = op->src[0]->ne[0] % 32 != 0; - const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0; + + constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; + constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; + + const bool has_tensor = ggml_metal_device_get_props(ggml_metal_library_get_device(lib))->has_tensor; + + const bool bc_out = has_tensor + ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) + : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); @@ -689,8 +702,20 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_free(cv); } - // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes - res.smem = bc_out ? 8192 : 4096 + 2048; + if (has_tensor) { + res.nr0 = NRA; + res.nr1 = NRB; + + const size_t smem_a = NRA * N_MM_NK_TOTAL * sizeof(ggml_fp16_t); + res.smem = smem_a; + } else { + res.nr0 = 64; + res.nr1 = 32; + + res.smem = bc_out ? 8192 : (4096 + 2048); + } + + res.nsg = N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; return res; } @@ -732,6 +757,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -944,6 +974,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m smem = 32*sizeof(float)*nr0; suffix = ne00 % 4 == 0 ? "_4" : ""; } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -1748,6 +1783,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_3D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); @@ -1782,6 +1839,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ROLL); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_roll_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index fd2b3ddeb55..a6c1dab5515 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -102,6 +102,8 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev void ggml_metal_library_free(ggml_metal_library_t lib); +ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); @@ -148,9 +150,11 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 82101f4714e..fe90aafe7bc 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -95,8 +95,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi struct ggml_metal_library { id obj; - id device; + ggml_metal_device_t dev; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines NSLock * lock; @@ -251,7 +251,7 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); res->obj = library; - res->device = device; + res->dev = dev; res->pipelines = ggml_metal_pipelines_init(); res->lock = [NSLock new]; @@ -318,7 +318,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev } res->obj = library; - res->device = device; + res->dev = dev; res->pipelines = ggml_metal_pipelines_init(); res->lock = [NSLock new]; @@ -341,6 +341,10 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { free(lib); } +ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib) { + return lib->dev; +} + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { [lib->lock lock]; @@ -405,7 +409,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ return res; } - id obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + id device = ggml_metal_device_get_obj(lib->dev); + id obj = [device newComputePipelineStateWithFunction:mtl_function error:&error]; [mtl_function release]; @@ -690,7 +695,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto tB = B.slice((int)tgid.x, 0); \n" " \n" " matmul2d< \n" - " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " matmul2d_descriptor(16, 16, dynamic_extent), \n" " execution_simdgroups<4>> mm; \n" " \n" " auto cT = mm.get_destination_cooperative_tensor(); \n" @@ -699,7 +704,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto sB = tB.slice(0, 0); \n" " mm.run(sB, sA, cT); \n" " \n" - " auto tC = tensor, tensor_inline>(C, dextents(4, 4)); \n" + " auto tC = tensor, tensor_inline>(C, dextents(16, 16)); \n" " \n" " cT.store(tC); \n" "}"; @@ -740,7 +745,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto tB = B.slice((int)tgid.x, 0); \n" " \n" " matmul2d< \n" - " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " matmul2d_descriptor(16, 16, dynamic_extent), \n" " execution_simdgroups<4>> mm; \n" " \n" " auto cT = mm.get_destination_cooperative_tensor(); \n" @@ -749,7 +754,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto sB = tB.slice(0, 0); \n" " mm.run(sB, sA, cT); \n" " \n" - " auto tC = tensor, tensor_inline>(C, dextents(4, 4)); \n" + " auto tC = tensor, tensor_inline>(C, dextents(16, 16)); \n" " \n" " cT.store(tC); \n" "}"; @@ -814,7 +819,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { } // print MTL GPU family: - GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); + GGML_LOG_INFO("%s: GPU name: %s (%s)\n", __func__, dev->props.name, dev->props.desc); // determine max supported GPU family // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -931,13 +936,13 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { } struct ggml_metal_event { - void * obj; // id + void * obj; // id atomic_int value; }; void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { - id event = (id)ev->obj; + id event = (id)ev->obj; id cmd_buf = (id) cmd_buf_raw; @@ -945,7 +950,7 @@ void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t } void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { - id event = (id)ev->obj; + id event = (id)ev->obj; id cmd_buf = (id) cmd_buf_raw; @@ -953,7 +958,7 @@ void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cm } ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { - id event = [dev->mtl_device newEvent]; + id event = [dev->mtl_device newSharedEvent]; ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event)); @@ -964,7 +969,7 @@ ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { } void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) { - id event = ev->obj; + id event = ev->obj; [event release]; free(ev); @@ -973,14 +978,13 @@ void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev } void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) { - @autoreleasepool { - id event = ev->obj; - - id cmd_buf = [dev->mtl_queue commandBuffer]; - [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; - [cmd_buf commit]; - [cmd_buf waitUntilCompleted]; + id event = ev->obj; + const bool res = [event waitUntilSignaledValue:atomic_load_explicit(&ev->value, memory_order_relaxed) timeoutMS:60000]; + if (!res) { + GGML_ABORT("%s: failed to wait for event\n", __func__); } + + GGML_UNUSED(dev); } void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { @@ -1039,6 +1043,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_XIELU: return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; @@ -1077,6 +1086,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONV_3D: + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && + op->src[1]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_TRI: @@ -1128,6 +1142,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ARGSORT: case GGML_OP_TOP_K: case GGML_OP_ARANGE: + case GGML_OP_ROLL: return true; case GGML_OP_FLASH_ATTN_EXT: // for new head sizes, add checks here @@ -1143,12 +1158,30 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 192 && op->src[0]->ne[0] != 256 && op->src[0]->ne[0] != 320 && + op->src[0]->ne[0] != 512 && op->src[0]->ne[0] != 576) { return false; } if (op->src[1]->type != op->src[2]->type) { return false; } + switch (op->src[1]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + break; + case GGML_TYPE_BF16: + if (!has_bfloat) { + return false; + } + break; + default: + return false; + } return has_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: @@ -1174,6 +1207,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1200,6 +1234,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 53437b23cda..ff74cafb5b7 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1,6 +1,19 @@ #ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL +// kernel parameters for mat-mat threadgroups +// +// TODO: become function constants + +#define SZ_SIMDGROUP 16 +#define N_MM_NK 2 +#define N_MM_NK_TOTAL (SZ_SIMDGROUP * N_MM_NK) + +#define N_MM_BLOCK_X 4 +#define N_MM_BLOCK_Y 2 +#define N_MM_SIMD_GROUP_X 2 +#define N_MM_SIMD_GROUP_Y 2 + // kernel parameters for mat-vec threadgroups // // N_R0: number of src0 rows to process per simdgroup @@ -8,6 +21,9 @@ // // TODO: for optimal performance, become function of the device and work size +#define N_R0_Q1_0 8 +#define N_SG_Q1_0 2 + #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 @@ -120,6 +136,11 @@ #define OP_UNARY_NUM_EXP 114 #define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_EXPM1 116 +#define OP_UNARY_NUM_FLOOR 117 +#define OP_UNARY_NUM_CEIL 118 +#define OP_UNARY_NUM_ROUND 119 +#define OP_UNARY_NUM_TRUNC 120 +#define OP_UNARY_NUM_XIELU 121 #define OP_SUM_ROWS_NUM_SUM_ROWS 10 #define OP_SUM_ROWS_NUM_MEAN 11 @@ -643,6 +664,42 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct { + int32_t IW; + int32_t IH; + int32_t ID; + int32_t OW; + int32_t OH; + int32_t OD; + int32_t KW; + int32_t KH; + int32_t KD; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t p0; + int32_t p1; + int32_t p2; + int32_t d0; + int32_t d1; + int32_t d2; + int32_t IC; + int32_t N; + int32_t OC; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_conv_3d; + typedef struct{ int32_t ne00; uint64_t nb01; @@ -973,6 +1030,29 @@ typedef struct { int32_t p1; } ggml_metal_kargs_pad_reflect_1d; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t s3; +} ggml_metal_kargs_roll; + typedef struct { uint64_t nb1; int dim; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c0bcad392b9..5fa162c875c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -394,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); } break; + case GGML_OP_CONV_3D: + { + n_fuse = ggml_metal_op_conv_3d(ctx, idx); + } break; case GGML_OP_UPSCALE: { n_fuse = ggml_metal_op_upscale(ctx, idx); @@ -406,6 +410,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx); } break; + case GGML_OP_ROLL: + { + n_fuse = ggml_metal_op_roll(ctx, idx); + } break; case GGML_OP_ARANGE: { n_fuse = ggml_metal_op_arange(ctx, idx); @@ -783,6 +791,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { args.max = ggml_get_op_params_f32(op, 1); } + if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) { + args.slope = ggml_get_op_params_f32(op, 1); // alpha_n + args.scale = ggml_get_op_params_f32(op, 2); // alpha_p + args.bias = ggml_get_op_params_f32(op, 3); // beta + args.val = ggml_get_op_params_f32(op, 4); // eps + } + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); if (pipeline.c4) { @@ -2043,6 +2058,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q1_0 || op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || @@ -2179,7 +2195,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); + + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + nr1 - 1) / nr1), ((ne01 + nr0 - 1) / nr0), ne12 * ne13, 32, nsg, 1); } else { auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); @@ -3697,6 +3718,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + // 1. Extract standard dimensions and byte strides + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + // 2. Extract hyperparams from op_params + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + const int32_t s1 = ((const int32_t *)(op->op_params))[1]; + const int32_t s2 = ((const int32_t *)(op->op_params))[2]; + const int32_t p0 = ((const int32_t *)(op->op_params))[3]; + const int32_t p1 = ((const int32_t *)(op->op_params))[4]; + const int32_t p2 = ((const int32_t *)(op->op_params))[5]; + const int32_t d0 = ((const int32_t *)(op->op_params))[6]; + const int32_t d1 = ((const int32_t *)(op->op_params))[7]; + const int32_t d2 = ((const int32_t *)(op->op_params))[8]; + const int32_t IC = ((const int32_t *)(op->op_params))[9]; + const int32_t N = ((const int32_t *)(op->op_params))[10]; + const int32_t OC = ((const int32_t *)(op->op_params))[11]; + + // 3. Build the parameter struct using the macro-generated variables + ggml_metal_kargs_conv_3d args = { + /*.IW =*/ (int32_t)op->src[1]->ne[0], + /*.IH =*/ (int32_t)op->src[1]->ne[1], + /*.ID =*/ (int32_t)op->src[1]->ne[2], + /*.OW =*/ (int32_t)op->ne[0], + /*.OH =*/ (int32_t)op->ne[1], + /*.OD =*/ (int32_t)op->ne[2], + /*.KW =*/ (int32_t)op->src[0]->ne[0], + /*.KH =*/ (int32_t)op->src[0]->ne[1], + /*.KD =*/ (int32_t)op->src[0]->ne[2], + s0, s1, s2, + p0, p1, p2, + d0, d1, d2, + IC, N, OC, + nb00, nb01, nb02, nb03, // Weight strides + nb10, nb11, nb12, nb13, // Input strides + nb0, nb1, nb2, nb3 // Output strides + }; + + // 4. Fetch the JIT pipeline + auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op); + + // 5. Grid mapping + int nth0 = 32; // Standard SIMD width for Apple Silicon + int nth1 = 1; + int nth2 = 1; + + int64_t spatial_volume = args.OW * args.OH * args.OD; + + int ntg0 = (spatial_volume + nth0 - 1) / nth0; + int ntg1 = args.OC; + int ntg2 = args.N; + + // 6. Bind and Dispatch via the ggml C wrapper + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2); + + return 1; +} + int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3862,6 +3954,59 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ggml_get_op_params_i32(op, 0); + const int32_t s1 = ggml_get_op_params_i32(op, 1); + const int32_t s2 = ggml_get_op_params_i32(op, 2); + const int32_t s3 = ggml_get_op_params_i32(op, 3); + + ggml_metal_kargs_roll args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.s2 =*/ s2, + /*.s3 =*/ s3 + }; + + auto pipeline = ggml_metal_library_get_pipeline_roll(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 019f2fec9ed..36c61071b4f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -75,11 +75,13 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_roll (ggml_metal_op_t ctx, int idx); int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 9382ce53b36..cc329d67594 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -90,6 +90,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, /* .clear = */ ggml_backend_metal_buffer_shared_clear, /* .reset = */ NULL, @@ -158,15 +160,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer } static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_private_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_private_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, }; static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { @@ -563,6 +567,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .free = */ ggml_backend_metal_free, /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, @@ -912,6 +918,10 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) { static std::vector devs; if (!initialized) { + // workaround macOS limitation (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) until proper fix becomes possible + // ref: https://github.com/ggml-org/llama.cpp/issues/20141#issuecomment-4272947703 + setenv("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1", true); + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); for (int i = 0; i < g_devices; ++i) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b2328605dd9..c372eaedeae 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg } #endif +template +void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs; + const float d = xb->d; + const float neg_d = -d; + + const int byte_offset = il * 2; // il*16 bits = il*2 bytes + const uint8_t b0 = qs[byte_offset]; + const uint8_t b1 = qs[byte_offset + 1]; + + float4x4 reg_f; + + reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01)); + reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02)); + reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04)); + reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08)); + reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10)); + reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20)); + reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40)); + reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80)); + + reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01)); + reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02)); + reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04)); + reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08)); + reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10)); + reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20)); + reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40)); + reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80)); + + reg = (type4x4) reg_f; +} + +template +void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) { + const float d = xb->d; + const float neg_d = -d; + const int base = il * 4; + const uint8_t byte = xb->qs[base / 8]; + const int s = base % 8; + + float4 reg_f; + reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1)); + reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1)); + reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1)); + reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1)); + + reg = (type4) reg_f; +} + template void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); @@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } } +void quantize_q1_0(device const float * src, device block_q1_0 & dst) { + float sum_abs = 0.0f; + for (int j = 0; j < QK1_0; j++) { + sum_abs += fabs(src[j]); + } + dst.d = sum_abs / QK1_0; + + for (int j = 0; j < QK1_0 / 8; j++) { + dst.qs[j] = 0; + } + for (int j = 0; j < QK1_0; j++) { + if (src[j] >= 0.0f) { + dst.qs[j / 8] |= (1 << (j % 8)); + } + } +} + void quantize_q4_0(device const float * src, device block_q4_0 & dst) { #pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max @@ -1094,6 +1161,31 @@ kernel void kernel_unary_impl( // TODO: precise implementation dst_ptr[i0] = (T) (exp(x) - 1); } + + if (FC_OP == OP_UNARY_NUM_FLOOR) { + dst_ptr[i0] = (T) floor(x); + } + + if (FC_OP == OP_UNARY_NUM_CEIL) { + dst_ptr[i0] = (T) ceil(x); + } + + if (FC_OP == OP_UNARY_NUM_ROUND) { + dst_ptr[i0] = (T) round(x); + } + + if (FC_OP == OP_UNARY_NUM_TRUNC) { + dst_ptr[i0] = (T) trunc(x); + } + + if (FC_OP == OP_UNARY_NUM_XIELU) { + const TC xi = x; + const TC gate = TC(xi > TC(0.0f)); + const TC clamped = fmin(xi, TC(args.val)); + const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi; + const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi; + dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg); + } } #undef FC_OP @@ -3100,6 +3192,35 @@ kernel void kernel_group_norm_f32( } } +// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy) +inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) { + device const uint8_t * qs = qb_curr->qs + il / 8; + const uint8_t b0 = qs[0]; + const uint8_t b1 = qs[1]; + + float acc = 0.0f; + + acc += select(0.0f, yl[ 0], bool(b0 & 0x01)); + acc += select(0.0f, yl[ 1], bool(b0 & 0x02)); + acc += select(0.0f, yl[ 2], bool(b0 & 0x04)); + acc += select(0.0f, yl[ 3], bool(b0 & 0x08)); + acc += select(0.0f, yl[ 4], bool(b0 & 0x10)); + acc += select(0.0f, yl[ 5], bool(b0 & 0x20)); + acc += select(0.0f, yl[ 6], bool(b0 & 0x40)); + acc += select(0.0f, yl[ 7], bool(b0 & 0x80)); + + acc += select(0.0f, yl[ 8], bool(b1 & 0x01)); + acc += select(0.0f, yl[ 9], bool(b1 & 0x02)); + acc += select(0.0f, yl[10], bool(b1 & 0x04)); + acc += select(0.0f, yl[11], bool(b1 & 0x08)); + acc += select(0.0f, yl[12], bool(b1 & 0x10)); + acc += select(0.0f, yl[13], bool(b1 & 0x20)); + acc += select(0.0f, yl[14], bool(b1 & 0x40)); + acc += select(0.0f, yl[15], bool(b1 & 0x80)); + + return qb_curr->d * (2.0f * acc - sumy); +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -3321,6 +3442,85 @@ void mul_vec_q_n_f32_impl( } } +template +void kernel_mul_mv_q1_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + const int nb = args.ne00/QK1_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + device const block_q1_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); + } + + float yl[16]; + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/8); + const short il = (tiisg%8)*16; + + device const float * yb = y + ix*QK1_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { + float sumy = 0.f; + + FOR_UNROLL (short i = 0; i < 16; i++) { + yl[i] = yb[i]; + sumy += yb[i]; + } + + FOR_UNROLL (short row = 0; row < nr0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); + } + + yb += QK1_0 * (N_SIMDWIDTH/8); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q1_0_f32")]] +kernel void kernel_mul_mv_q1_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q1_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3713,6 +3913,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>; #endif +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>; + template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; @@ -4883,6 +5088,98 @@ kernel void kernel_upscale_bilinear_f32( } } +template +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, // Weights [IC * OC, KD, KH, KW] + device const char * src1, // Inputs [IC * N, ID, IH, IW] + device char * dst, // Outputs [OC * N, OD, OH, OW] + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + + // 1. Un-flatten the spatial dimension from Grid X + int64_t spatial_idx = tgpig.x * 32 + tpitg.x; + + if (spatial_idx >= args.OW * args.OH * args.OD) { + return; // Thread falls outside the spatial volume + } + + int64_t od = spatial_idx / (args.OW * args.OH); + int64_t oh = (spatial_idx / args.OW) % args.OH; + int64_t ow = spatial_idx % args.OW; + + // 2. Map Y to Channels, Z to Batch + int64_t oc = tgpig.y; + int64_t batch_idx = tgpig.z; + + // 3. Calculate anchor coordinates in the Input volume + int64_t i_w_base = ow * args.s0 - args.p0; + int64_t i_h_base = oh * args.s1 - args.p1; + int64_t i_d_base = od * args.s2 - args.p2; + + float sum = 0.0f; + + // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width) + for (int64_t ic = 0; ic < args.IC; ++ic) { + + // ggml packs batch and channel together in the 4th dimension + int64_t src_cn_idx = batch_idx * args.IC + ic; + int64_t w_cn_idx = oc * args.IC + ic; + + for (int64_t kz = 0; kz < args.KD; ++kz) { + int64_t id = i_d_base + kz * args.d2; + if (id < 0 || id >= args.ID) continue; // Boundary check (Padding) + + for (int64_t ky = 0; ky < args.KH; ++ky) { + int64_t ih = i_h_base + ky * args.d1; + if (ih < 0 || ih >= args.IH) continue; + + for (int64_t kx = 0; kx < args.KW; ++kx) { + int64_t iw = i_w_base + kx * args.d0; + if (iw < 0 || iw >= args.IW) continue; + + // Convert multi-dimensional coordinates to flat byte offsets + int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03; + int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13; + + // Dereference memory and cast weights to f32 if they were f16 + float w_val = (float)*(device const T*)((device const char*)src0 + w_idx); + float i_val = *(device const float*)((device const char*)src1 + i_idx); + + sum += w_val * i_val; + } + } + } + } + + // 5. Write the accumulated value out to RAM + int64_t dst_cn_idx = batch_idx * args.OC + oc; + int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3; + + *(device float*)(dst + d_idx) = sum; +} + +// Explicit instantiations so the JIT compiler can find them by name +template [[host_name("kernel_conv_3d_f32_f32")]] +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +// Explicit instantiation for f16 weights +template [[host_name("kernel_conv_3d_f16_f32")]] +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + + static inline float bicubic_weight1(float x) { const float a = -0.75f; return ((a + 2) * x - (a + 3)) * x * x + 1; @@ -4950,6 +5247,40 @@ kernel void kernel_upscale_bicubic_f32( } } +kernel void kernel_roll_f32( + constant ggml_metal_kargs_roll & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + device const float * src0_ptr = (device const float *) src0; + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + // apply shifts and wrap around + int64_t i00 = i0 - args.s0; + int64_t i01 = i1 - args.s1; + int64_t i02 = i2 - args.s2; + int64_t i03 = i3 - args.s3; + + if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; } + if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; } + if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; } + if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; } + + int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00; + int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0; + + dst_ptr[dst_idx] = src0_ptr[src_idx]; + } +} + kernel void kernel_pad_f32( constant ggml_metal_kargs_pad & args, device const char * src0, @@ -6177,6 +6508,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6192,6 +6524,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) @@ -6208,6 +6541,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif @@ -6224,6 +6558,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6239,6 +6574,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6254,6 +6590,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6269,6 +6606,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6284,6 +6622,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -6865,6 +7204,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) @@ -7006,6 +7356,7 @@ kernel void kernel_cpy_f32_q( typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; @@ -7046,12 +7397,14 @@ kernel void kernel_cpy_q_f32( typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; +template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; @@ -8953,7 +9306,137 @@ constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; // each block_q contains 16*nl weights -template +#ifdef GGML_METAL_HAS_TENSOR +template< + typename SA, typename SA_4x4, typename SA_8x8, + typename SB, typename SB_2x4, typename SB_8x8, + typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), + typename T0, typename T0_4x4, typename T1, typename T1_2x4> +kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + // Matrix dimensions: A(M,K) x B(K,N) -> C(M,N) + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + + // Batch dimension handling + const int im = tgpig.z; + const int i12 = im % args.ne12; + const int i13 = im / args.ne12; + + // Batch offsets for srcA and srcB + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + // Tile dimensions + constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; + constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; + + // Tile offsets in output matrix + const int ra = tgpig.y * NRA; + const int rb = tgpig.x * NRB; + + // Threadgroup memory for dequantized A tile only + threadgroup SA * sa = (threadgroup SA *)(shmem); + + // Work-item count for A loading + constexpr int A_WORK_ITEMS = NRA * N_MM_NK; + constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; + + // tA wraps threadgroup memory + auto tA = tensor(sa, dextents(N_MM_NK_TOTAL, NRA)); + + // tB wraps device memory directly + device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11 / sizeof(T1); + auto tB = tensor(ptrB, dextents(K, N), array({1, strideB})); + + // Configure matmul operation + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor( + NRB, NRA, N_MM_NK_TOTAL, false, true, true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups> mm; + + auto cT = mm.get_destination_cooperative_tensor(); + + // Accumulate partial results over K dimension + for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) { + // === PHASE 1: Dequantization of A into threadgroup memory === + for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) { + const int row = work / N_MM_NK; + const int k_chunk = work % N_MM_NK; + const int k_pos = loop_k + k_chunk * 16; + const short k_base = k_chunk * 16; + + // Bounds check: skip device read if row is out of matrix bounds + if (ra + row < M) { + if (is_same::value && FC_mul_mm_bc_inp) { + // Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4). + // MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd, + // nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned. + // Mirrors the legacy kernel's existing guard. + device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0); + + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0; + } + } else { + const int block_idx = k_pos / (16 * nl); + const short il = (k_pos / 16) % nl; + + device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0); + + SA_4x4 temp_a; + dequantize_func(row_ptr + block_idx, il, temp_a); + + FOR_UNROLL (short i = 0; i < 16; i++) { + // Zero-pad A for K positions beyond valid range (handles partial K iterations) + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0; + } + } + } else { + // Zero-pad rows beyond matrix bounds + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // === PHASE 2: Tensor matmul === + auto mA = tA.slice(0, 0); + auto mB = tB.slice(loop_k, rb); + + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Store result tile to output matrix (with batch offset) + // cT.store handles bounds checking via tD's extents (M, N) + device float * dstBatch = (device float *)dst + im * N * M; + + auto tD = tensor(dstBatch, dextents(M, N), array({1, M})); + cT.store(tD.slice(ra, rb)); +} + +#else + +template< + typename S0, typename S0_4x4, typename S0_8x8, + typename S1, typename S1_2x4, typename S1_8x8, + typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), + typename T0, typename T0_4x4, typename T1, typename T1_2x4> kernel void kernel_mul_mm( constant ggml_metal_kargs_mul_mm & args, device const char * src0, @@ -8967,10 +9450,6 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); -#ifdef GGML_METAL_HAS_TENSOR - threadgroup float * sc = (threadgroup float *)(shmem); -#endif - constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9010,7 +9489,6 @@ kernel void kernel_mul_mm( + args.nb11*(r1 + lr1) + args.nb10*iy); -#ifndef GGML_METAL_HAS_TENSOR S0_8x8 ma[4]; S1_8x8 mb[2]; @@ -9019,19 +9497,8 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } -#else - auto tA = tensor, tensor_inline>(sa, dextents(NK, NR0)); - auto tB = tensor, tensor_inline>(sb, dextents(NR1, NK )); - - mpp::tensor_ops::matmul2d< - mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), - execution_simdgroups<4>> mm; - - auto cT = mm.get_destination_cooperative_tensor(); -#endif for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { -#ifndef GGML_METAL_HAS_TENSOR // load data and store to threadgroup memory if (is_same::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -9101,66 +9568,6 @@ kernel void kernel_mul_mm( *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); } -#else - // load data and store to threadgroup memory - if (is_same::value && FC_mul_mm_bc_inp) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // no need for dequantization - for (short i = 0; i < 16; i++) { - const short sx = 2*il0 + i/8; - const short sy = (tiitg/NL0)/8; - - const short lx = i%8; - const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; - - *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; - } - } else { - S0_4x4 temp_a; - dequantize_func(x, il, temp_a); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - FOR_UNROLL (short i = 0; i < 16; i++) { - const short sx = 2*il0 + i/8; - const short sy = (tiitg/NL0)/8; - - const short lx = i%8; - const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; - - *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; - } - } - - if (FC_mul_mm_bc_inp) { - for (short i = 0; i < 8; ++i) { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - const short lx = i; - const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; - - *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; - } - } else { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - //const short lx = i; - const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; - - *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); - } -#endif il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -9169,7 +9576,6 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); -#ifndef GGML_METAL_HAS_TENSOR // load matrices from threadgroup memory and conduct outer products threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); @@ -9196,24 +9602,10 @@ kernel void kernel_mul_mm( lsma += 8*64; lsmb += 4*64; } -#else - auto sA = tA.slice(0, 0); - auto sB = tB.slice(0, 0); - - mm.run(sB, sA, cT); -#endif } if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { // if no bounds checks on the output are needed, we can directly write to device memory -#ifdef GGML_METAL_HAS_TENSOR - device float * C = (device float *) dst + - r0 + \ - r1 * args.ne0 + im*args.ne1*args.ne0; - - auto tC = tensor, tensor_inline>(C, dextents(args.ne0, NR1)); - cT.store(tC); -#else device float * C = (device float *) dst + (r0 + 32*(sgitg & 1)) + \ (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; @@ -9221,21 +9613,15 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); } -#endif } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; -#ifdef GGML_METAL_HAS_TENSOR - auto tC = tensor, tensor_inline>(sc, dextents(NR0, NR1)); - cT.store(tC); -#else for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } -#endif threadgroup_barrier(mem_flags::mem_threadgroup); @@ -9261,6 +9647,8 @@ kernel void kernel_mul_mm( } } +#endif // GGML_METAL_HAS_TENSOR + template // n_expert_used kernel void kernel_mul_mm_id_map0( constant ggml_metal_kargs_mul_mm_id_map0 & args, @@ -9436,7 +9824,7 @@ kernel void kernel_mul_mm_id( const short ib = 8*sx + sy; - *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0; } } else { S0_4x4 temp_a; @@ -9649,6 +10037,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro typedef decltype(kernel_get_rows_q) get_rows_q_t; +template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9711,6 +10100,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif +template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -9734,6 +10124,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -9766,6 +10157,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id; #endif +template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; @@ -9789,6 +10181,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -9943,6 +10336,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4 template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977f..cc53c812ce5 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + list(APPEND GGML_SOURCES_MUSA + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 1f8250934b0..5ed83eeb48a 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -89,10 +89,15 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_1_f32 mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 + mul_mv_q4_k_f32_flat + mul_mv_q5_k_f32 + mul_mv_q5_k_f32_flat mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 mul_mv_q8_0_f32_flat + mul_mv_iq4_nl_f32 + mul_mv_iq4_nl_f32_flat mul_mv_mxfp4_f32 mul_mv_mxfp4_f32_flat mul_mv_id_q4_0_f32_8x_flat @@ -107,11 +112,22 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_0_f32_l4_lm mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_iq4_nl_f32_l4_lm + mul_mm_q4_k_f32_l4_lm + mul_mm_q5_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 + gemv_noshuffle_iq4_nl_f32 + gemm_noshuffle_iq4_nl_f32 gemv_noshuffle_general_q8_0_f32 + gemv_noshuffle_q4_k_f32 + gemm_noshuffle_q4_k_f32 + gemv_noshuffle_q6_k_f32 + gemm_noshuffle_q6_k_f32 + gemv_noshuffle_q5_k_f32 + gemm_noshuffle_q5_k_f32 mul neg norm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e1dca6b4b4d..11f72a5198a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -394,6 +394,9 @@ struct ggml_backend_opencl_context { bool fp16_support; bool has_vector_subgroup_broadcast; bool disable_fusion; + + bool adreno_has_large_buffer; + bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; @@ -529,20 +532,35 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; + cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q4_1_noshuffle; cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q4_K_noshuffle; + cl_kernel kernel_restore_block_q4_K_noshuffle; + cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; + cl_kernel kernel_convert_block_q5_K, kernel_restore_block_q5_K; + cl_kernel kernel_convert_block_q5_K_noshuffle; + cl_kernel kernel_restore_block_q5_K_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; + cl_kernel kernel_convert_block_iq4_nl, kernel_restore_block_iq4_nl; + cl_kernel kernel_convert_block_iq4_nl_noshuffle; + cl_kernel kernel_restore_block_iq4_nl_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; + cl_kernel kernel_mul_mv_q4_K_f32_flat; + cl_kernel kernel_mul_mv_q5_K_f32; + cl_kernel kernel_mul_mv_q5_K_f32_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat; + cl_kernel kernel_mul_mv_iq4_nl_f32; + cl_kernel kernel_mul_mv_iq4_nl_f32_flat; cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; @@ -578,7 +596,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; + cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm; std::vector profiling_info; @@ -713,6 +734,14 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q4_1_f32; cl_kernel kernel_mul_mm_q8_0_f32_8x4; cl_kernel CL_mul_mat_vec_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_k_f32; + cl_kernel kernel_gemm_noshuffle_q4_k_f32; + cl_kernel kernel_gemv_noshuffle_q6_K_f32; + cl_kernel kernel_gemm_noshuffle_q6_K_f32; + cl_kernel kernel_gemv_noshuffle_q5_k_f32; + cl_kernel kernel_gemm_noshuffle_q5_k_f32; + cl_kernel kernel_gemv_noshuffle_iq4_nl_f32; + cl_kernel kernel_gemm_noshuffle_iq4_nl_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -781,6 +810,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -cl-mad-enable -cl-unsafe-math-optimizations" " -cl-finite-math-only -cl-fast-relaxed-math"; + if (backend_ctx->adreno_use_large_buffer) { + compile_opts += " -qcom-enable-large-buffer "; + } + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); // add @@ -917,8 +950,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl_noshuffle", &err), err)); GGML_LOG_CONT("."); } @@ -1209,6 +1256,56 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q4_k_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_k_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_k_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_k_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_k_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_k_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1274,6 +1371,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_iq4_nl_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_iq4_nl_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32 = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_iq4_nl_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_iq4_nl_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_iq4_nl_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32_flat = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1482,6 +1613,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_iq4_nl_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_iq4_nl_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_iq4_nl_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_iq4_nl_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mm_q4_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q6_k_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1499,6 +1664,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q5_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2528,6 +2710,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemm_noshuffle_iq4_nl_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_iq4_nl_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_iq4_nl_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_iq4_nl_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_8x4 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2568,6 +2789,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemm_noshuffle_q4_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q4_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q4_k_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable " " -cl-fast-relaxed-math"; @@ -2603,6 +2863,84 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err)); GGML_LOG_CONT("."); } + + // gemv_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q6_k_f32.cl"); +#endif + + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q6_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_k_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q5_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_k_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); } @@ -2937,6 +3275,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + // check Adreno large buffer support + backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; // fp16 is required if (!backend_ctx->fp16_support) { @@ -3003,6 +3343,18 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); #endif // GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use large buffer for Adreno + backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && + backend_ctx->gpu_family == GPU_FAMILY::ADRENO; + if (backend_ctx->adreno_use_large_buffer) { + if (!backend_ctx->adreno_has_large_buffer) { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); + backend_ctx->adreno_use_large_buffer = false; + } else { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); + } + } + cl_int err; // A local ref of cl_context for convenience @@ -3347,31 +3699,141 @@ struct ggml_tensor_extra_cl_q8_0 { } }; -struct ggml_tensor_extra_cl_q6_K { - // Lower 4 bits of quantized weights. - cl_mem ql = nullptr; - // Upper 2 bits of quantized weights. - cl_mem qh = nullptr; - // Scales for each block. - cl_mem s = nullptr; - // Scales for each super block. - cl_mem d = nullptr; +struct ggml_tensor_extra_cl_iq4_nl { + cl_mem q = nullptr; + cl_mem q_img = nullptr; - size_t size_ql = 0; - size_t size_qh = 0; - size_t size_s = 0; - size_t size_d = 0; + cl_mem d = nullptr; + cl_mem d_img = nullptr; - ~ggml_tensor_extra_cl_q6_K() { + size_t size_q = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_iq4_nl() { reset(); } void reset() { - if (ql != nullptr) { - CL_CHECK(clReleaseMemObject(ql)); - ql = nullptr; - } - if (qh != nullptr) { + if (q != nullptr) { CL_CHECK(clReleaseMemObject(q)); q = nullptr; } + if (d != nullptr) { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + +struct ggml_tensor_extra_cl_q4_K { + // Quantized values + cl_mem q = nullptr; + // Scales for each super block. + cl_mem s = nullptr; + // Scales + cl_mem d = nullptr; + // Min + cl_mem dm = nullptr; + + ~ggml_tensor_extra_cl_q4_K() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (dm != nullptr) { + CL_CHECK(clReleaseMemObject(dm)); + dm = nullptr; + } + } +}; + +struct ggml_tensor_extra_cl_q5_K { + // Lower 4 bits of quantized weights. + cl_mem q = nullptr; + // Upper 1 bit of quantized weights. + cl_mem qh = nullptr; + // Scales for each block. + cl_mem s = nullptr; + // Scales for each super block. + cl_mem d = nullptr; + // Min for each super block. + cl_mem dm = nullptr; + + size_t size_q = 0; + size_t size_qh = 0; + size_t size_s = 0; + size_t size_d = 0; + size_t size_dm = 0; + + ~ggml_tensor_extra_cl_q5_K() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (dm != nullptr) { + CL_CHECK(clReleaseMemObject(dm)); + dm = nullptr; + } + + size_q = 0; + size_qh = 0; + size_s = 0; + size_d = 0; + size_dm = 0; + } +}; + +struct ggml_tensor_extra_cl_q6_K { + // Lower 4 bits of quantized weights. + cl_mem ql = nullptr; + // Upper 2 bits of quantized weights. + cl_mem qh = nullptr; + // Scales for each block. + cl_mem s = nullptr; + // Scales for each super block. + cl_mem d = nullptr; + + size_t size_ql = 0; + size_t size_qh = 0; + size_t size_s = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q6_K() { + reset(); + } + + void reset() { + if (ql != nullptr) { + CL_CHECK(clReleaseMemObject(ql)); + ql = nullptr; + } + if (qh != nullptr) { CL_CHECK(clReleaseMemObject(qh)); qh = nullptr; } @@ -3761,7 +4223,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_IQ4_NL || op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { @@ -3879,6 +4343,8 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .free = */ ggml_backend_opencl_free, /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ /* .synchronize = */ ggml_backend_opencl_synchronize, /* .graph_plan_create = */ NULL, @@ -3956,12 +4422,30 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl) { + delete e; + } + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) { + delete e; + } + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { + delete e; + } for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) { delete e; } for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { delete e; } + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K) { + delete e; + } + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) { + delete e; + } } ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { @@ -4039,6 +4523,51 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_iq4_nl * ggml_opencl_alloc_temp_tensor_extra_iq4_nl() { + ggml_tensor_extra_cl_iq4_nl * extra; + if (temp_tensor_extras_iq4_nl.empty()) { + extra = new ggml_tensor_extra_cl_iq4_nl(); + } else { + extra = temp_tensor_extras_iq4_nl.back(); + temp_tensor_extras_iq4_nl.pop_back(); + } + + temp_tensor_extras_iq4_nl_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() { + ggml_tensor_extra_cl_q4_K * extra; + if (temp_tensor_extras_q4_K.empty()) { + extra = new ggml_tensor_extra_cl_q4_K(); + } else { + extra = temp_tensor_extras_q4_K.back(); + temp_tensor_extras_q4_K.pop_back(); + } + + temp_tensor_extras_q4_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q5_K * ggml_opencl_alloc_temp_tensor_extra_q5_K() { + ggml_tensor_extra_cl_q5_K * extra; + if (temp_tensor_extras_q5_K.empty()) { + extra = new ggml_tensor_extra_cl_q5_K(); + } else { + extra = temp_tensor_extras_q5_K.back(); + temp_tensor_extras_q5_K.pop_back(); + } + + temp_tensor_extras_q5_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() { ggml_tensor_extra_cl_q6_K * extra; if (temp_tensor_extras_q6_K.empty()) { @@ -4080,6 +4609,21 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q8_0_in_use.clear(); + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) { + temp_tensor_extras_iq4_nl.push_back(e); + } + temp_tensor_extras_iq4_nl_in_use.clear(); + + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { + temp_tensor_extras_q4_K.push_back(e); + } + temp_tensor_extras_q4_K_in_use.clear(); + + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) { + temp_tensor_extras_q5_K.push_back(e); + } + temp_tensor_extras_q5_K_in_use.clear(); + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { temp_tensor_extras_q6_K.push_back(e); } @@ -4101,6 +4645,12 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; std::vector temp_tensor_extras_q8_0_in_use; + std::vector temp_tensor_extras_iq4_nl; + std::vector temp_tensor_extras_iq4_nl_in_use; + std::vector temp_tensor_extras_q4_K; + std::vector temp_tensor_extras_q4_K_in_use; + std::vector temp_tensor_extras_q5_K; + std::vector temp_tensor_extras_q5_K_in_use; std::vector temp_tensor_extras_q6_K; std::vector temp_tensor_extras_q6_K_in_use; @@ -4721,134 +5271,107 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - // Transpose weights - size_t q_size_bytes = K * M / 4 * sizeof(float); - cl_buffer_region region; - region.origin = 0; - region.size = q_size_bytes; - cl_mem qT_d = clCreateSubBuffer( - backend_ctx->prealloc_quant_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - cl_mem q_d_image1D; - cl_mem qT_d_image1D; - - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = extra->q; - q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = qT_d; - qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - int height_q = M / 4; - int width_q = K / 4 / 4; - kernel = backend_ctx->kernel_transpose_32; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); + transpose_2d_as_32b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + } // end transpose +#endif // GGML_OPENCL_USE_ADRENO_KERNELS - size_t local_size_q[3] = {4, 16, 1}; - size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); + return; + } + if (tensor->type == GGML_TYPE_IQ4_NL) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tensors in OpenCL backend should have been allocated and initialized"); - // Transpose scales - size_t d_size_bytes = M * (K / 32) * 2; - region.origin = 0; - region.size = d_size_bytes; - cl_mem dT_d = clCreateSubBuffer( - backend_ctx->prealloc_scales_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_iq4_nl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_iq4_nl(); - cl_mem d_d_image1D; - cl_mem dT_d_image1D; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)/2); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_fmt_1d = { CL_R, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32; - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = extra->d; - d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4; - img_desc_1d.buffer = dT_d; - dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); + cl_buffer_region region; - int height_s = M / 4; - int width_s = K / 32; + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; - kernel = backend_ctx->kernel_transpose_16_4x1; + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_iq4_nl_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl; + #endif + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; - size_t local_size_s[3] = {4, 16, 1}; - size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); - // copy transposed buffer contents to original buffers - CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); - CL_CHECK(clReleaseMemObject(qT_d)); - CL_CHECK(clReleaseMemObject(dT_d)); + tensor->extra = extra; - CL_CHECK(clReleaseMemObject(q_d_image1D)); - CL_CHECK(clReleaseMemObject(d_d_image1D)); - CL_CHECK(clReleaseMemObject(qT_d_image1D)); - CL_CHECK(clReleaseMemObject(dT_d_image1D)); - } // end transpose -#endif // GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + } +#endif return; } - if (tensor->type == GGML_TYPE_Q6_K) { + if (tensor->type == GGML_TYPE_Q4_K) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); // Allocate the new extra and create aliases from the original. ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K(); + ggml_tensor_extra_cl_q4_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_K(); - size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; - size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; - size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; - size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); - GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && - "Incorrect tensor size"); + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3 * ggml_blck_size(tensor->type) / 64); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_dm + size_s + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, @@ -4860,26 +5383,26 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_buffer_region region; - // Subbuffer for ql + // Create subbuffer for d. region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); - region.size = size_ql; - extra->ql = clCreateSubBuffer( + region.size = size_d; + extra->d = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); auto previous_origin = region.origin; - // Subbuffer for qh - region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment); - region.size = size_qh; - extra->qh = clCreateSubBuffer( + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_dm; + extra->dm = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); previous_origin = region.origin; - // Subbuffer for scales - region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + // Create subbuffer for s. + region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment); region.size = size_s; extra->s = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, @@ -4887,23 +5410,144 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(err); previous_origin = region.origin; - // Create subbuffer for d. + // Create subbuffer for quants. region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + #endif + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q, d, dm as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M); + transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } + if (tensor->type == GGML_TYPE_Q5_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q5_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_K(); + + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/8; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3*ggml_blck_size(tensor->type)/64); + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_q + size_qh + size_s + size_d + size_dm == ggml_nbytes(tensor) && + "Incorrect tensor size"); + + cl_int err; + cl_mem data_device; + CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err)); + CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for d. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_d; extra->d = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for dm. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_dm; + extra->dm = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); previous_origin = region.origin; - // Flatten the weights - cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K; + // Create subbuffer for s. + region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + // Create subbuffer for q (lower 4 bits) + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qh (upper 1 bit) + region.origin = align_to(previous_origin + size_q, backend_ctx->alignment); + region.size = size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + CL_CHECK(err); + + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q5_K_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + #endif + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {64, 1, 1}; @@ -4913,12 +5557,134 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clWaitForEvents(1, &evt)); CL_CHECK(clReleaseMemObject(data_device)); + extra->size_q = size_q; + extra->size_qh = size_qh; + extra->size_s = size_s; + extra->size_d = size_d; + extra->size_dm = size_dm; + + tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q, d, dm as ushort, qh as uchar + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_8b (backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M); + transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } + if (tensor->type == GGML_TYPE_Q6_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K(); + + size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && + "Incorrect tensor size"); + + cl_int err; + cl_mem data_device; + CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err)); + CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_ql; + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment); + region.size = size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_s; + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Create subbuffer for d. + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_d; + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Flatten the weights + cl_kernel kernel; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + kernel = backend_ctx->kernel_convert_block_q6_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle; + } +#else + kernel = backend_ctx->kernel_convert_block_q6_K; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_uchar mask = 0xff; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + extra->size_ql = size_ql; extra->size_qh = size_qh; extra->size_s = size_s; extra->size_d = size_d; tensor->extra = extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + // Transpose ql as ushort + transpose_2d_as_16b(backend_ctx, + extra->ql, extra->ql, size_ql, K/4, M); + + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, + extra->qh, extra->qh, size_qh, K/4, M); + + // Transpose s as ushort + transpose_2d_as_16b(backend_ctx, + extra->s, extra->s, size_s, K/16/2, M); + + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, + extra->d, extra->d, size_d, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } #endif // GGML_OPENCL_SOA_Q @@ -5245,22 +6011,66 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (tensor->type == GGML_TYPE_Q6_K) { - ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; + if (tensor->type == GGML_TYPE_IQ4_NL) { + ggml_tensor_extra_cl_iq4_nl * extra = (ggml_tensor_extra_cl_iq4_nl *)tensor->extra; cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); - cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*(ggml_blck_size(tensor->type)/2); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose q, d back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl_noshuffle; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif + cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; size_t local_work_size[] = {1, 1, 1}; cl_event evt; @@ -5273,41 +6083,288 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } -#endif // GGML_OPENCL_SOA_Q + if (tensor->type == GGML_TYPE_Q4_K) { + ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra; - ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); - CL_CHECK(clEnqueueReadBuffer( - queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset, - size, data, 0, NULL, NULL)); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; - GGML_UNUSED(buffer); -} +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; -static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_dev_t dev = buffer->buft->device; - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); - cl_command_queue queue = backend_ctx->queue; + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - for (cl_mem buf : ctx->buffer) { - CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL)); - } - CL_CHECK(clFinish(queue)); -} + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_trans_dm; -static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) { - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ctx->reset(); -} + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_dm.allocate(backend_ctx->context, size_dm); + + // Transpose q, d, dm back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_dm.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); -static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { - /* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer, - /* .get_base = */ ggml_backend_opencl_buffer_get_base, - /* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor, - /* .memset_tensor = */ NULL, - /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q5_K) { + ggml_tensor_extra_cl_q5_K * extra = (ggml_tensor_extra_cl_q5_K *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + size_t size_q = extra->size_q; + size_t size_qh = extra->size_qh; + size_t size_d = extra->size_d; + size_t size_dm = extra->size_dm; + + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_qh; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_trans_dm; + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_dm.allocate(backend_ctx->context, size_dm); + + // Reverse transpose q, qh, d, dm + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_8b (backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_trans_dm.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q6_K) { + ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_ql; + static ggml_cl_buffer buf_trans_qh; + static ggml_cl_buffer buf_trans_s; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0); + + size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_ql.allocate(backend_ctx->context, size_ql); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_s.allocate(backend_ctx->context, size_s); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose ql, qh, s and d back + transpose_2d_as_16b(backend_ctx, extra->ql, buf_trans_ql.buffer, size_ql, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->s, buf_trans_s.buffer, size_s, M, K/16/2); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + + // unpack + cl_uchar mask = 0xFF; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_ql.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_s.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_uchar mask = 0xFF; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_SOA_Q + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + + CL_CHECK(clEnqueueReadBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset, + size, data, 0, NULL, NULL)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_dev_t dev = buffer->buft->device; + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + cl_command_queue queue = backend_ctx->queue; + + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + for (cl_mem buf : ctx->buffer) { + CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL)); + } + CL_CHECK(clFinish(queue)); +} + +static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ctx->reset(); +} + +static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { + /* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer, + /* .get_base = */ ggml_backend_opencl_buffer_get_base, + /* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_opencl_buffer_clear, /* .reset = */ ggml_backend_opencl_buffer_reset, @@ -5331,6 +6388,11 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b cl_int err; cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err); + if (err != CL_SUCCESS && backend_ctx->adreno_use_large_buffer) { + cl_mem_properties props[] = { 0x41A6 /* CL_LARGE_BUFFER_QCOM */, 1, 0 }; + mem = clCreateBufferWithProperties(backend_ctx->context, props, CL_MEM_READ_WRITE, size, NULL, &err); + } + if (err != CL_SUCCESS) { GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0); return nullptr; @@ -5553,6 +6615,8 @@ typedef struct { static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); +#define QK_MXFP4 32 + #include #ifdef __cplusplus #include "half.hpp" @@ -5596,7 +6660,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso buf_d = malloc(size_e); CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); - CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->e, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); } else { // Read out the tensor from GPU memory. @@ -9084,7 +10148,7 @@ static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_t #endif } -static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9093,181 +10157,119 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0->type; - const enum ggml_type src1t = src1->type; - - GGML_ASSERT(src0t == GGML_TYPE_Q8_0); - GGML_ASSERT(src1t == GGML_TYPE_F32); - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; - - GGML_ASSERT(src1->view_offs == 0); - GGML_ASSERT(dst->view_offs == 0); + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - - const int ne10 = src1->ne[0]; - const int ne12 = src1->ne[2]; - const int ne0 = dst->ne[0]; const int ne1 = dst->ne[1]; - GGML_ASSERT(ne00 == ne10); - GGML_ASSERT((ne00 % 32) == 0); - GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne00 % 32 == 0); cl_context context = backend_ctx->context; cl_kernel kernel; - // init CL objects - cl_int status; - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; cl_buffer_region region; - cl_mem A_image1d; - cl_mem B_image1d; - cl_mem B_sub_buffer; - cl_mem S_image1d; - - cl_mem D_image1d; - cl_mem D_sub_buffer; int M = ne01; int N = ne1; int K = ne00; - // create an image for A - img_fmt_1d = { CL_R, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4; // Divide by 4 for char -> float - img_desc_1d.buffer = extra0_q8_0->q; - A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - - // create an image for Scale - img_fmt_1d = { CL_R, CL_HALF_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32; // Block size is 32 - img_desc_1d.buffer = extra0_q8_0->d; - S_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - - // create a sub_buffer for B - region.origin = (extra1->offset); // + src1->view_offs); - region.size = K * N * sizeof(float); - B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - // create an image for B from sub_buffer: RGBA (OCL) - img_fmt_1d = {CL_RGBA, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = K * N / 4; - img_desc_1d.buffer = B_sub_buffer; - B_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; - // Create subbuffer and image1d_buffer for dst - region.origin = (extrad->offset); // + dst->view_offs; - region.size = M * N * sizeof(float); - D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_iq4_nl->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - img_fmt_1d = {CL_R, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * N; - img_desc_1d.buffer = D_sub_buffer; - D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); - size_t local_work_size[3] = {1, 1, 1}; - size_t global_work_size[3] = {1, 1, 1}; + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - if (N == 1) { - kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32; + kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32; - int r2 = 1; - int r3 = 1; - cl_uint k_arg = 0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q8_0->d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; - size_t wavesize = backend_ctx->adreno_wave_size; - local_work_size[0] = wavesize; - local_work_size[1] = 4; // reduce factor - local_work_size[2] = 1; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - global_work_size[0] = ((M + wavesize - 1) / wavesize) * wavesize; - global_work_size[1] = 4; // reduce factor - global_work_size[2] = 1; + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); } else { - cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_mem B_image1d_trans = nullptr; - // for B transpose - cl_mem B_d = nullptr; - int padding; - - //how many extra elements beyond multiple of 8 - int extra_elements = N % 8; + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; - //how much padding to add - padding = 0; + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; if (extra_elements > 0){ padding = 8 - extra_elements; } - // Specify the starting offset (in bytes) + // subbuffer for transposed activations region.origin = 0; - // Specify the size of the sub-buffer (divide by 2 for FP16) region.size = K * (N + padding) * sizeof(float)/2; backend_ctx->prealloc_act_trans.allocate(context, region.size); - B_d = clCreateSubBuffer( - backend_ctx->prealloc_act_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); - cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) - cl_image_desc image_desc_B_d_output = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(K * (N + padding)/4), - 0, 0, 0, 0, 0, 0, 0, { B_d } - }; - B_image1d_trans = clCreateImage( - context, - 0, - &image_format_B_d_output, - &image_desc_B_d_output, - NULL, - &status); - CL_CHECK(status); + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + // transpose activations int height_B = N/4; if (height_B == 0) { height_B = 1; @@ -9276,53 +10278,808 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t int padded_height_B = (N + padding)/4; kernel = backend_ctx->kernel_transpose_32_16; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); - size_t local_size_t[2] = { 1, 16 }; - size_t global_size_t[2] = { - static_cast(width_B), - static_cast(padded_height_B) - }; + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32; + int padded_N = N + padding; - kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + GGML_ASSERT(src0->type == GGML_TYPE_Q8_0); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_ASSERT(src1->view_offs == 0); + GGML_ASSERT(dst->view_offs == 0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 == ne10); + GGML_ASSERT((ne00 % 32) == 0); + GGML_ASSERT(ne0 == ne01); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 4; + img_desc.buffer = extra0_q8_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // create a sub_buffer for B + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32; + + int r2 = 1; + int r3 = 1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + + size_t wavesize = backend_ctx->adreno_wave_size; + size_t local_work_size[] = { wavesize, 4, 1 }; + size_t global_work_size[] = { CEIL_DIV(M, wavesize)*wavesize, 4, 1 }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); - int N_with_padding = N + padding; + // gemm + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4; + int padded_N = N + padding; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &K)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &M)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &N_with_padding)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &padded_N)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); - global_work_size[0] = (size_t)(N + 7) / 8; - global_work_size[1] = (size_t)(M + 3) / 4; - global_work_size[2] = 1; + size_t global_work_size[] = { (size_t)CEIL_DIV(N, 8), (size_t)CEIL_DIV(M, 4), 1 }; + size_t local_work_size[] = { 2, 128, 1 }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_q4_k_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_k = (ggml_tensor_extra_cl_q4_K *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; - local_work_size[0] = 2; - local_work_size[1] = 128; - local_work_size[2] = 1; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + cl_uchar mask_d6 = 0x3F; + cl_uchar mask_d4 = 0x0F; + cl_uchar mask_hi2 = 0xC0; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_k->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_k_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_hi2)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_k_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_k->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_hi2)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_buffer_region region; + cl_image_format img_fmt; + cl_image_desc img_desc; + + // subbuffer and image for activation + if (ne1 == 1) { + cl_mem ql_img = nullptr; + cl_mem qh_img = nullptr; + cl_mem b_sub_buffer = nullptr; + cl_mem b_img = nullptr; + + // image for ql + img_fmt.image_channel_order = CL_R; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne01 * ne00 / 8; + img_desc.buffer = extra0_q6_K->ql; + CL_CHECK((ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // image for qh + img_fmt.image_channel_order = CL_R; + img_fmt.image_channel_data_type = CL_HALF_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne01 * ne00 / 8; + img_desc.buffer = extra0_q6_K->qh; + CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + region.origin = offset1; + region.size = ne00 * ne1 * sizeof(float); + CL_CHECK((b_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * ne1 / 4; + img_desc.buffer = b_sub_buffer; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q6_K_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &ql_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(ql_img)); + CL_CHECK(clReleaseMemObject(qh_img)); + CL_CHECK(clReleaseMemObject(b_sub_buffer)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf; + cl_mem b_buf_trans; + cl_mem b_img; + cl_mem b_img_trans; + + // subbuffer for activation + region.origin = offset1; + region.size = ne00 * ne1 * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activation + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * ne1 / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = ne1 % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activation + region.origin = 0; + region.size = ne00 * (ne1 + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activation + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_HALF_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * (ne1 + padding) / 4; + img_desc.buffer = b_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activation + int height_B = ne1/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = ne00/4; + int padded_height_B = (ne1 + padding) / 4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + size_t global_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q6_K_f32; + int padded_N = ne1 + padding; + + cl_ushort mask_f000 = 0xF000; + cl_uchar mask_c0 = 0xC0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ushort),&mask_f000)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_c0)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {2, 128, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_q5_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_k = (ggml_tensor_extra_cl_q5_K *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + cl_uchar mask_d6 = 0x3F; + cl_uchar mask_d4 = 0x0F; + cl_uchar mask_hi2 = 0xC0; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem qh_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q (CL_R, CL_UNSIGNED_INT32): width = M*K/2/4 + img_fmt = {CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_k->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // image for qh (CL_R, CL_HALF_FLOAT): width = M*K/16 + img_fmt = {CL_R, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 16; + img_desc.buffer = extra0_q5_k->qh; + CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations (CL_RGBA, CL_FLOAT): width = K*N/4 + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_k_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_k->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_k->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_hi2)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(qh_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0) { + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float) / 2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N / 4; + if (height_B == 0) height_B = 1; + int width_B = K / 4; + int padded_height_B = (N + padding) / 4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = {1, 16}; + size_t global_work_size_t[2] = {(size_t)width_B, (size_t)padded_height_B}; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_k_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_k->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_k->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_k->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_k->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_uchar), &mask_hi2)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); } - - // enqueue kernel with profiling - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - - // deallocate sub buffers and images - CL_CHECK(clReleaseMemObject(A_image1d)); - CL_CHECK(clReleaseMemObject(B_sub_buffer)); - CL_CHECK(clReleaseMemObject(B_image1d)); - CL_CHECK(clReleaseMemObject(S_image1d)); - CL_CHECK(clReleaseMemObject(D_sub_buffer)); - CL_CHECK(clReleaseMemObject(D_image1d)); #else GGML_UNUSED(backend); GGML_UNUSED(src0); @@ -9357,6 +11114,9 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif @@ -9459,6 +11219,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // iq4_nl x fp32 + if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); + return; + } + // q8_0 x fp32 if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && enable_adreno_trans_weight(backend_ctx, src0)) { @@ -9466,6 +11232,24 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q4_k x fp32 + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); + return; + } + + // q6_K x fp32 + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); + return; + } + + // q5_K x fp32 + if (src0t == GGML_TYPE_Q5_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_K_f32_adreno(backend, src0, src1, dst); + return; + } + // q4_0 x fp32 if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { // TODO: remove duplicate definitions of image description + format -- move to top @@ -10005,6 +11789,137 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_IQ4_NL: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q4_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q5_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q6_K: { if (ne11 < 32) { break; @@ -10443,12 +12358,113 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_IQ4_NL: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; } case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q4_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); +#else kernel = backend_ctx->kernel_mul_mv_q4_K_f32; if (backend_ctx->gpu_family == INTEL) { @@ -10482,9 +12498,84 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q5_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_q5_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q break; } - case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat; @@ -10633,6 +12724,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_Q2_K) { // Each SIMD group produces N_DST values in the result. Assuming each // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will @@ -10652,7 +12744,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } else if (src0t == GGML_TYPE_Q3_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q5_K) { - GGML_ASSERT(false && "not implemented"); + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q6_K) { size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 78ef9c177f6..f3937d8304c 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -28,6 +28,7 @@ #define QK8_0 32 #define QR8_0 1 #define QK_K 256 +#define K_SCALE_SIZE (3 * QK_K / 64) #define K_QUANTS_PER_ITERATION 2 typedef char int8_t; @@ -55,6 +56,27 @@ struct block_q4_1 { uchar qs[QK4_1 / 2]; // nibbles / quants }; +//------------------------------------------------------------------------------ +// block_q4_k +//------------------------------------------------------------------------------ +struct block_q4_K { + half d; // delta + half dm; // min + uchar s[K_SCALE_SIZE]; + uchar q[QK_K / 2]; // nibbles / quants +}; + +//------------------------------------------------------------------------------ +// block_q5_k +//------------------------------------------------------------------------------ +struct block_q5_K { + half d; // delta + half dm; // min + uchar s[K_SCALE_SIZE]; + uchar qh[QK_K / 8]; + uchar qs[QK_K / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -65,6 +87,17 @@ struct block_q6_K { half d; // super-block scale }; +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +#define QK4_NL 32 + +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). @@ -408,6 +441,288 @@ kernel void kernel_restore_block_q8_0_trans( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_K +// Convert the block_q4_K format to 4 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +// Each thread processes a super block. +// Mask args are just to keep the signature consistent with the no-shuffle +// version and they are not used in this kernel. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_K( + global struct block_q4_K * src0, + global uchar * dst_q, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K/2; ++i) { + q[i] = b->q[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +// Restore block_q4_K from flattened arrays. +// Each thread processes a super block. +// Mask args are just to keep the signature consistent with the no-shuffle ones. +kernel void kernel_restore_block_q4_K( + global uchar * src_q, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q4_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K/2; ++i) { + b->q[i] = q[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + +kernel void kernel_convert_block_q4_K_noshuffle( + global struct block_q4_K * src0, + global uchar * dst_q, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_K_noshuffle( + global uchar * src_q, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q4_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = q[i*32 + j]; + uchar hi = q[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_K +// Convert the block_q5_K format to 5 separate arrays (AOS -> SOA). +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_K( + global struct block_q5_K * src0, + global uchar * dst_q, + global uchar * dst_qh, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/8*get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K/2; ++i) { + q[i] = b->qs[i]; + } + for (int i = 0; i < QK_K/8; ++i) { + qh[i] = b->qh[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +// Restore block_q5_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q5_K( + global uchar * src_q, + global uchar * src_qh, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q5_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/8*get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K/2; ++i) { + b->qs[i] = q[i]; + } + for (int i = 0; i < QK_K/8; ++i) { + b->qh[i] = qh[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + +kernel void kernel_convert_block_q5_K_noshuffle( + global struct block_q5_K * src0, + global uchar * dst_q, + global uchar * dst_qh, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/8 * get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + for (int l = 0; l < QK_K/8; ++l) { + uchar x0 = 0; + for (int i = 0; i < 8; ++i) { + x0 |= ((b->qh[(l%4)*8+i] >> (l/4)) & 0x01) << i; + } + qh[l] = x0; + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_K_noshuffle( + global uchar * src_q, + global uchar * src_qh, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q5_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/8 * get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = q[i*32 + j]; + uchar hi = q[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } + + for (int g = 0; g < 4; ++g) { + for (int i = 0; i < 8; ++i) { + uchar x0 = 0; + for (int k = 0; k < 8; ++k) { + x0 |= ((qh[4*k+g] >> i) & 0x01) << k; + } + b->qh[g*8+i] = x0; + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). @@ -419,8 +734,13 @@ kernel void kernel_convert_block_q6_K( global uchar * dst_ql, global uchar * dst_qh, global char * dst_s, - global half * dst_d + global half * dst_d, + uchar mask_lsb_8, + ulong n_blk ) { + if (get_global_id(0) >= n_blk) { + return; + } global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); @@ -447,8 +767,13 @@ kernel void kernel_restore_block_q6_K( global uchar * dst_qh, global char * dst_s, global half * dst_d, - global struct block_q6_K * dst + global struct block_q6_K * dst, + uchar mask_lsb_8, + ulong n_blk ) { + if (get_global_id(0) >= n_blk) { + return; + } global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); @@ -467,3 +792,213 @@ kernel void kernel_restore_block_q6_K( b->scales[i] = s[i]; } } + +kernel void kernel_convert_block_q6_K_noshuffle( + global struct block_q6_K * src0, + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK_K/2/4; ++i) { + uchar x0 = b->ql[i*2 + 0] & mask_lsb_8; + uchar x1 = b->ql[i*2 + 1] & mask_lsb_8; + ql[i + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4); + ql[i + 32] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0); + + uchar x2 = b->ql[i*2 + 0 + 64] & mask_lsb_8; + uchar x3 = b->ql[i*2 + 1 + 64] & mask_lsb_8; + ql[i + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4); + ql[i + 96] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0); + } + + for (int i = 0; i < QK_K/4/8; ++i) { + uchar x0 = b->qh[i*4 + 0] & mask_lsb_8; + uchar x1 = b->qh[i*4 + 1] & mask_lsb_8; + uchar x2 = b->qh[i*4 + 2] & mask_lsb_8; + uchar x3 = b->qh[i*4 + 3] & mask_lsb_8; + qh[i + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6); + qh[i + 8] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4); + qh[i + 16] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2); + qh[i + 24] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0); + + uchar x4 = b->qh[i*4 + 0 + 32] & mask_lsb_8; + uchar x5 = b->qh[i*4 + 1 + 32] & mask_lsb_8; + uchar x6 = b->qh[i*4 + 2 + 32] & mask_lsb_8; + uchar x7 = b->qh[i*4 + 3 + 32] & mask_lsb_8; + qh[i + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6); + qh[i + 40] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4); + qh[i + 48] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2); + qh[i + 56] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0); + } + + for (int i = 0; i < QK_K/16; ++i) { + s[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_K_noshuffle( + global uchar * src_ql, + global uchar * src_qh, + global char * src_s, + global half * src_d, + global struct block_q6_K * dst, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); + global uchar * ql = (global uchar *) src_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) src_s + QK_K/16*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK_K/2/4; ++i) { + uchar x0 = ql[i + 0] & mask_lsb_8; + uchar x1 = ql[i + 32] & mask_lsb_8; + b->ql[i*2 + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4); + b->ql[i*2 + 1] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0); + + uchar x2 = ql[i + 64] & mask_lsb_8; + uchar x3 = ql[i + 96] & mask_lsb_8; + b->ql[i*2 + 0 + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4); + b->ql[i*2 + 1 + 64] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0); + } + + for (int i = 0; i < QK_K/4/8; ++i) { + uchar x0 = qh[i + 0] & mask_lsb_8; + uchar x1 = qh[i + 8] & mask_lsb_8; + uchar x2 = qh[i + 16] & mask_lsb_8; + uchar x3 = qh[i + 24] & mask_lsb_8; + b->qh[i*4 + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6); + b->qh[i*4 + 1] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4); + b->qh[i*4 + 2] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2); + b->qh[i*4 + 3] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0); + + uchar x4 = qh[i + 0 + 32] & mask_lsb_8; + uchar x5 = qh[i + 8 + 32] & mask_lsb_8; + uchar x6 = qh[i + 16 + 32] & mask_lsb_8; + uchar x7 = qh[i + 24 + 32] & mask_lsb_8; + b->qh[i*4 + 0 + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6); + b->qh[i*4 + 1 + 32] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4); + b->qh[i*4 + 2 + 32] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2); + b->qh[i*4 + 3 + 32] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0); + } + + for (int i = 0; i < QK_K/16; ++i) { + b->scales[i] = s[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_iq4_nl +// Convert the block_iq4_nl format to 2 separate arrays (AOS -> SOA). +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_iq4_nl( + global struct block_iq4_nl * src0, + global uchar * dst_q, + global half * dst_d, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_NL/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_iq4_nl( + global uchar * src_q, + global half * src_d, + global struct block_iq4_nl * dst, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK4_NL/2; ++i) { + b->qs[i] = q[i]; + } +} + +kernel void kernel_convert_block_iq4_nl_noshuffle( + global struct block_iq4_nl * src0, + global uchar * dst_q, + global half * dst_d, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + for (int i = 0; i < QK4_NL/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i + QK4_NL/4] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } +} + +kernel void kernel_restore_block_iq4_nl_noshuffle( + global uchar * src_q, + global half * src_d, + global struct block_iq4_nl * dst, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_NL/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK4_NL/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl new file mode 100644 index 00000000000..6869d822862 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl @@ -0,0 +1,150 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +constant half kvalues_iq4nl[16] = { + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f +}; + +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 +constant uint iq4nl_packed[8] = { + 0xD680D7F0u, // idx 0,1: -127, -104 + 0xD410D530u, // idx 2,3: -83, -65 + 0xD060D220u, // idx 4,5: -49, -35 + 0xC900CD80u, // idx 6,7: -22, -10 + 0x4A803C00u, // idx 8,9: 1, 13 + 0x50C04E40u, // idx 10,11: 25, 38 + 0x545052A0u, // idx 12,13: 53, 69 + 0x57105590u // idx 14,15: 89, 113 +}; + +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_iq4_nl_f32( + global const ushort * src0_q, + global const half * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding +) { + dst = (global float *)((global char *)dst + offsetd); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_q + gx_2; + global const half * scale_ptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1); + + ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); + + half4 scale = vload4(0, scale_ptr + (i/32)*(m)); + + // j=0 + dequantized_weights.s0 = IQ4_NL_DEQUANT(bits4.s0 & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT(bits4.s1 & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT(bits4.s2 & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT(bits4.s3 & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 4) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 4) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 4) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 4) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 8) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 8) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 8) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 8) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 12) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 12) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 12) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 12) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl new file mode 100644 index 00000000000..99fd1fd7bf1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl @@ -0,0 +1,172 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q4_k_f32( + global const ushort * src0_q, + global const uchar * src0_s, + global const half * src0_d, + global const half * src0_dm, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + dst = (global float *)((global char *)dst + offsetd); + int n_4 = n >> 2; + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + int num_blocks_K = k / QK_K; + + global const ushort * weight_ptr = src0_q + gx_2; + global const half * d_ptr = src0_d + gx_2; + global const half * dm_ptr = src0_dm + gx_2; + + for (int i = 0; i < k; i += 32) { + int sb_idx = i / QK_K; + int sub_idx = (i / 32) % 8; + + half4 d = vload4(0, d_ptr + sb_idx * m); + half4 dm = vload4(0, dm_ptr + sb_idx * m); + + global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + + uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3; + get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2); + + half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3))); + half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3))); + + for (int l = 0; l < 32; l += 4) { + int ki = i + l; + ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m); + + // j=0 + B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4); + dequantized_weights.s0 = (bits4.s0 & 0x000F) * scale.s0 - mval.s0; + dequantized_weights.s1 = (bits4.s1 & 0x000F) * scale.s1 - mval.s1; + dequantized_weights.s2 = (bits4.s2 & 0x000F) * scale.s2 - mval.s2; + dequantized_weights.s3 = (bits4.s3 & 0x000F) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x00F0) >> 4) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x00F0) >> 4) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x00F0) >> 4) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x00F0) >> 4) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x0F00) >> 8) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x0F00) >> 8) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x0F00) >> 8) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x0F00) >> 8) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0xF000) >> 12) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0xF000) >> 12) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0xF000) >> 12) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0xF000) >> 12) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + } + + int idx = (gy<<3)*m + (gx<<2); + + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl new file mode 100644 index 00000000000..058c0f7edc6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl @@ -0,0 +1,176 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q5_k_f32( + global const ushort * src0_q, + global const uchar * src0_qh, + global const uchar * src0_s, + global const half * src0_d, + global const half * src0_dm, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + dst = (global float *)((global char *)dst + offsetd); + int n_4 = n >> 2; + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + int num_blocks_K = k / QK_K; + + global const ushort * weight_ptr = src0_q + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * d_ptr = src0_d + gx_2; + global const half * dm_ptr = src0_dm + gx_2; + + for (int i = 0; i < k; i += 32) { + int sb_idx = i / QK_K; + int sub_idx = (i / 32) % 8; + + half4 d = vload4(0, d_ptr + sb_idx * m); + half4 dm = vload4(0, dm_ptr + sb_idx * m); + + global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + + uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3; + get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2); + + half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3))); + half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3))); + + for (int l = 0; l < 32; l += 4) { + int ki = i + l; + ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m); + uchar4 qh_bits = vload4(0, qh_ptr + (ki/8) * m); + int qh_shift = ki % 8; + + // j=0 + B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x000F) | (((qh_bits.s0 >> (qh_shift+0)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x000F) | (((qh_bits.s1 >> (qh_shift+0)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x000F) | (((qh_bits.s2 >> (qh_shift+0)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x000F) | (((qh_bits.s3 >> (qh_shift+0)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0x00F0) >> 4) | (((qh_bits.s0 >> (qh_shift+1)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0x00F0) >> 4) | (((qh_bits.s1 >> (qh_shift+1)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0x00F0) >> 4) | (((qh_bits.s2 >> (qh_shift+1)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0x00F0) >> 4) | (((qh_bits.s3 >> (qh_shift+1)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0x0F00) >> 8) | (((qh_bits.s0 >> (qh_shift+2)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0x0F00) >> 8) | (((qh_bits.s1 >> (qh_shift+2)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0x0F00) >> 8) | (((qh_bits.s2 >> (qh_shift+2)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0x0F00) >> 8) | (((qh_bits.s3 >> (qh_shift+2)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0xF000) >> 12) | (((qh_bits.s0 >> (qh_shift+3)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0xF000) >> 12) | (((qh_bits.s1 >> (qh_shift+3)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0xF000) >> 12) | (((qh_bits.s2 >> (qh_shift+3)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0xF000) >> 12) | (((qh_bits.s3 >> (qh_shift+3)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + } + + int idx = (gy<<3)*m + (gx<<2); + + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl new file mode 100644 index 00000000000..3a9c624508a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl @@ -0,0 +1,140 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q6_K_f32( + global const ushort * src0_ql, + global const uchar * src0_qh, + global const ushort * src0_s, + global const half * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + ushort mask_f000, + uchar mask_c0 +) { + dst = (global float *)( (global char *)dst + offsetd ); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); // n + int gx = get_global_id(1); // m + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * ptr_ql = src0_ql + gx_2; + global const uchar * ptr_qh = src0_qh + gx_2; + global const ushort * ptr_s = src0_s + gx_2; + global const half * ptr_d = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + // load 4x elements (ushort) of ql on M, each ushort contains 4 weights + // 4x ushort correspons to 4 rows on M + ushort4 bits4 = vload4(0, ptr_ql + (i/4)*m); // ql packed in 4s in ushort + uchar4 bits2 = vload4(0, ptr_qh + (i/4)*m); // qh packed in 4s in uchar + + // load 4 consecutive scales + char8 scale_s_8 = as_char8(vload4(0, ptr_s + (i/16/2)*m)); // 1 char scale every 16 elements, packed in 2s + char4 scale_s = ((i/16) % 2) == 0 ? scale_s_8.s0246 : scale_s_8.s1357; // transposed as ushort, 2 blocks + half4 scale_d = vload4(0, ptr_d + (i/256)*m); // 1 half scale every 256 elements + + // j=0 + // load 2x 4 elements of activations on N, corresponding to 8 rows on N + B.s0123 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 1); + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x0F00) >> 8) | (bits2.s0 & 0x30))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x0F00) >> 8) | (bits2.s1 & 0x30))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x0F00) >> 8) | (bits2.s2 & 0x30))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x0F00) >> 8) | (bits2.s3 & 0x30))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & mask_f000) >> 12) | ((bits2.s0 & mask_c0) >> 2))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & mask_f000) >> 12) | ((bits2.s1 & mask_c0) >> 2))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & mask_f000) >> 12) | ((bits2.s2 & mask_c0) >> 2))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & mask_f000) >> 12) | ((bits2.s3 & mask_c0) >> 2))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl new file mode 100644 index 00000000000..9386bf25a6f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl @@ -0,0 +1,302 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK4_NL 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +constant half kvalues_iq4nl[16] = { + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f +}; + +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 +constant uint iq4nl_packed[8] = { + 0xD680D7F0u, // idx 0,1: -127, -104 + 0xD410D530u, // idx 2,3: -83, -65 + 0xD060D220u, // idx 4,5: -49, -35 + 0xC900CD80u, // idx 6,7: -22, -10 + 0x4A803C00u, // idx 8,9: 1, 13 + 0x50C04E40u, // idx 10,11: 25, 38 + 0x545052A0u, // idx 12,13: 53, 69 + 0x57105590u // idx 14,15: 89, 113 +}; + +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_iq4_nl_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_NL); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl new file mode 100644 index 00000000000..dd1e2b55c0b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl @@ -0,0 +1,318 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK_K 256 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q4_k_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + global half2 * src0_m, + global uchar * src0_s, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + uint scales_per_row = (K / QK_K) * 12; + + private uint4 regA; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) { + uint sb = k / 8; + uint j = k % 8; + + half2 d = src0_d[gid + sb * LINE_STRIDE_A]; + half2 dm = src0_m[gid + sb * LINE_STRIDE_A]; + + global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12; + global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12; + + uchar sv0, mn0, sv1, mn1; + get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + + regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1))); + regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1))); + + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl new file mode 100644 index 00000000000..c40db166638 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl @@ -0,0 +1,326 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK_K 256 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q5_k_f32( + read_only image1d_buffer_t src0_q, + read_only image1d_buffer_t src0_qh, + global half2 * src0_d, + global half2 * src0_m, + global uchar * src0_s, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + uint LINE_STRIDE_A_QH = M / 2; + uint BLOCK_STRIDE_A_QH = NSUBGROUPS * M / 2; + uint scales_per_row = (K / QK_K) * 12; + + private uint4 regA; + private ushort4 regH; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) { + uint sb = k / 8; + uint j = k % 8; + + half2 d = src0_d[gid + sb * LINE_STRIDE_A]; + half2 dm = src0_m[gid + sb * LINE_STRIDE_A]; + + global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12; + global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12; + + uchar sv0, mn0, sv1, mn1; + get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + + regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1))); + regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1))); + + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regH.s0 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 0)).x); + regH.s1 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 1)).x); + regH.s2 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 2)).x); + regH.s3 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 3)).x); + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl new file mode 100644 index 00000000000..6f89cf968b9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl @@ -0,0 +1,293 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantize_block_acc_bcast_8_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \ + +#define dequantize_block_acc_bcast_8_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \ + +#define dequantize_block_acc_bcast_1_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + +#define dequantize_block_acc_bcast_1_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + +#if defined(ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q6_K_f32( + read_only image1d_buffer_t src0_ql, + read_only image1d_buffer_t src0_qh, + global half2 * src0_s, + global half2 * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01 +) { + int grp = get_local_id(1); + int gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + int nb = ne00 / 32; + + uint4 reg_a_l; + ushort4 reg_a_h; + half2 reg_d; + char4 reg_s; + float8 reg_b; + + float2 total_sum = 0.0f; + + int line_stride_a = ne01 / 2; + int block_stride_a = NSUBGROUPS * ne01; + + for (int k = grp; k < nb; k += NSUBGROUPS) { + reg_d = src0_d[gid + k/8 * line_stride_a]; + reg_s = as_char4(src0_s[gid + k * line_stride_a]); + + if (slid < 4) { + reg_b.s0123 = read_imagef(src1, 0 + slid*2 + k*8); + reg_b.s4567 = read_imagef(src1, 1 + slid*2 + k*8); + } + + reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*0).x; + reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*1).x; + reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*2).x; + reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*3).x; + + reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*0).x); + reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*1).x); + reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*2).x); + reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*3).x); + +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantize_block_acc_bcast_8_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#else + dequantize_block_acc_bcast_1_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#endif // VECTOR_SUB_GROUP_BROADCAT + + reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*4).x; + reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*5).x; + reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*6).x; + reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*7).x; + + reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*4).x); + reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*5).x); + reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*6).x); + reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*7).x); + +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantize_block_acc_bcast_8_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#else + dequantize_block_acc_bcast_1_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + local float2 reduce_lm[SUBGROUP_SIZE * 3]; + if (grp == 1) { + reduce_lm[SUBGROUP_SIZE*0 + slid] = total_sum; + } + if (grp == 2) { + reduce_lm[SUBGROUP_SIZE*1 + slid] = total_sum; + } + if (grp == 3) { + reduce_lm[SUBGROUP_SIZE*2 + slid] = total_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*0 + slid]; + } + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*1 + slid]; + } + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*2 + slid]; + } + + if (grp == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(total_sum, 0, &(dst[gid * 2])); + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl new file mode 100644 index 00000000000..11ff7f8d9dc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl @@ -0,0 +1,171 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_mul_mm_iq4_nl_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + // IQ4_NL: use lookup table instead of linear (nibble - 8) + float4 v1 = (float4)(kvalues_iq4nl[(q.s0 )&0x0F], kvalues_iq4nl[(q.s1 )&0x0F], + kvalues_iq4nl[(q.s2 )&0x0F], kvalues_iq4nl[(q.s3 )&0x0F])*d; + float4 v2 = (float4)(kvalues_iq4nl[(q.s0>>4)&0x0F], kvalues_iq4nl[(q.s1>>4)&0x0F], + kvalues_iq4nl[(q.s2>>4)&0x0F], kvalues_iq4nl[(q.s3>>4)&0x0F])*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl new file mode 100644 index 00000000000..2235b1ae838 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl @@ -0,0 +1,179 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_k_f32_l4_lm( + global uchar4 * src0_q, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 64; + int iqs = (idx % 64) * 2; + + int n = iqs / 32; + int b = (iqs % 32) / 16; + int is = 2 * n + b; + int qsi = n * 32 + (iqs % 16) * 2; + + char * scales = src0_s + ib * 12; + + int scidx0 = (is < 4) ? is : (is + 4); + int scidx1 = (is < 4) ? is : (is - 4); + int scidxmask1 = (is < 4) ? 0x30 : 0xC0; + int scidxshift1 = (is < 4) ? 0 : 2; + int mbidx0 = is + 4; + int mbidx1 = (is < 4) ? is + 4 : is; + int mbidxmask0 = (is < 4) ? 0xF : 0xF0; + int mbidxshift0 = (is < 4) ? 0 : 4; + int mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + int mbidxshift1 = (is < 4) ? 0 : 2; + + uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1); + uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1); + + float d = (float)src0_d[ib] * (float)sc; + float m = -(float)src0_dm[ib] * (float)mbyte; + + global uchar4 * qs = src0_q + ib*32 + (qsi >> 2); + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 >> (b * 4))&0x0F, (q.s1 >> (b * 4))&0x0F, (q.s2 >> (b * 4))&0x0F, (q.s3 >> (b * 4))&0x0F)))*d + m; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl new file mode 100644 index 00000000000..8e191f57e83 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_k_f32_l4_lm( + global uchar4 * src0_q, + global uchar * src0_qh, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 64; + int iqs = (idx % 64) * 2; + + int n = iqs / 32; + int b = (iqs % 32) / 16; + int is = 2 * n + b; + int qsi = n * 32 + (iqs % 16) * 2; + + global uchar * scales = src0_s + ib * 12; + + int scidx0 = (is < 4) ? is : (is + 4); + int scidx1 = (is < 4) ? is : (is - 4); + int scidxmask1 = (is < 4) ? 0x30 : 0xC0; + int scidxshift1 = (is < 4) ? 0 : 2; + int mbidx0 = is + 4; + int mbidx1 = (is < 4) ? is + 4 : is; + int mbidxmask0 = (is < 4) ? 0xF : 0xF0; + int mbidxshift0 = (is < 4) ? 0 : 4; + int mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + int mbidxshift1 = (is < 4) ? 0 : 2; + + uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1); + uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1); + + float d = (float)src0_d[ib] * (float)sc; + float m = -(float)src0_dm[ib] * (float)mbyte; + + int qh_base = (iqs % 16) * 2; + int bit_pos = 2*n + b; + uchar h0 = (src0_qh[ib*32 + qh_base + 0] >> bit_pos) & 1; + uchar h1 = (src0_qh[ib*32 + qh_base + 1] >> bit_pos) & 1; + uchar h2 = (src0_qh[ib*32 + qh_base + 2] >> bit_pos) & 1; + uchar h3 = (src0_qh[ib*32 + qh_base + 3] >> bit_pos) & 1; + + global uchar4 * qs = src0_q + ib*32 + (qsi >> 2); + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)( + ((q.s0 >> (b * 4))&0x0F) | (h0 << 4), + ((q.s1 >> (b * 4))&0x0F) | (h1 << 4), + ((q.s2 >> (b * 4))&0x0F) | (h2 << 4), + ((q.s3 >> (b * 4))&0x0F) | (h3 << 4) + )))*d + m; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl new file mode 100644 index 00000000000..a6a325cd729 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl @@ -0,0 +1,164 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_NL 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// Compute inner product between half a block of iq4_nl and 16 floats (yl). +// il indicates where the quants begin (0 or 8). +inline float block_iq4_nl_dot_y( + global struct block_iq4_nl * qb_curr, + private float * yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global uchar * qs = qb_curr->qs + il; + for (int i = 0; i < 8; ++i) { + acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F]; + acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4]; + } + return d * acc; +} + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup group works on 4 rows +#define N_SUBGROUP 1 // number of subgroups in a thread group +#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SUBGROUP 1 +#define N_SUBGROUP_SIZE 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_NL; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_iq4_nl * x = (global struct block_iq4_nl *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_NL + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) { + for (int i = 0; i < 8; ++i) { + yl[i] = yb[i]; + yl[i+8] = yb[i+16]; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_iq4_nl_dot_y(x+ib+row*nb, yl, il); + } + + yb += QK4_NL * (N_SUBGROUP_SIZE/2); + } + + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_iq4_nl_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl new file mode 100644 index 00000000000..8c5b3f52e42 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl @@ -0,0 +1,202 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_NL 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +// Compute dot product between half a block of iq4_nl quants and activations. +// x points to the quant bytes, dh points to the scale. +// yl has 16 activation values: [0..7] for low nibbles, [8..15] for high nibbles. +// il indicates offset into the quant bytes (0 or 8). +inline float block_iq4_nl_dot_y_flat( + global uchar * x, + global half * dh, + private float * yl, + int il +) { + float d = *dh; + global uchar * qs = x + il; + float acc = 0.f; + for (int i = 0; i < 8; ++i) { + acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F]; + acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4]; + } + return d * acc; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each subgroup works on 8 rows +#define N_SUBGROUP 1 // number of subgroups in a thread group +#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SUBGROUP 1 +#define N_SUBGROUP_SIZE 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_NL; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_NL/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_NL/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_NL + il; + + for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) { + for (int i = 0; i < 8; ++i) { + yl[i] = yb[i]; + yl[i+8] = yb[i+16]; + } + + sumf.s0 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 0*nb*QK4_NL/2, d + ib + 0*nb, yl, il); + sumf.s1 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 1*nb*QK4_NL/2, d + ib + 1*nb, yl, il); + sumf.s2 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 2*nb*QK4_NL/2, d + ib + 2*nb, yl, il); + sumf.s3 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 3*nb*QK4_NL/2, d + ib + 3*nb, yl, il); + + sumf.s4 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 4*nb*QK4_NL/2, d + ib + 4*nb, yl, il); + sumf.s5 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 5*nb*QK4_NL/2, d + ib + 5*nb, yl, il); + sumf.s6 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 6*nb*QK4_NL/2, d + ib + 6*nb, yl, il); + sumf.s7 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 7*nb*QK4_NL/2, d + ib + 7*nb, yl, il); + + yb += QK4_NL * (N_SUBGROUP_SIZE/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_iq4_nl_f32_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl new file mode 100644 index 00000000000..d92fb968904 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl @@ -0,0 +1,196 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define BLOCK_Q4K_SIZE 144 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32_flat( + global uchar * src0_q, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; + int it = get_sub_group_local_id()%8; + int iq = it/4; + int ir = it%4; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q4K_SIZE; + uint blk = nb01 / BLOCK_Q4K_SIZE; + global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2); + global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE; + global half * blk_d = (global half *)src0_d + offset_src0; + global half * blk_dm = (global half *)src0_dm + offset_src0; + + int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13; + global float * y = (global float *)(src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir); + global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq; + global half * d = blk_d + ib; + global half * dm = blk_dm + ib; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = *d; + float dmin = *dm; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += blk*64; + sc += blk*6; + d += blk; + dm += blk; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl new file mode 100644 index 00000000000..b2058abc1b6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl @@ -0,0 +1,187 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK_K 256 +#define K_SCALE_SIZE 12 + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte) + uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte) +} block_q5_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined(ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q5_K * x = (global block_q5_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uchar u1_lo = (uchar)(1 << (2*iq)); + uchar u2_lo = (uchar)(2 << (2*iq)); + uchar u1_hi = (uchar)(1 << (2*iq + 4)); + uchar u2_hi = (uchar)(2 << (2*iq + 4)); + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global uchar * qh = x[ib].qh + 8 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)); + acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)); + acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)); + acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)); + acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)); + acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)); + acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)); + acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + qh += nb01; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl new file mode 100644 index 00000000000..e353a72be70 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl @@ -0,0 +1,203 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q5_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define BLOCK_Q5K_SIZE 176 +#define K_SCALE_SIZE 12 + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte) + uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte) +} block_q5_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined(ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_K_f32_flat( + global uchar * src0_q, + global uchar * src0_qh, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; + int it = get_sub_group_local_id()%8; + int iq = it/4; + int ir = it%4; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q5K_SIZE; + uint blk = nb01 / BLOCK_Q5K_SIZE; + global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2); + global uchar * blk_qh = (global uchar *)src0_qh + offset_src0*(QK_K/8); + global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE; + global half * blk_d = (global half *)src0_d + offset_src0; + global half * blk_dm = (global half *)src0_dm + offset_src0; + + int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13; + global float * y = (global float *)(src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uchar u1_lo = (uchar)(1 << (2*iq)); + uchar u2_lo = (uchar)(2 << (2*iq)); + uchar u1_hi = (uchar)(1 << (2*iq + 4)); + uchar u2_hi = (uchar)(2 << (2*iq + 4)); + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir); + global uchar * qh = (global uchar *)(blk_qh + ib * (QK_K/8)) + 8 * ir; + global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq; + global half * d = blk_d + ib; + global half * dm = blk_dm + ib; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)); + acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)); + acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)); + acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)); + acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)); + acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)); + acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)); + acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)); + } + + float dall = *d; + float dmin = *dm; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += blk*64; + qh += blk*32; + sc += blk*6; + d += blk; + dm += blk; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 0938d2273e9..5095e799849 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -207,8 +206,22 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { break; } case GGML_OP_ROPE: { + const int mode = node->op_params[2]; + switch (mode) { + case GGML_ROPE_TYPE_NEOX: { + op_case = 0x00010000; + break; + } + case GGML_ROPE_TYPE_IMROPE: { + op_case = 0x00020000; + break; + } + default: + op_case = 0x00000000; + break; + } if (node->src[0]->op == GGML_OP_VIEW) { - op_case = 2; + op_case = (op_case | 0x00000002); } break; } @@ -573,9 +586,6 @@ std::map GgmlOvDecoder::get_kv_param_res_names() const } std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) { - static std::mutex weights_mutex; - std::lock_guard lock(weights_mutex); - std::map> model_weights; auto * nodes = cgraph->nodes; auto n_nodes = cgraph->n_nodes; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index cc3cb4583cd..4140136aca2 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include ov::Core & ov_singleton_core() { @@ -42,11 +43,13 @@ void ggml_openvino_device_config::init() { {"NPUW_DQ", "YES" }, {"NPUW_DQ_FULL", "NO" }, }; - if (cache_dir) { + if (cache_dir && strlen(cache_dir) > 0) { compile_config["NPUW_CACHE_DIR"] = cache_dir; + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } - } else if (cache_dir) { - ov_singleton_core().set_property(ov::cache_dir(cache_dir)); + } else if (cache_dir && strlen(cache_dir) > 0) { + compile_config.insert(ov::cache_dir(cache_dir)); + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } // Initialize remote context with queue sharing for GPU @@ -259,10 +262,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); - // For symmetric quantization, we only need one zp value (not one per block) - // Zero points are stored in U4 or U8 format matching the weight type - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } layout.weights_offset = 0; layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; @@ -313,10 +318,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten // Scales: F16 per block int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - // Zero points: U4 or U8 matching weight type - // For symmetric quantization, we only need one zp value (not one per block) - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } // Layout in buffer: [weights | scales | zp] with alignment layout.weights_offset = 0; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 0031cb7369f..4f3ebf2536b 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -97,6 +97,8 @@ struct ggml_backend_openvino_buffer_context { ov_buffer = std::make_shared(std::move(usm_tensor)); } else { data = ggml_aligned_malloc(size); + GGML_ASSERT(data); + memset(data, 0, size); ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); } @@ -143,13 +145,18 @@ static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer return ctx->data; } +static bool is_stateful_enabled() { + static const auto * stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION"); + return stateful && *stateful != '\0' && strcmp(stateful, "0") != 0; +} + static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" && - !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { + !is_stateful_enabled()) { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; @@ -410,6 +417,8 @@ static const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = { /* .memset_tensor = */ ggml_backend_openvino_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_openvino_buffer_set_tensor, /* .get_tensor = */ ggml_backend_openvino_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_openvino_buffer_cpy_tensor, /* .clear = */ ggml_backend_openvino_buffer_clear, /* .reset = */ NULL, @@ -596,6 +605,14 @@ bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) { static void ggml_backend_openvino_free(ggml_backend_t backend) { ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + + if (ctx->runtime_context) { + auto r_ctx = std::static_pointer_cast(ctx->runtime_context); + if (--r_ctx->backend_count == 0) { + r_ctx->clear_caches(); + } + } + delete ctx; delete backend; } @@ -615,6 +632,8 @@ static const ggml_backend_i ggml_backend_openvino_interface = { /* .free = */ ggml_backend_openvino_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, @@ -638,7 +657,12 @@ static ggml_guid_t ggml_backend_openvino_guid(void) { } static std::shared_ptr get_ov_runtime_context_ptr() { - static std::shared_ptr r_ctx = std::make_shared(); + static std::shared_ptr r_ctx = [] { + auto ctx = std::make_shared(); + ctx->device = ggml_openvino_get_device_name(); + ctx->stateful = is_stateful_enabled() && !ggml_openvino_is_npu(); + return ctx; + }(); return r_ctx; } @@ -663,8 +687,7 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) { } std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context); - r_ctx->device = ggml_openvino_get_device_name(); - r_ctx->stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !ggml_openvino_is_npu(); + r_ctx->backend_count++; ggml_backend_t openvino_backend = new ggml_backend{ /* .guid = */ ggml_backend_openvino_guid(), @@ -877,7 +900,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { const int32_t * op_params = op->op_params; const int n_dims = op_params[1]; const int mode = op_params[2]; - if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_IMROPE) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); return true; } @@ -890,14 +913,6 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); return true; } - float freq_scale; - float ext_factor; - memcpy(&freq_scale, op_params + 6, sizeof(float)); - memcpy(&ext_factor, op_params + 7, sizeof(float)); - if (ext_factor != 0.0f) { - // GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); - return true; - } if (op->src[0]->op == GGML_OP_VIEW) { if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { // GGML_LOG_WARN( @@ -907,6 +922,12 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { return true; } } + if (mode == GGML_ROPE_TYPE_IMROPE && + (op->src[2] != 0 || ((const float *) op_params)[6] != 1 || ((const float *) op_params)[7] != 0 || + ((const float *) op_params)[8] != 1)) { + // GGML_LOG_WARN("OpenVINO backend does not support IMROPE with freq_factors, freq_scale, ext_factor, and attn_factor\n"); + return true; + } break; } default: @@ -936,6 +957,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con // GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; static const std::set supported_unary_ops{ + GGML_UNARY_OP_GELU, GGML_UNARY_OP_SILU, }; static const std::set supported_glu_ops{ diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index dbf38646ddd..57d66df4f01 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -46,6 +46,7 @@ void unpack_32_4(const uint8_t * data, uint8_t * dst) { // Extracts (weight, scales, zp) from Q4_0 tensors. // Data layout is: |16 bit scale|32 x 4bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i4 (value - 8). void extract_q4_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -55,28 +56,32 @@ void extract_q4_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); - // For asymmetric quantization, compute per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); // Pack two 4-bit zero points per byte if (i % 2 == 0) { zp[i / 2] = 8; // Lower nibble } else { zp[i / 2] |= (8 << 4); // Upper nibble } - } - unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); - }); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + }); + } else { + // Symmetric: unpack as u4 then convert to i4 by subtracting 8 (XOR each nibble) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + // Convert u4 to i4: subtract 8 from each nibble. XOR 0x88 flips each nibble by 8. + for (int j = 0; j < 16; ++j) { + weights[i * 16 + j] ^= 0x88; + } + }); + } } // Extracts (weight, scales, zp) from Q4_1 tensors. @@ -123,6 +128,7 @@ void extract_q4_1_data(const ggml_tensor * tensor, // Extracts (weight, scales, zp) from Q8_0 tensors. // Data layout is: |16 bit scale|32 x 8bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i8 directly. void extract_q8_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -133,29 +139,30 @@ void extract_q8_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); zp[i] = 128; - } - for (size_t j = 0; j < weights_per_block; ++j) { - uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. - // Original data is in int8_t, so we add a bias of -128 and invert the first bit. - x ^= 1 << 7; - weights[i * weights_per_block + j] = x; - } - }); + for (size_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; + x ^= 1 << 7; // Convert int8 to uint8 by flipping sign bit + weights[i * weights_per_block + j] = x; + } + }); + } else { + // Symmetric: store original int8 values directly (no unsigned bias) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); + // Copy int8 weights as-is (the tensor element type is i8) + memcpy(weights + i * weights_per_block, block_data + 2, weights_per_block); + }); + } } void unpack_256_4(const uint8_t * data, uint8_t * dst) { @@ -256,44 +263,62 @@ void extract_q6_k_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q6_K, zero point is always 32 - if (is_scalar_zp) { - zp[0] = 32; - } - - ov::parallel_for(n_super_block, [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - float scale_factor = - static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); // (128+64+16)/2 + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - for (size_t j = 0; j < 16; j++) { - scales[j + i * 16] = - ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); zp[j + i * 16] = 32; } - } - - uint8_t * ql = block_data; - uint8_t * qh = block_data + 128; - - for (int64_t j = 0; j < 32; ++j) { - weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); - weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); - weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); - weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); - weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); - weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); - weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); - weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); - } - }); + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + }); + } else { + // Symmetric: subtract 32 from each weight to store as signed i8 + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); + } + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + auto * signed_weights = reinterpret_cast(weights); + for (int64_t j = 0; j < 32; ++j) { + signed_weights[i * 256 + j] = static_cast((ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 32] = + static_cast((ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 64] = static_cast((ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 96] = + static_cast((ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 128] = + static_cast((ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 160] = + static_cast((ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 192] = + static_cast((ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 224] = + static_cast((ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4)) - 32; + } + }); + } } static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { @@ -389,11 +414,10 @@ ov::Output make_int8_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i8); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias auto scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; @@ -403,37 +427,48 @@ ov::Output make_int8_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - // Create graph nodes - auto weights_node = std::make_shared(ov::element::u8, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto scales_f16 = std::make_shared(scales); - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_point = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_point, zp_value)) { - zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_point = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } + auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -452,11 +487,10 @@ ov::Output make_int4_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_weight_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i4); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias ov::Shape scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization // Create INT4 weight tensor ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; @@ -467,36 +501,48 @@ ov::Output make_int4_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - auto weights_node = std::make_shared(ov::element::u4, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); auto scales_f16 = std::make_shared(scales); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_points_node = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_points_node, zp_value)) { - zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_points_node = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } + auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -699,24 +745,32 @@ OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, vo // Quantized path (normal extraction or quantized requant) // Create weight/scale/zp tensors - shared between both paths - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + // For symmetric quantization, use signed types (i4/i8) and no ZP tensor + ov::element::Type weight_type = layout.is_symmetric ? (layout.is_u4 ? ov::element::i4 : ov::element::i8) : + (layout.is_u4 ? ov::element::u4 : ov::element::u8); ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); + if (!layout.is_symmetric) { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape, buf_base + layout.zp_offset); + } + // else: result.zp remains default-constructed (empty) for symmetric } else { result.weights = ov::Tensor(weight_type, node_shape); result.scales = ov::Tensor(ov::element::f16, scale_shape); - if (use_bias && !layout.is_symmetric) { - // bias only has effect for asymmetric quant - result.zp = ov::Tensor(ov::element::f16, zp_shape); - } else { - result.zp = ov::Tensor(weight_type, zp_shape); + if (!layout.is_symmetric) { + if (use_bias) { + result.zp = ov::Tensor(ov::element::f16, scale_shape); + } else { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape); + } } + // else: result.zp remains default-constructed (empty) for symmetric } if (layout.is_requant && layout.requant_type.has_value()) { @@ -741,59 +795,75 @@ void quantize_q4_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } } - } - - const float d = max / -8; - - if (d == 0) { - scales[i] = ov::float16(1.0f); - // zp is already set to 8 for symmetric, or set per-block for asymmetric - if (!is_scalar_zp) { + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); + continue; } - memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); - continue; - } - - const float id = 1.0f / d; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float id = 1.0f / d; + scales[i] = ov::float16(d); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } } - - for (int j = 0; j < qk / 2; ++j) { - const float x0 = x[i * qk + 2 * j] * id; - const float x1 = x[i * qk + 2 * j + 1] * id; - const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); - weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } else { + // Symmetric: produce signed i4 values in [-8, 7] + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); + // i4 value 0 packed: 0x00 + memset(weights + i * qk / 2, 0, qk / 2); + continue; + } + const float id = 1.0f / d; + scales[i] = ov::float16(d); + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + // Signed i4: range [-8, 7]. Quantize as round(x*id), then pack as 4-bit two's complement. + int8_t si0 = (int8_t) std::max(-8, std::min(7, (int) roundf(x0))); + int8_t si1 = (int8_t) std::max(-8, std::min(7, (int) roundf(x1))); + weights[i * qk / 2 + j] = (si0 & 0x0F) | ((si1 & 0x0F) << 4); + } } } } @@ -809,36 +879,42 @@ void quantize_q8_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); } - } - - const float d = amax / 127.0f; - const float id = d ? 1.0f / d : 0.0f; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); zp[i] = 128; + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + const int8_t xi0 = roundf(x0); + weights[i * qk + j] = (uint8_t) (xi0 + 128); + } } - - for (int j = 0; j < qk; ++j) { - const float x0 = x[i * qk + j] * id; - const int8_t xi0 = roundf(x0); - weights[i * qk + j] = (uint8_t) (xi0 + 128); + } else { + // Symmetric: store signed int8 values directly + auto * signed_weights = reinterpret_cast(weights); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); + } + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + signed_weights[i * qk + j] = (int8_t) roundf(x0); + } } } } @@ -861,12 +937,8 @@ void quantize_q8_1(const float * x, for (int j = 0; j < qk; j++) { const float v = x[i * qk + j]; - if (v < min) { - min = v; - } - if (v > max) { - max = v; - } + min = std::min(v, min); + max = std::max(v, max); } const float d = (max - min) / ((1 << 8) - 1); diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 26dc2d24f82..a8db9b38930 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -9,12 +9,17 @@ #include #include #include +#include +#include +#include #include #include #include +#include #include #include #include +#include #include #include @@ -33,6 +38,12 @@ OutputVector translate_rope(const NodeContext & context) { auto data_node = context.get_input(0).get_node_shared_ptr(); auto output_shape = context.get_output_shape().to_shape(); int32_t * op_params = context.get_output_op_params(); + const int mode = (op_case & 0xFFFF0000) >> 16; + op_case = (op_case & 0x0000FFFF); + + constexpr int TYPE_NORMAL = 0; + constexpr int TYPE_NEOX = 1; + constexpr int TYPE_IMROPE = 2; Output cos_theta_node; Output sin_theta_node; @@ -45,7 +56,7 @@ OutputVector translate_rope(const NodeContext & context) { if (context.get_input_size() == 3) { rope_freqs_weight = context.get_input(2).get_node_shared_ptr(); } - auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight); + auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight, mode == TYPE_IMROPE); sin_theta_node = sin_cos.first; cos_theta_node = sin_cos.second; } @@ -65,11 +76,7 @@ OutputVector translate_rope(const NodeContext & context) { } } - const int mode = op_params[2]; - constexpr int ROPE_TYPE_NORMAL = 0; - constexpr int ROPE_TYPE_NEOX = 2; - - if (mode == ROPE_TYPE_NORMAL) { + if (mode == TYPE_NORMAL) { auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); @@ -97,7 +104,7 @@ OutputVector translate_rope(const NodeContext & context) { auto data_shape = ov::op::v0::Constant::create( ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); res = std::make_shared(stack, data_shape, false); - } else if (mode == ROPE_TYPE_NEOX) { + } else if (mode == TYPE_NEOX) { auto data_split = std::make_shared( data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2); Output slice_data_node_0 = data_split->outputs()[0]; @@ -112,6 +119,25 @@ OutputVector translate_rope(const NodeContext & context) { std::make_shared(slice_data_node_1, cos_theta_node)); res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, -1); + } else if (mode == TYPE_IMROPE) { + int64_t n_dims = data_node->get_shape()[3]; + auto cos_sin_shape = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1,-1,1,(n_dims >> 1)}); + auto cos_reshaped = std::make_shared(cos_theta_node, cos_sin_shape, true); + auto sin_reshaped = std::make_shared(sin_theta_node, cos_sin_shape, true); + + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}); + auto split_a = std::make_shared(data_node, split_axis, 2); + auto x0 = split_a->output(0); + auto x1 = split_a->output(1); + auto mul_a = std::make_shared(x0, cos_reshaped); + auto mul_b = std::make_shared(x1, sin_reshaped); + auto sub = std::make_shared(mul_a, mul_b); + + auto mul_c = std::make_shared(x0, sin_reshaped); + auto mul_d = std::make_shared(x1, cos_reshaped); + auto add = std::make_shared(mul_c, mul_d); + + res = std::make_shared(ov::OutputVector{sub, add}, 3); } return rename_outputs_with_suffix({res}, context.get_name()); diff --git a/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp new file mode 100644 index 00000000000..d1e9efc33a5 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp @@ -0,0 +1,25 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_gelu(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto res = std::make_shared(input); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp index beadafe8103..1385539279c 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.cpp +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -31,6 +31,7 @@ std::unordered_map get_supported_ops() { {"GGML_OP_SOFT_MAX", op::translate_soft_max }, {"GGML_OP_SUB", op::translate_1to1_match_2_inputs}, {"GGML_OP_TRANSPOSE", op::translate_transpose }, + {"GGML_UNARY_OP_GELU", op::translate_unary_gelu }, {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, {"GGML_OP_VIEW", op::translate_view }, {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, diff --git a/ggml/src/ggml-openvino/openvino/op_table.h b/ggml/src/ggml-openvino/openvino/op_table.h index 37f763117aa..f546796d2ee 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.h +++ b/ggml/src/ggml-openvino/openvino/op_table.h @@ -21,6 +21,7 @@ GGML_OP_CONVERTER(translate_rms_norm); GGML_OP_CONVERTER(translate_rope); GGML_OP_CONVERTER(translate_scale); GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_unary_gelu); GGML_OP_CONVERTER(translate_soft_max); GGML_OP_CONVERTER(translate_transpose); GGML_OP_CONVERTER(translate_view); diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp deleted file mode 100644 index ed2a3ab6d1b..00000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include "eliminate_zp.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -EliminateZeroPoints::EliminateZeroPoints() { - // Find pattern: - // (Multiply Any(scale) - // (Subtract (Convert Constant(data))) - // (Convert Constant(zero_point))) - // where zero_point is a scalar - // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val - // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant - - auto m_data_constant = ov::pass::pattern::wrap_type(); - auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant}); - - auto m_zp_constant = ov::pass::pattern::wrap_type(); - auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant}); - - auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert}); - auto m_scale = ov::pass::pattern::any_input(); - auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract}); - - const auto callback = [=](ov::pass::pattern::Matcher & m) { - const auto & pattern_map = m.get_pattern_value_map(); - - auto multiply_node = - std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr()); - auto subtract_node = - std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr()); - auto data_constant = - std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr()); - auto zp_constant = - std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr()); - - if (!multiply_node || !subtract_node || !data_constant || !zp_constant) { - return false; - } - - if (ov::shape_size(zp_constant->get_shape()) != 1) { - return false; - } - - auto data_type = data_constant->get_element_type(); - auto zp_data = zp_constant->cast_vector(); - - if (zp_data.empty()) { - return false; - } - - int zp_value = zp_data[0]; - - bool should_eliminate = false; - ov::element::Type target_type; - - if (data_type == ov::element::u4 && zp_value == 8) { - should_eliminate = true; - target_type = ov::element::i4; - } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) { - should_eliminate = true; - target_type = ov::element::i8; - } - - if (!should_eliminate) { - return false; - } - - auto data_shape = data_constant->get_shape(); - size_t total_elements = ov::shape_size(data_shape); - - std::shared_ptr new_constant; - - // TODO improve performance - if (data_type == ov::element::u4) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } else if (data_type == ov::element::u8) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&, zp_value](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } - - auto new_convert = - std::make_shared(new_constant, subtract_node->get_output_element_type(0)); - ov::replace_node(subtract_node, new_convert); - - return true; - }; - - register_matcher( - std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"), - callback); -} - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h deleted file mode 100644 index edd3cd718d9..00000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +++ /dev/null @@ -1,17 +0,0 @@ -#include "openvino/pass/matcher_pass.hpp" - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -class EliminateZeroPoints : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints") - EliminateZeroPoints(); -}; - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp new file mode 100644 index 00000000000..f051891c481 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace ov { + +/** + * @brief Holds weightless caching attributes of a single constant. + * + * WeightlessCacheAttribute class represents runtime info attribute that holds + * the values of original size of the constant in bytes and the binary offset of the + * constant's data in the weights file used by the weightless caching mechanism. It's + * not copyable in case the data was changed (the original node was replaced by a new + * one produced during the tranformation pipeline) - in that case weightless caching + * can't be used for that constant. + */ +class OPENVINO_API WeightlessCacheAttribute : public RuntimeAttribute { +public: + OPENVINO_RTTI("WeightlessCacheAttribute", "0", RuntimeAttribute) + + WeightlessCacheAttribute() = delete; + + WeightlessCacheAttribute(size_t original_size, size_t bin_offset, ov::element::Type original_dtype) + : original_size(original_size), + bin_offset(bin_offset), + original_dtype(original_dtype) {} + + bool is_copyable() const override; + + size_t original_size; + size_t bin_offset; + ov::element::Type original_dtype; +}; + +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 23a1dea2496..0f68a1f5062 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -3,15 +3,16 @@ #include "ggml-openvino/openvino/node_context.h" #include "ggml-openvino/openvino/utils.h" #include "input_model.h" -#include "pass/eliminate_zp.h" #include "pass/mark_decompression_convert_constant_folding.h" #include "pass/squeeze_matmul.h" +#include "rt_info/weightless_caching_attributes.hpp" #include #include #include #include #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include #include -#include namespace ov { namespace frontend { @@ -240,6 +240,31 @@ std::shared_ptr TranslateSession::translate_graph(const frontend::InputMo resulting_model = std::make_shared(results, used_params); apply_transformations(resulting_model); + + // Set WeightlessCacheAttribute on large constants to avoid unnecessary memory copies + // in the NPUW plugin. Without this attribute, NPUW's LazyTensor constructor + // (lazy_tensor.cpp, op::Const::Const) will memcpy every constant "in case export + // occurs", doubling memory usage per compile_model call. + // + // The bin_offset field serves as a unique key (not a real file offset) — this is + // the same convention the GPU plugin uses for non-IR models (see + // Plugin::set_weightless_cache_attributes in intel_gpu/src/plugin/plugin.cpp). + // Each constant must have a distinct bin_offset, otherwise GPU's weightless cache + // import will map multiple constants to the same data. + // + // Small constants (< 16 elements) are excluded since they may be introduced by + // optimization patterns and the overhead is negligible. + size_t offset = 0; + for (auto & node : resulting_model->get_ordered_ops()) { + if (auto cnst = ov::as_type_ptr(node); + cnst && cnst->get_byte_size() / cnst->get_element_type().size() >= 16) { + auto & rt_info = cnst->get_rt_info(); + if (rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static()) == rt_info.end()) { + rt_info[ov::WeightlessCacheAttribute::get_type_info_static()] = + ov::WeightlessCacheAttribute(cnst->get_byte_size(), offset++, cnst->get_element_type()); + } + } + } return resulting_model; } @@ -257,7 +282,6 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptris_static()) { - manager.register_pass(); manager.register_pass(); } manager.run_passes(model); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index 65356a51b51..0baaf88e17a 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -87,8 +89,11 @@ ov::Output rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], fl auto ramp_y = std::make_shared(std::make_shared(dim_ids, corr_low), denom); auto ramp_clamped = std::make_shared(ramp_y, 0.0f, 1.0f); + // rope_yarn_ramp returns (1 - clamp(y)), so invert before scaling + auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + auto ramp_inverted = std::make_shared(one, ramp_clamped); auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor}); - auto ramp_mix = std::make_shared(ramp_clamped, ext_factor_node); + auto ramp_mix = std::make_shared(ramp_inverted, ext_factor_node); return ramp_mix; } @@ -115,6 +120,7 @@ void ggml_rope_yarn_corr_dims(int n_dims, std::pair, ov::Output> make_sin_cos(int32_t * rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight, + bool imrope, bool stateful) { if (stateful) { inp_pos = std::make_shared(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); @@ -122,6 +128,13 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params auto pos_perm = std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); inp_pos = std::make_shared(inp_pos, pos_perm); + } else if (imrope) { + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 0, 0, 4, -1}); + inp_pos = std::make_shared(inp_pos, pos_shape, true); + auto pos_transpose_shape = + std::make_shared(ov::element::i64, ov::Shape{5}, std::vector{0, 1, 2, 4, 3}); + inp_pos = std::make_shared(inp_pos, pos_transpose_shape); } else { inp_pos = std::make_shared(inp_pos, ov::element::f32); auto pos_perm = @@ -136,6 +149,7 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params float beta_fast; float beta_slow; const int n_dims = rope_params[1]; + const size_t n_dims_half = n_dims >> 1; const int n_ctx_orig = rope_params[4]; memcpy(&freq_base, rope_params + 5, sizeof(float)); memcpy(&freq_scale, rope_params + 6, sizeof(float)); @@ -146,57 +160,74 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params const float theta_scale = powf(freq_base, -2.0f / n_dims); - float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - std::vector factor(n_dims / 2); - factor[0] = 1.0f; - for (size_t i = 1; i < factor.size(); i++) { - factor[i] = theta_scale * factor[i - 1]; - } + std::vector factor(n_dims_half); Output freq_factors; - if (stateful) { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); - } else { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); - } - if (rope_freqs_weight) { - freq_factors = std::make_shared(freq_factors, rope_freqs_weight); - } - - auto theta_extrap = std::make_shared(freq_factors, inp_pos); - auto theta_interp = std::make_shared( - theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); Output theta; float mscale = attn_factor; - if (ext_factor == 0.0f) { - theta = theta_interp; + if (imrope) { + std::vector gather_indices(n_dims_half); + for (size_t j = 0; j < n_dims_half; j++) { + gather_indices[j] = j % 3; + factor[j] = std::pow(theta_scale, j); + } + auto gather_indices_const = + std::make_shared(ov::element::i64, ov::Shape{n_dims_half}, gather_indices); + auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {4}); + inp_pos = std::make_shared(inp_pos, gather_indices_const, gather_axis); + auto factor_const = std::make_shared(ov::element::f32, ov::Shape{n_dims_half}, factor); + theta = std::make_shared(inp_pos, factor_const); } else { - auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); - Output one; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + factor[0] = 1.0f; + for (size_t i = 1; i < factor.size(); i++) { + factor[i] = theta_scale * factor[i - 1]; + } if (stateful) { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); } else { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + } + if (rope_freqs_weight) { + freq_factors = std::make_shared(freq_factors, rope_freqs_weight); } - auto one_minus_ramp = std::make_shared(one, ramp_mix); - theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), - std::make_shared(theta_extrap, ramp_mix)); - mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + auto theta_extrap = std::make_shared(freq_factors, inp_pos); + auto theta_interp = std::make_shared( + theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); + + if (ext_factor == 0.0f) { + theta = theta_interp; + } else { + auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); + Output one; + if (stateful) { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + } else { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + } + auto one_minus_ramp = std::make_shared(one, ramp_mix); + + theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), + std::make_shared(theta_extrap, ramp_mix)); + mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + } } Output cos_theta = std::make_shared(theta); Output sin_theta = std::make_shared(theta); - auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + if (!imrope) { + auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + + cos_theta = std::make_shared(cos_theta, mscale_node); + sin_theta = std::make_shared(sin_theta, mscale_node); + } - cos_theta = std::make_shared(cos_theta, mscale_node); - sin_theta = std::make_shared(sin_theta, mscale_node); return std::make_pair(sin_theta, cos_theta); } diff --git a/ggml/src/ggml-openvino/openvino/utils.h b/ggml/src/ggml-openvino/openvino/utils.h index 88dcad4c906..767dd4c53ea 100644 --- a/ggml/src/ggml-openvino/openvino/utils.h +++ b/ggml/src/ggml-openvino/openvino/utils.h @@ -67,6 +67,7 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std:: std::pair, ov::Output> make_sin_cos(int32_t* rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight = nullptr, + bool imrope = false, bool stateful = false); ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 1b553a0de00..998ef7c9eb4 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -81,8 +81,8 @@ ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder, enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr r_ctx) { auto & core = ov_singleton_core(); const auto & config = ggml_openvino_get_compile_config(); - auto device = r_ctx->device; - bool stateful = r_ctx->stateful; + const auto & device = r_ctx->device; + const auto & stateful = r_ctx->stateful; static auto is_static = false; if (is_naive(cgraph)) { @@ -106,14 +106,26 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< int64_t infer_end_time; { - std::lock_guard lock(r_ctx->ov_compute_mutex); + std::shared_ptr entry; + ModelParams old_m_params; - auto it = r_ctx->decoder_cache.find(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); - cache_hit = it != r_ctx->decoder_cache.end(); - ModelParams old_m_params; if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_dynamically(m_params); } @@ -126,7 +138,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< ggml_decoder->update_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = r_ctx->infer_request_cache.at(key); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -170,7 +185,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -199,8 +217,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } compile_end_time = ggml_time_us(); infer_request = std::make_shared(compiled_model.create_infer_request()); - r_ctx->infer_request_cache[key] = infer_request; - r_ctx->decoder_cache[key] = ggml_decoder; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -210,8 +227,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< for (const auto & ov_output : model->get_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache[key] = infer_request; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -224,8 +246,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names; + std::vector ov_output_names; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names = r_ctx->ov_input_names_cache[key]; + ov_output_names = r_ctx->ov_output_names_cache[key]; + } for (size_t i = 0; i < ov_input_names.size(); i++) { auto param_name = ov_input_names[i]; @@ -306,12 +333,26 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrdecoder_cache.find(key); - - cache_hit = it != r_ctx->decoder_cache.end(); + std::shared_ptr entry; ModelParams old_m_params; + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); + if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_statically(m_params); } @@ -325,14 +366,21 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrupdate_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = + is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + } decoder_end_time = ggml_time_us(); conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); - r_ctx->infer_request_cache_prefill.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + r_ctx->infer_request_cache_prefill.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -372,16 +420,14 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer_request_cache_prefill[key] = - std::make_shared(compiled_model_prefill.create_infer_request()); - r_ctx->infer_request_cache[key] = - std::make_shared(compiled_model_decode.create_infer_request()); + auto infer_request_prefill = std::make_shared(compiled_model_prefill.create_infer_request()); + auto infer_request_decode = std::make_shared(compiled_model_decode.create_infer_request()); compile_end_time = ggml_time_us(); model = is_prefill ? model_prefill : model_decode; ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode; - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key]; - r_ctx->decoder_cache[key] = ggml_decoder; + infer_request = is_prefill ? infer_request_prefill : infer_request_decode; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -391,18 +437,29 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache_prefill[key] = infer_request_prefill; + r_ctx->infer_request_cache[key] = infer_request_decode; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names_local; + std::vector ov_output_names_local; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names_local = r_ctx->ov_input_names_cache[key]; + ov_output_names_local = r_ctx->ov_output_names_cache[key]; + } if (is_prefill) { auto inp_len = inp_pos->ne[0]; for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index); infer_request->set_input_tensor(i, input_tensor); @@ -412,8 +469,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -421,16 +478,16 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer(); if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { - for (size_t i = 0; i < ov_output_names.size(); i++) { + for (size_t i = 0; i < ov_output_names_local.size(); i++) { const auto output_tensor = infer_request->get_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } infer_end_time = ggml_time_us(); } else { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name); infer_request->set_input_tensor(i, input_tensor); @@ -440,8 +497,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -450,9 +507,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 656573d1389..2c72e33c352 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -3,12 +3,15 @@ #include "ggml-impl.h" #include +#include #include #include +#include #include #include #include #include +#include #include struct graph_key { @@ -40,11 +43,17 @@ struct graph_key_hash { } }; +struct decoder_runtime_ctx { + decoder_runtime_ctx(std::shared_ptr mutex) : mutex(std::move(mutex)) {} + std::shared_ptr mutex; + std::shared_ptr ptr; +}; + struct ov_runtime_context { - std::mutex ov_compute_mutex; + mutable std::mutex ctx_mutex; std::string device; bool stateful; - std::unordered_map, graph_key_hash> decoder_cache; + std::unordered_map, graph_key_hash> decoder_cache; std::unordered_map, graph_key_hash> infer_request_cache; std::unordered_map, graph_key_hash> infer_request_cache_prefill; std::unordered_map, graph_key_hash> ov_input_names_cache; @@ -53,11 +62,22 @@ struct ov_runtime_context { // Simultanous stateful inference request support to be added. size_t stateful_kv_size; std::map kv_state_input_name_map; + std::atomic backend_count; ov_runtime_context() : device("CPU"), stateful(false), - stateful_kv_size(0) {} + stateful_kv_size(0), + backend_count(0) {} + + void clear_caches() { + std::lock_guard lock(ctx_mutex); + decoder_cache.clear(); + infer_request_cache.clear(); + infer_request_cache_prefill.clear(); + ov_input_names_cache.clear(); + ov_output_names_cache.clear(); + } }; enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend); diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index e078ad14a39..53903defa8f 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -589,6 +589,7 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { ggml_backend_buffer_free(opt_ctx->buf_cpu); ggml_free(opt_ctx->ctx_static); ggml_free(opt_ctx->ctx_cpu); + ggml_free(opt_ctx->ctx_copy); delete opt_ctx; } diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48695a61ea3..15443aa554a 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -32,6 +32,41 @@ static inline int best_index_int8(int n, const int8_t * val, float x) { return x - val[mu-1] < val[mu] - x ? mu-1 : mu; } +// reference implementation for deterministic creation of model files +void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float sum_abs = 0.0f; + for (int j = 0; j < qk; j++) { + sum_abs += fabsf(x[i*qk + j]); + } + const float d = sum_abs / qk; + + y[i].d = GGML_FP32_TO_FP16(d); + + // Clear all bits first + for (int j = 0; j < qk / 8; ++j) { + y[i].qs[j] = 0; + } + + // Just store sign of each weight directly (no normalization) + for (int j = 0; j < qk; ++j) { + const int bit_index = j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + if (x[i*qk + j] >= 0.0f) { + y[i].qs[byte_index] |= (1 << bit_offset); + } + } + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -339,6 +374,26 @@ void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RE } } +void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float neg_d = -d; + + for (int j = 0; j < qk; ++j) { + const int byte_index = j / 8; + const int bit_offset = j % 8; + const uint8_t bit = (x[i].qs[byte_index] >> bit_offset) & 1; + y[i*qk + j] = bit ? d : neg_d; + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -1978,6 +2033,22 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G } } +size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q1_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q1_0_ref(src, (block_q1_0*)qrow, n_per_row); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + + size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); @@ -5286,6 +5357,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } } } break; + case GGML_TYPE_Q1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q1_0, data, nb); + } break; case GGML_TYPE_Q4_0: { VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 00604f75c0e..d56c86da890 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -14,6 +14,7 @@ extern "C" { // NOTE: these functions are defined as GGML_API because they used by the CPU backend // Quantization +GGML_API void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -41,6 +42,7 @@ GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_ GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); // Dequantization +GGML_API void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -90,6 +92,7 @@ GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index f5acb8ec2cb..40e11fead63 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -2,8 +2,32 @@ message(STATUS "Using RPC backend") ggml_add_backend_library(ggml-rpc ggml-rpc.cpp + transport.cpp ) if (WIN32) target_link_libraries(ggml-rpc PRIVATE ws2_32) endif() + +# RDMA auto-detection (Linux only, requires libibverbs) +if (NOT WIN32 AND NOT APPLE) + find_library(IBVERBS_LIB ibverbs) + if (IBVERBS_LIB) + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" ON) + else() + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" OFF) + endif() +else() + set(GGML_RPC_RDMA OFF CACHE BOOL "RDMA not available on this platform" FORCE) +endif() + +if (GGML_RPC_RDMA) + if (NOT IBVERBS_LIB) + find_library(IBVERBS_LIB ibverbs REQUIRED) + endif() + target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA) + target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB}) + message(STATUS " RDMA transport enabled (auto-detected)") +else() + message(STATUS " RDMA transport disabled") +endif() diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index d7c8ad8c168..7176d2feef9 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -2,30 +2,17 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cpp.h" +#include "transport.h" +#include #include +#include #include #include #include #include #include #include -#ifdef _WIN32 -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include -# include -#else -# include -# include -# include -# include -# include -# include -# include -#endif #include #include #include @@ -39,29 +26,6 @@ static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); namespace fs = std::filesystem; -static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB - -#ifdef _WIN32 -typedef SOCKET sockfd_t; -using ssize_t = __int64; -#else -typedef int sockfd_t; -#endif - -// cross-platform socket -struct socket_t { - sockfd_t fd; - socket_t(sockfd_t fd) : fd(fd) {} - ~socket_t() { - LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); -#ifdef _WIN32 - closesocket(this->fd); -#else - close(this->fd); -#endif - } -}; - // macro for nicer error messages on server crash #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") @@ -115,10 +79,16 @@ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +struct rpc_msg_hello_req { + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; +}; + struct rpc_msg_hello_rsp { uint8_t major; uint8_t minor; uint8_t patch; + uint8_t padding; + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; }; struct rpc_msg_device_count_rsp { @@ -288,153 +258,27 @@ static uint64_t fnv_hash(const uint8_t * data, size_t len) { return hash; } -static std::shared_ptr make_socket(sockfd_t fd) { -#ifdef _WIN32 - if (fd == INVALID_SOCKET) { - return nullptr; - } -#else - if (fd < 0) { - return nullptr; - } -#endif - return std::make_shared(fd); -} - -static bool set_no_delay(sockfd_t sockfd) { - int flag = 1; - // set TCP_NODELAY to disable Nagle's algorithm - int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static bool set_reuse_addr(sockfd_t sockfd) { - int flag = 1; - int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static std::shared_ptr socket_connect(const char * host, int port) { - struct sockaddr_in addr; - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock_ptr = make_socket(sockfd); - if (sock_ptr == nullptr) { - return nullptr; - } - if (!set_no_delay(sockfd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - struct hostent * server = gethostbyname(host); - if (server == NULL) { - GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); - return nullptr; - } - memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); - if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { - return nullptr; - } - return sock_ptr; -} - -static std::shared_ptr socket_accept(sockfd_t srv_sockfd) { - auto client_socket_fd = accept(srv_sockfd, NULL, NULL); - auto client_socket = make_socket(client_socket_fd); - if (client_socket == nullptr) { - return nullptr; - } - if (!set_no_delay(client_socket_fd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - return client_socket; -} - -static std::shared_ptr create_server_socket(const char * host, int port) { - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock = make_socket(sockfd); - if (sock == nullptr) { - return nullptr; - } - if (!set_reuse_addr(sockfd)) { - GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); - return nullptr; - } - if (inet_addr(host) == INADDR_NONE) { - GGML_LOG_ERROR("Invalid host address: %s\n", host); - return nullptr; - } - struct sockaddr_in serv_addr; - serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = inet_addr(host); - serv_addr.sin_port = htons(port); - - if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { - return nullptr; - } - if (listen(sockfd, 1) < 0) { - return nullptr; - } - return sock; -} - -static bool send_data(sockfd_t sockfd, const void * data, size_t size) { - size_t bytes_sent = 0; - while (bytes_sent < size) { - size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); - ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0); - if (n < 0) { - GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", - bytes_sent, size_to_send); - return false; - } - bytes_sent += (size_t)n; - } - return true; -} - -static bool recv_data(sockfd_t sockfd, void * data, size_t size) { - size_t bytes_recv = 0; - while (bytes_recv < size) { - size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); - ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0); - if (n < 0) { - GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", - bytes_recv, size_to_recv); - return false; - } - if (n == 0) { - LOG_DBG("recv returned 0 (peer closed?)\n"); - return false; - } - bytes_recv += (size_t)n; - } - return true; -} - -static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) { - if (!send_data(sockfd, &msg_size, sizeof(msg_size))) { +static bool send_msg(socket_ptr sock, const void * msg, size_t msg_size) { + if (!sock->send_data(&msg_size, sizeof(msg_size))) { return false; } - return send_data(sockfd, msg, msg_size); + return sock->send_data(msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) { +static bool recv_msg(socket_ptr sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sockfd, msg, msg_size); + return sock->recv_data(msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, std::vector & input) { +static bool recv_msg(socket_ptr sock, std::vector & input) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } try { @@ -443,7 +287,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sockfd, input.data(), size); + return sock->recv_data(input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -452,21 +296,25 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int return false; } host = endpoint.substr(0, pos); - port = std::stoi(endpoint.substr(pos + 1)); + try { + port = std::stoi(endpoint.substr(pos + 1)); + } catch (...) { + return false; + } return true; } // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // No response -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { + if (!sock->send_data(&cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock->fd, &input_size, sizeof(input_size))) { + if (!sock->send_data(&input_size, sizeof(input_size))) { return false; } - if (!send_data(sock->fd, input, input_size)) { + if (!sock->send_data(input, input_size)) { return false; } return true; @@ -474,20 +322,18 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } - // TODO: currently the output_size is always known, do we need support for commands with variable output size? - // even if we do, we can skip sending output_size from the server for commands with known output size uint64_t out_size; - if (!recv_data(sock->fd, &out_size, sizeof(out_size))) { + if (!sock->recv_data(&out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock->fd, output, output_size)) { + if (!sock->recv_data(output, output_size)) { return false; } return true; @@ -495,17 +341,25 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation -static bool check_server_version(const std::shared_ptr & sock) { - rpc_msg_hello_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); +// Performs HELLO handshake with transport auto-negotiation. +// Advertises local capabilities via conn_caps; if the server responds with +// matching capabilities, the socket is upgraded transparently. +static bool negotiate_hello(const std::shared_ptr & sock) { + rpc_msg_hello_req request = {}; + rpc_msg_hello_rsp response = {}; + + sock->get_caps(request.conn_caps); + + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { - GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", + response.major, response.minor, response.patch); return false; } - if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) { - GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); - } + + sock->update_caps(response.conn_caps); return true; } @@ -513,7 +367,6 @@ static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); static std::unordered_map> sockets; - static bool initialized = false; auto it = sockets.find(endpoint); if (it != sockets.end()) { @@ -527,26 +380,18 @@ static std::shared_ptr get_socket(const std::string & endpoint) { GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str()); return nullptr; } -#ifdef _WIN32 - if (!initialized) { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - return nullptr; - } - initialized = true; + + if (!rpc_transport_init()) { + return nullptr; } -#else - GGML_UNUSED(initialized); -#endif - auto sock = socket_connect(host.c_str(), port); + auto sock = socket_t::connect(host.c_str(), port); if (sock == nullptr) { return nullptr; } - if (!check_server_version(sock)) { + if (!negotiate_hello(sock)) { return nullptr; } - LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); + LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str()); sockets[endpoint] = sock; return sock; } @@ -589,8 +434,10 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { ggml_backend_buffer_t buffer = tensor->buffer; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; result.buffer = ctx != nullptr ? ctx->remote_ptr : 0; + result.data = reinterpret_cast(tensor->data); } else { result.buffer = 0; + result.data = 0; } for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { result.ne[i] = tensor->ne[i]; @@ -606,7 +453,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { } result.view_src = reinterpret_cast(tensor->view_src); result.view_offs = tensor->view_offs; - result.data = reinterpret_cast(tensor->data); // Avoid sending uninitialized data over the wire memset(result.name, 0, sizeof(result.name)); @@ -705,6 +551,8 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, /* .clear = */ ggml_backend_rpc_buffer_clear, /* .reset = */ NULL, @@ -892,6 +740,8 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .free = */ ggml_backend_rpc_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_rpc_synchronize, /* .graph_plan_create = */ NULL, @@ -1008,8 +858,8 @@ class rpc_server { bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); struct stored_graph { - ggml_context_ptr ctx_ptr; - ggml_cgraph * graph; + std::vector buffer; + ggml_cgraph * graph; }; private: @@ -1162,12 +1012,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp return nullptr; } + // Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types) + if (ggml_blck_size((enum ggml_type)tensor->type) == 0) { + GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type); + return nullptr; + } + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type if (result == nullptr) { - GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type); + GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type); return nullptr; } @@ -1245,7 +1101,7 @@ bool rpc_server::set_tensor(const std::vector & input) { fs::path cache_file = fs::path(cache_dir) / hash_str; std::ofstream ofs(cache_file, std::ios::binary); ofs.write((const char *)data, size); - GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str()); + GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.string().c_str()); } ggml_backend_tensor_set(tensor, data, offset, size); return true; @@ -1333,7 +1189,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { if (buffer && buffer->iface.init_tensor) { buffer->iface.init_tensor(buffer, tensor); } else { - GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n"); + if (!buffer) { + GGML_LOG_ERROR("Tensor with null buffer passed to init_tensor function\n"); + } } if (tensor->extra != nullptr) { @@ -1440,6 +1298,10 @@ ggml_tensor * rpc_server::create_node(uint64_t id, if (result == nullptr) { return nullptr; } + if (result->buffer == nullptr && result->data != nullptr) { + GGML_LOG_ERROR("[%s] invalid data ptr", __func__); + return nullptr; + } tensor_map[id] = result; for (int i = 0; i < GGML_MAX_SRC; i++) { // Check if the source ID is 0 before calling create_node recursively @@ -1505,10 +1367,12 @@ bool rpc_server::graph_compute(const std::vector & input) { LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); - + if (stored_graphs[device].buffer.size() < buf_size) { + stored_graphs[device].buffer.resize(buf_size); + } struct ggml_init_params params = { /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ NULL, + /*.mem_buffer =*/ stored_graphs[device].buffer.data(), /*.no_alloc =*/ true, }; ggml_context_ptr ctx_ptr { ggml_init(params) }; @@ -1538,7 +1402,6 @@ bool rpc_server::graph_compute(const std::vector & input) { } ggml_status status = ggml_backend_graph_compute(backends[device], graph); GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); - stored_graphs[device].ctx_ptr.swap(ctx_ptr); stored_graphs[device].graph = graph; return true; } @@ -1579,27 +1442,46 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - sockfd_t sockfd) { + socket_ptr sock) { rpc_server server(backends, cache_dir); uint8_t cmd; - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { return; } - // the first command sent by the client must be HELLO if (cmd != RPC_CMD_HELLO) { GGML_LOG_ERROR("Expected HELLO command, update client\n"); return; } - if (!recv_msg(sockfd, nullptr, 0)) { + + // Read input_size and validate protocol version + uint64_t hello_input_size; + if (!sock->recv_data(&hello_input_size, sizeof(hello_input_size))) { return; } - rpc_msg_hello_rsp response; - server.hello(response); - if (!send_msg(sockfd, &response, sizeof(response))) { + + if (hello_input_size != sizeof(rpc_msg_hello_req)) { + GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n", + (size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION); return; } + + rpc_msg_hello_req req = {}; + if (!sock->recv_data(&req, sizeof(req))) { + return; + } + + rpc_msg_hello_rsp rsp = {}; + server.hello(rsp); + // Advertise server transport capabilities based on client's caps + sock->get_caps(rsp.conn_caps); + if (!send_msg(sock, &rsp, sizeof(rsp))) { + return; + } + + // Activate transport upgrade using client's caps + sock->update_caps(req.conn_caps); while (true) { - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { break; } if (cmd >= RPC_CMD_COUNT) { @@ -1613,115 +1495,115 @@ static void rpc_serve_client(const std::vector & backends, const return; } case RPC_CMD_DEVICE_COUNT: { - if (!recv_msg(sockfd, nullptr, 0)) { + if (!recv_msg(sock, nullptr, 0)) { return; } rpc_msg_device_count_rsp response; response.device_count = backends.size(); - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; if (!server.alloc_buffer(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALLOC_SIZE: { rpc_msg_get_alloc_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alloc_size_rsp response; if (!server.get_alloc_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALIGNMENT: { rpc_msg_get_alignment_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; if (!server.get_alignment(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { rpc_msg_get_max_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; if (!server.get_max_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_BUFFER_GET_BASE: { rpc_msg_buffer_get_base_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_buffer_get_base_rsp response; if (!server.buffer_get_base(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_FREE_BUFFER: { rpc_msg_free_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.free_buffer(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_BUFFER_CLEAR: { rpc_msg_buffer_clear_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.buffer_clear(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_SET_TENSOR: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.set_tensor(input)) { @@ -1731,62 +1613,62 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_SET_TENSOR_HASH: { rpc_msg_set_tensor_hash_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_set_tensor_hash_rsp response; if (!server.set_tensor_hash(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_INIT_TENSOR: { rpc_msg_init_tensor_req request; - if (!recv_msg(sockfd, &request,sizeof(request))) { + if (!recv_msg(sock, &request,sizeof(request))) { return; } if (!server.init_tensor(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_GET_TENSOR: { rpc_msg_get_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } std::vector response; if (!server.get_tensor(request, response)) { return; } - if (!send_msg(sockfd, response.data(), response.size())) { + if (!send_msg(sock, response.data(), response.size())) { return; } break; } case RPC_CMD_COPY_TENSOR: { rpc_msg_copy_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_copy_tensor_rsp response; if (!server.copy_tensor(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GRAPH_COMPUTE: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.graph_compute(input)) { @@ -1796,7 +1678,7 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GRAPH_RECOMPUTE: { rpc_msg_graph_recompute_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.graph_recompute(request)) { @@ -1806,14 +1688,14 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GET_DEVICE_MEMORY: { rpc_msg_get_device_memory_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_device_memory_rsp response; if (!server.get_device_memory(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; @@ -1866,36 +1748,34 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir if (!parse_endpoint(endpoint, host, port)) { return; } -#ifdef _WIN32 - { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - fprintf(stderr, "WSAStartup failed: %d\n", res); - return; - } + +#ifdef GGML_RPC_RDMA + printf(" transport : TCP (RDMA auto-negotiate enabled)\n"); +#else + printf(" transport : TCP\n"); +#endif // GGML_RPC_RDMA + if (!rpc_transport_init()) { + fprintf(stderr, "Failed to initialize RPC transport\n"); + return; } -#endif - auto server_socket = create_server_socket(host.c_str(), port); + auto server_socket = socket_t::create_server(host.c_str(), port); if (server_socket == nullptr) { fprintf(stderr, "Failed to create server socket\n"); return; } while (true) { - auto client_socket = socket_accept(server_socket->fd); + auto client_socket = server_socket->accept(); if (client_socket == nullptr) { fprintf(stderr, "Failed to accept client connection\n"); return; } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket->fd); + rpc_serve_client(backends, cache_dir, client_socket); printf("Client connection closed\n"); fflush(stdout); } -#ifdef _WIN32 - WSACleanup(); -#endif + rpc_transport_shutdown(); for (auto backend : backends) { ggml_backend_free(backend); } diff --git a/ggml/src/ggml-rpc/transport.cpp b/ggml/src/ggml-rpc/transport.cpp new file mode 100644 index 00000000000..a728152421f --- /dev/null +++ b/ggml/src/ggml-rpc/transport.cpp @@ -0,0 +1,683 @@ +#include "transport.h" +#include "ggml-impl.h" + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include +#include +#include + +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; + +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; + +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); + +#endif // GGML_RPC_RDMA + +struct socket_t::impl { + impl(sockfd_t fd) : use_rdma(false), fd(fd) {} + ~impl(); + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + +#ifdef GGML_RPC_RDMA + bool tcp_peer_closed(); + std::optional rdma_build_target_gid(); + bool rdma_probe(); + bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid); + bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc); + bool rdma_send(const void * data, size_t size); + bool rdma_recv(void * data, size_t size); + + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA + bool use_rdma; + sockfd_t fd; +}; + +socket_t::impl::~impl() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA + LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + if (fd != INVALID_SOCKET) closesocket(this->fd); +#else + if (fd >= 0) close(this->fd); +#endif +} + +#ifdef GGML_RPC_RDMA + +bool socket_t::impl::tcp_peer_closed() { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +std::optional socket_t::impl::rdma_build_target_gid() { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +bool socket_t::impl::rdma_probe() { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(); + if (!target_gid) { + return false; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return false; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + rdma_local.path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return false; + + rdma_local.ib_port = ib_port; + rdma_local.gid_idx = gid_idx; + + rdma = std::make_unique(); + rdma->ctx = ibctx; + + rdma->pd = ibv_alloc_pd(ibctx); + if (!rdma->pd) return false; + + rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!rdma->scq || !rdma->rcq) return false; + + ibv_qp_init_attr qia = {}; + qia.send_cq = rdma->scq; + qia.recv_cq = rdma->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + rdma->qp = ibv_create_qp(rdma->pd, &qia); + if (!rdma->qp) return false; + rdma->max_inline = qia.cap.max_inline_data; + + rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + rdma->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!rdma->tx_buf || !rdma->rx_buf) return false; + + rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!rdma->tx_mr || !rdma->rx_mr) return false; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false; + + rdma_local.qpn = rdma->qp->qp_num; + rdma_local.psn = rdma->qp->qp_num & 0xffffff; + memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline); + return true; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = rdma_local.ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!rdma->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = rdma_local.path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = rdma_local.gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = rdma_local.ib_port; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = rdma_local.psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH); + return true; +} + +bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed()) { + return false; + } + } + } +} + +bool socket_t::impl::rdma_send(const void * data, size_t size) { + rdma_conn * c = rdma.get(); + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +bool socket_t::impl::rdma_recv(void * data, size_t size) { + rdma_conn * c = rdma.get(); + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +#endif // GGML_RPC_RDMA + +bool socket_t::impl::send_data(const void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_send(data, size); + } +#endif + size_t bytes_sent = 0; + while (bytes_sent < size) { + size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); + ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0); + if (n < 0) { + GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", + bytes_sent, size_to_send); + return false; + } + bytes_sent += (size_t)n; + } + return true; +} + +bool socket_t::impl::recv_data(void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_recv(data, size); + } +#endif + size_t bytes_recv = 0; + while (bytes_recv < size) { + size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); + ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0); + if (n < 0) { + GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", + bytes_recv, size_to_recv); + return false; + } + if (n == 0) { + LOG_DBG("recv returned 0 (peer closed?)\n"); + return false; + } + bytes_recv += (size_t)n; + } + return true; +} + +void socket_t::impl::get_caps(uint8_t * local_caps) { + memset(local_caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + if (rdma_probe()) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(local_caps, &rc, sizeof(rc)); + } else { + rdma.reset(); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::impl::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rc.qpn, rc.psn, rc.gid)) { + use_rdma = true; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + + +///////////////////////////////////////////////////////////////////////////// + +socket_t::socket_t(std::unique_ptr p) : pimpl(std::move(p)) {} + +socket_t::~socket_t() = default; + +bool socket_t::send_data(const void * data, size_t size) { + return pimpl->send_data(data, size); +} + +bool socket_t::recv_data(void * data, size_t size) { + return pimpl->recv_data(data, size); +} + +void socket_t::get_caps(uint8_t * local_caps) { + return pimpl->get_caps(local_caps); +} + +void socket_t::update_caps(const uint8_t * remote_caps) { + return pimpl->update_caps(remote_caps); +} + +static bool is_valid_fd(sockfd_t sockfd) { +#ifdef _WIN32 + return sockfd != INVALID_SOCKET; +#else + return sockfd >= 0; +#endif +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +socket_ptr socket_t::accept() { + auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL); + if (!is_valid_fd(client_socket_fd)) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(client_socket_fd))); +} + +socket_ptr socket_t::create_server(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + GGML_LOG_ERROR("Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +socket_ptr socket_t::connect(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +#ifdef _WIN32 +static std::mutex g_rpc_transport_mu; +static bool g_rpc_transport_wsa_started = false; +#endif + +bool rpc_transport_init() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (g_rpc_transport_wsa_started) { + return true; + } + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return false; + } + g_rpc_transport_wsa_started = true; + return true; +#else + return true; +#endif +} + +void rpc_transport_shutdown() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (!g_rpc_transport_wsa_started) { + return; + } + WSACleanup(); + g_rpc_transport_wsa_started = false; +#endif +} diff --git a/ggml/src/ggml-rpc/transport.h b/ggml/src/ggml-rpc/transport.h new file mode 100644 index 00000000000..73b85cc530a --- /dev/null +++ b/ggml/src/ggml-rpc/transport.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +struct socket_t; +typedef std::shared_ptr socket_ptr; + +static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +struct socket_t { + ~socket_t(); + + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + + socket_ptr accept(); + + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + + static socket_ptr create_server(const char * host, int port); + static socket_ptr connect(const char * host, int port); + +private: + struct impl; + explicit socket_t(std::unique_ptr p); + std::unique_ptr pimpl; +}; + +bool rpc_transport_init(); +void rpc_transport_shutdown(); diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 7b07b227874..8e589fa238d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -154,6 +154,11 @@ if (GGML_SYCL_GRAPH) target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) endif() +if (GGML_SYCL_HOST_MEM_FALLBACK) + message(STATUS "find GGML_SYCL_HOST_MEM_FALLBACK") + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_HOST_MEM_FALLBACK) +endif() + if (GGML_SYCL_DEVICE_ARCH) target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) diff --git a/ggml/src/ggml-sycl/add-id.cpp b/ggml/src/ggml-sycl/add-id.cpp index 8929017a999..e0adc4fe423 100644 --- a/ggml/src/ggml-sycl/add-id.cpp +++ b/ggml/src/ggml-sycl/add-id.cpp @@ -56,7 +56,7 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { float* dst_d = (float*)dst->data; const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + GGML_ASSERT(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); int threads = std::min((unsigned int)ne00, max_work_group_size); // cols diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index b30b7f2beb7..a526d8e58bc 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -24,6 +24,7 @@ #include "dmmv.hpp" #include "element_wise.hpp" #include "fattn.hpp" +#include "gated_delta_net.hpp" #include "gla.hpp" #include "im2col.hpp" #include "mmq.hpp" @@ -31,6 +32,7 @@ #include "norm.hpp" #include "outprod.hpp" #include "pad.hpp" +#include "pad_reflect_1d.hpp" #include "quantize.hpp" #include "quants.hpp" #include "roll.hpp" @@ -39,8 +41,8 @@ #include "ssm_conv.hpp" #include "softmax.hpp" #include "tsembd.hpp" +#include "upscale.hpp" #include "wkv.hpp" -#include "pad_reflect_1d.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fcb0db99c6b..5abf2290651 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -23,10 +23,18 @@ #include "ggml-impl.h" #include "ggml-sycl.h" #include "presets.hpp" +#include "type.hpp" #include "sycl_hw.hpp" namespace syclexp = sycl::ext::oneapi::experimental; +#if defined(__INTEL_LLVM_COMPILER) && __has_include() + #include + #ifndef GGML_SYCL_HAS_BF16 + #define GGML_SYCL_HAS_BF16 + #endif +#endif + #if GGML_SYCL_DNNL #include "dnnl.hpp" #include "dnnl_sycl.hpp" @@ -216,7 +224,7 @@ struct sycl_device_info { // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support size_t total_vram; - //sycl_hw_info hw_info; \\ device id and aarch, currently not used + sycl_hw_info hw_info; optimize_feature opt_feature; }; @@ -965,4 +973,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) { return val; } +static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) { + const uint32_t bits = x * (x != 0x7F && x != 0xFF); + const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index d17aca2cac4..67b9c06f3e4 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -2,13 +2,6 @@ #include "dequantize.hpp" #include "presets.hpp" -#if defined(__INTEL_LLVM_COMPILER) - #if __has_include() - #include - #define GGML_SYCL_HAS_BF16 - #endif -#endif - template static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { @@ -151,6 +144,25 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int } +template +static void dequantize_row_q8_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + int constexpr WARP_K = WARP_SIZE * QK8_0; + const int n_warp = (k + WARP_K - 1) / WARP_K; + GGML_ASSERT(k % QK8_0 == 0); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ + dequantize_block_q8_0_reorder(vx, y, k, item_ct1); + }); + +} + template static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -482,6 +494,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t }); } +template +static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> /*item_ct1*/) { + dequantize_block_nvfp4(vx, y, k); + }); +} + + template static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, @@ -602,7 +626,12 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q5_1: return dequantize_block_sycl; case GGML_TYPE_Q8_0: - return dequantize_block_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q8_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: @@ -641,6 +670,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F32: return convert_unary_sycl; #ifdef GGML_SYCL_HAS_BF16 @@ -648,6 +679,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return convert_unary_sycl; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } @@ -668,7 +700,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q5_1: return dequantize_block_sycl; case GGML_TYPE_Q8_0: - return dequantize_block_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q8_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: @@ -708,6 +745,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F16: return convert_unary_sycl; #ifdef GGML_SYCL_HAS_BF16 @@ -715,11 +754,28 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return convert_unary_sycl; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } +#ifdef GGML_SYCL_HAS_BF16 +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_sycl; + case GGML_TYPE_F16: + return convert_unary_sycl; + case GGML_TYPE_BF16: + return convert_unary_sycl; + default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); + return nullptr; + } +} +#endif + to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) { switch (type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 6e621f2154d..8de79d10ff6 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -23,6 +23,11 @@ typedef to_t_sycl_t to_fp16_sycl_t; to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); +#ifdef GGML_SYCL_HAS_BF16 +typedef to_t_sycl_t to_bf16_sycl_t; +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * dst); +#endif + // Nc = Non-contiguous template using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, @@ -35,15 +40,19 @@ template inline dst_t ggml_sycl_cast(src_t x) { if constexpr (std::is_same_v) { return x; +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v) { return sycl::ext::oneapi::bfloat16(float(x)); } else if constexpr (std::is_same_v) { return static_cast(x); +#endif } else if constexpr (std::is_same_v && std::is_same_v) { return x.template convert(); +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v && std::is_same_v>) { return {x.x, x.y}; +#endif } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index da2a605daa8..19fa88680d6 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -14,6 +14,7 @@ #define GGML_SYCL_DEQUANTIZE_HPP #include "common.hpp" +#include "convert.hpp" typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, @@ -143,6 +144,22 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q8_0_reorder(const void *d_ptr, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v) { + const dfloat d = (const dfloat)*((const sycl::half*)d_ptr + ib); + + v.x() = ((const int8_t *)qs)[iqs + 0]; + v.y() = ((const int8_t *)qs)[iqs + 1]; + +#ifdef GGML_SYCL_F16 + v.s0() *= d; + v.s1() *= d; +#else + v.x() *= d; + v.y() *= d; +#endif // GGML_SYCL_F16 +} + static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q8_0 * x = (const block_q8_0 *) vx; @@ -222,6 +239,34 @@ static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * } +// Dequantize Q8_0 from reorder layout: [all qs (k bytes)][all d values] +// Each thread handles one block of QK8_0 elements. +template +static void dequantize_block_q8_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t k, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const int64_t tid = item_ct1.get_local_id(2); + const int lane_ib = i * WARP_SIZE + tid; + + if (lane_ib >= k / QK8_0) { + return; + } + + dst_t * y_ptr = yy + lane_ib * QK8_0; + + auto qs = (const int8_t*)vx + lane_ib * QK8_0; + auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k) + lane_ib; + + const float d = float(*s_ptr); + +#pragma unroll + for (int l = 0; l < QK8_0; ++l) { + y_ptr[l] = d * qs[l]; + } + +} + template static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { @@ -838,4 +883,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr } } + +template +static void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_group(2); + const int tid = item_ct1.get_local_id(2); + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_sycl_cast(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_sycl_cast(d * kvalues_mxfp4[q >> 4]); +} + + #endif // GGML_SYCL_DEQUANTIZE_HPP diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 4f2760110c2..5577bf73b28 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -615,6 +615,162 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, } } +static void dequantize_mul_mat_vec_q4_k_reorder(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [qs: nb * QK_K/2] [scales: nb * K_SCALE_SIZE] [dm: nb * sizeof(half2)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * qs_base = (const uint8_t *)vx; + const uint8_t * scales_base = qs_base + (size_t)nb * (QK_K / 2); + const sycl::half2 * dm_base = (const sycl::half2 *)(scales_base + (size_t)nb * K_SCALE_SIZE); + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const sycl::half2 dm_val = dm_base[bi]; + const float dall = dm_val[0]; + const float dmin = dm_val[1]; + + const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE); + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(qs_base + bi * (QK_K / 2) + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + sycl::float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4]; + s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + + s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - + dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(qs_base + bi * (QK_K / 2) + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const uint8_t * q = qs_base + bi * (QK_K / 2) + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE); + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const sycl::half2 dm_val = dm_base[bi]; + const float d = (float)dm_val[0]; + const float m = (float)dm_val[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + /* DPCT1110:7: The total declared local variable size in device function dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register @@ -864,6 +1020,129 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa } } +static void dequantize_mul_mat_vec_q6_k_reorder(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [ql: nb * QK_K/2] [qh: nb * QK_K/4] [scales: nb * QK_K/16] [d: nb * sizeof(half)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * ql_base = (const uint8_t *)vx; + const uint8_t * qh_base = ql_base + (size_t)nb * (QK_K / 2); + const int8_t * scales_base = (const int8_t *)(qh_base + (size_t)nb * (QK_K / 4)); + const sycl::half * d_base = (const sycl::half *)((const uint8_t *)scales_base + (size_t)nb * (QK_K / 16)); + +#if QK_K == 256 + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = ql_base + bi * (QK_K / 2) + ql_offset; + const uint8_t * qh = qh_base + bi * (QK_K / 4) + qh_offset; + const int8_t * s = scales_base + bi * (QK_K / 16) + s_offset; + + const float d = d_base[bi]; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = ql_base + bi * (QK_K / 2) + step; + const uint8_t * qh = qh_base + bi * (QK_K / 4) + step; + const int8_t * s = scales_base + bi * (QK_K / 16); + + const float d = d_base[bi]; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -972,6 +1251,103 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, } } +static void dequantize_mul_mat_vec_q8_0_sycl_reorder(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + // Q8_0 reorder layout: [all qs (ncols*nrows bytes)][all d values] + // Cannot reuse dequantize_mul_mat_vec_reorder template because it has + // Q4_0-specific constants hardcoded (d_ptr offset and qs stride). + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row >= nrows) return; + + const int tid = item_ct1.get_local_id(2); + const int iter_stride = 8*2*GGML_SYCL_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; + const int ncols_left = ncols % (QK8_0*WARP_SIZE); + const int ncols_align = ncols - ncols_left; + +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; +#else + float tmp = 0.0f; +#endif + const char *d_ptr = (const char*)vx + ncols*nrows; // d after all qs + + int i = 0; + for (i = 0; i < ncols_align; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK8_0; + const int iqs = col % QK8_0; + +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + dfloat2 v; + dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx, + ib * QK8_0 + iqs + j, v); + +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[col + j + 0], y[col + j + 1]}; + tmp += v * t1; +#else + tmp += v.x() * y[col + j + 0]; + tmp += v.y() * y[col + j + 1]; +#endif + } + } + + // handle remaining columns + for (; i < ncols; i += iter_stride) { + if (tid >= ncols_left/QK8_0) continue; + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK8_0; + const int iqs = col % QK8_0; + +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + dfloat2 v; + dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx, + ib * QK8_0 + iqs + j, v); + +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[col + j + 0], y[col + j + 1]}; + tmp += v * t1; +#else + tmp += v.x() * y[col + j + 0]; + tmp += v.y() * y[col + j + 1]; +#endif + } + } + + // reduce + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif + } + }); + } +} + static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -1070,6 +1446,38 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, }); } +static void dequantize_mul_mat_vec_q4_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q4_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + +static void dequantize_mul_mat_vec_q6_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q6_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + void ggml_sycl_op_dequantize_mul_mat_vec( ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -1122,7 +1530,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q8_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q2_K: dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); @@ -1133,8 +1546,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - // reorder is currently not supported for dmmv - GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); + dequantize_mul_mat_vec_q4_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); } else { dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); } @@ -1143,7 +1555,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q6_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index acd51bf45b2..249e80c826e 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -294,30 +294,6 @@ static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl: } } -template -static void upscale(const T *x, T *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { - int index = item_ct1.get_local_id(0) + - item_ct1.get_group(0) * item_ct1.get_local_range(0); - if (index >= ne10 * ne11 * ne12 * ne13) { - return; - } - // operation - int i10 = index % ne10; - int i11 = (index / ne10) % ne11; - int i12 = (index / (ne10 * ne11)) % ne12; - int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - - int i00 = static_cast(i10 / sf0); - int i01 = static_cast(i11 / sf1); - int i02 = static_cast(i12 / sf2); - int i03 = static_cast(i13 / sf3); - - dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); -} - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, const sycl::nd_item<1> &item_ct1) { @@ -379,7 +355,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); }); } @@ -392,20 +368,6 @@ static void arange_kernel(T * dst, const int k, T start, T step, } } -template -static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, queue_ptr stream) { - int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE); - sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); - }); -} - template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); @@ -505,42 +467,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c } } -template -static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; - const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; - const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; - const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; - switch (dst->type) { - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], - (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, - main_stream, std::forward(args)...); - break; - } - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], - (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, - main_stream, std::forward(args)...); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - template static inline void ggml_sycl_op_unary( ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) { @@ -784,15 +710,6 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor }); } -static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst, - [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03, - int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3, - queue_ptr stream) { - ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream); - }); -} - static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float min_val; float max_val; @@ -1131,12 +1048,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sqr(ctx, dst); } -void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); - ggml_sycl_op_upscale(ctx, dst); -} - - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_clamp(ctx, dst); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 7c71974687a..997132166ab 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -71,8 +71,6 @@ void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/fattn-tile.cpp b/ggml/src/ggml-sycl/fattn-tile.cpp index 9d4f019cf51..9449d75784d 100644 --- a/ggml/src/ggml-sycl/fattn-tile.cpp +++ b/ggml/src/ggml-sycl/fattn-tile.cpp @@ -44,6 +44,10 @@ void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp index 29fd0f8c9ec..9ba5296968d 100644 --- a/ggml/src/ggml-sycl/fattn-tile.hpp +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -67,9 +67,16 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64) return 0; } @@ -123,6 +130,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) @@ -130,134 +143,6 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co return 0; } -static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) - - return 0; -} - -static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) - - return 0; -} - static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if(fast_fp16_available(cc)) return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols); @@ -310,11 +195,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const sycl::half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; auto load = [&] (const int n) { - auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); const int stride_j = warp_size >> n; if (stride_j == 0) { @@ -455,7 +340,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, flash_attn_tile_load_tile (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #ifdef SYCL_FAST_FP16 static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); @@ -505,7 +390,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, } if (k_KQ_0 + nbatch_K < DKQ) { - item_ct1.barrier(); // Sync not needed on last iteration. + item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration. } } @@ -545,7 +430,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, const int k_VKQ_max, const int col_Q_0, float * KQ_max_new_shared) { - auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -620,14 +505,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, } if constexpr (np == 1) { - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { static_assert(cpw == 1, "bad cpw"); if (item_ct1.get_local_id(2) == 0) { KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0]; } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np]; KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); } @@ -697,7 +582,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { flash_attn_tile_load_tile (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #ifdef SYCL_FAST_FP16 #pragma unroll @@ -765,7 +650,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, } } #endif // SYCL_FAST_FP16 - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } } @@ -972,7 +857,7 @@ static void flash_attn_tile(const char * Q, } } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); // Main loop over KV cache: const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11; @@ -1051,7 +936,7 @@ static void flash_attn_tile(const char * Q, return; } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll for (int ip = 1; ip < np; ++ip) { @@ -1193,37 +1078,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm constexpr size_t nbytes_shared = 0; - if constexpr (DV <= 256) { - if (Q->ne[1] > 16/ncols2) { - constexpr int cols_per_block = 32; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; + if (DV < 512 && Q->ne[1] < 32) { + if constexpr (ncols2 <= 32) { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } } - } - - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; - } - - if constexpr (ncols2 <= 8) { - if (Q->ne[1] > 4/ncols2) { - constexpr int cols_per_block = 8; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } } } @@ -1249,6 +1136,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm return; } + { + constexpr int cols_per_block = ncols2*2; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + GGML_ABORT("fatal error"); } @@ -1280,6 +1177,16 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } + // ncols2=2 and ncols2=1 fallbacks only for cases where ncols=2 config exists (DKQ == DV). + // For DKQ == 576, DV == 512 only GQA-optimized variants are implemented. + if constexpr (DKQ == DV) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } } if constexpr (DV <= 256) { @@ -1334,5 +1241,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-sycl/fattn-vec.hpp b/ggml/src/ggml-sycl/fattn-vec.hpp index 48c389052f4..8031acfdff8 100644 --- a/ggml/src/ggml-sycl/fattn-vec.hpp +++ b/ggml/src/ggml-sycl/fattn-vec.hpp @@ -664,4 +664,11 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q8_0) + #endif // GGML_SYCL_FATTN_VEC_HPP diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp index c276ed89827..7c6e6112fdc 100644 --- a/ggml/src/ggml-sycl/fattn.cpp +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -34,6 +34,7 @@ FATTN_VEC_CASE( 64, type_K, type_V) \ FATTN_VEC_CASE(128, type_K, type_V) \ FATTN_VEC_CASE(256, type_K, type_V) \ + FATTN_VEC_CASE(512, type_K, type_V) \ static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_tensor * Q = dst->src[0]; @@ -141,6 +142,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: + case 512: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } @@ -185,7 +187,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const } // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 512 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; // Todo: Use the XMX kernel if possible: diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 8c76afbd571..ebc587524bf 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -55,7 +55,7 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[i * S_v + col]; + s_shard[r] = curr_state[col * S_v + i]; } for (int t = 0; t < n_tokens; t++) { @@ -137,7 +137,7 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - state[i * S_v + col] = s_shard[r]; + state[col * S_v + i] = s_shard[r]; } } @@ -176,14 +176,12 @@ static void launch_gated_delta_net(const float * q_d, const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1); const sycl::uint3 rq3_magic = init_fastdiv_values(rq3); - int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc; - switch (S_v) { case 16: { constexpr int sv = 16; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -194,7 +192,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 32; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -205,7 +203,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 64; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -217,7 +215,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 128; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index dcf6c7aeeb4..c202da110be 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -29,6 +29,9 @@ class DnnlGemmWrapper { static constexpr dt to_dt() { if constexpr (std::is_same_v) return dt::f32; else if constexpr (std::is_same_v) return dt::f16; +#ifdef GGML_SYCL_HAS_BF16 + else if constexpr (std::is_same_v) return dt::bf16; +#endif else static_assert(0); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 12819705849..f06147eeeb8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -44,7 +44,6 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" #include "ggml-sycl/element_wise.hpp" -#include "ggml-sycl/gated_delta_net.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/getrows.hpp" #include "ggml-sycl/norm.hpp" @@ -105,6 +104,7 @@ static ggml_sycl_device_info ggml_sycl_init() { info.max_work_group_sizes[i] = prop.get_max_work_group_size(); info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); + info.devices[i].hw_info = get_device_hw_info(&device); } @@ -412,11 +412,22 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && - !g_ggml_sycl_disable_optimize) { - ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; - tensor->extra = extra; - ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. + + if (!g_ggml_sycl_disable_optimize) { + // set reorder extra buffer based on supported type + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K:{ + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + tensor->extra = extra; + ctx->tensor_extras.push_back(extra); + break; + } + default: + break; + } } if (ggml_is_quantized(tensor->type)) { @@ -570,9 +581,15 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, SYCL_CHECK( CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); - SYCL_CHECK(CHECK_TRY_ERROR((*stream) - .memset(ctx->dev_ptr, value, buffer->size) - .wait())); + constexpr size_t MAX_CHUNK = 2ULL << 30; // 2 GiB + for (size_t off = 0; off < buffer->size; off += MAX_CHUNK) { + size_t chunk = std::min(buffer->size - off, MAX_CHUNK); + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memset(static_cast(ctx->dev_ptr) + off, value, chunk) + .wait() + )); + } } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -622,6 +639,8 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, /* .clear = */ ggml_backend_sycl_buffer_clear, /* .reset = */ ggml_backend_sycl_buffer_reset, @@ -1068,6 +1087,8 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_sycl_split_buffer_clear, /* .reset = */ NULL, @@ -2156,6 +2177,31 @@ inline void ggml_sycl_op_mul_mat_sycl( #else bool use_fp16 = false; #endif + +#if GGML_SYCL_DNNL && defined(GGML_SYCL_HAS_BF16) + // Fast path for bf16 src0 + if (src0->type == GGML_TYPE_BF16 && !g_ggml_sycl_disable_dnn && ggml_is_contiguous(src0) && + row_diff == src0->ne[1]) { + using bf16_t = sycl::ext::oneapi::bfloat16; + ggml_sycl_pool_alloc src1_as_bf16(ctx.pool(), src1_ncols*ne10); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_sycl_t to_bf16_sycl = ggml_get_to_bf16_sycl(src1->type, dst); + GGML_ASSERT(to_bf16_sycl != nullptr); + to_bf16_sycl(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream); + } else { + stream->memcpy(src1_as_bf16.get(), src1_ddf_i, src1_ncols*ne10*sizeof(bf16_t)); + } + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, + src0_dd_i, DnnlGemmWrapper::to_dt(), + src1_as_bf16.get(), DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); + return; + } +#endif + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); @@ -3249,6 +3295,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: return true; case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: @@ -3261,6 +3308,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: return true; default: return false; @@ -3270,6 +3318,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: return true; @@ -3325,9 +3374,55 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { sycl::free(ptr, *stream); } -static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, +// RAII wrapper for temporary reorder buffers with optional host memory fallback. +// When device allocation fails and GGML_SYCL_HOST_MEM_FALLBACK is enabled, +// falls back to host memory so the reorder kernel can still run (over PCIe). +// Device access to host memory requires Linux kernel 6.8+ (Ubuntu 26.04+). +struct sycl_reorder_temp_buffer { + void * ptr = nullptr; + dpct::queue_ptr stream; + + sycl_reorder_temp_buffer(dpct::queue_ptr stream, size_t size) : stream(stream) { + ptr = sycl_ext_malloc_device(stream, size); +#ifdef GGML_SYCL_HOST_MEM_FALLBACK + if (!ptr) { + ptr = sycl::malloc_host(size, *stream); + if (ptr) { + host_fallback = true; + GGML_LOG_WARN("%s: device alloc of %zu bytes failed, using host memory fallback\n", __func__, size); + } + } +#endif + } + + ~sycl_reorder_temp_buffer() { + if (!ptr) { + return; + } + if (host_fallback) { + sycl::free(ptr, *stream); + } else { + sycl_ext_free(stream, ptr); + } + } + + explicit operator bool() const { return ptr != nullptr; } + + sycl_reorder_temp_buffer(const sycl_reorder_temp_buffer &) = delete; + sycl_reorder_temp_buffer & operator=(const sycl_reorder_temp_buffer &) = delete; + +private: + bool host_fallback = false; +}; + +static bool reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, dpct::queue_ptr stream) { - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3356,16 +3451,60 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; +} + +static bool reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + GGML_ASSERT((size % sizeof(block_q8_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q8_0) == 0)); + int offset_blks = offset / sizeof(block_q8_0); + auto qs_ptr = data_device + offset_blks * QK8_0; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks; + + auto reorder_event = stream->parallel_for( + size / sizeof(block_q8_0), + [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const block_q8_0* x = (const block_q8_0*)tmp_buf; + const int ib = i; + + for (int j = 0; j < QK8_0; j++) + { + *((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; } -static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { +static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q4_K) == 0); GGML_ASSERT(offset % sizeof(block_q4_K) == 0); const int nblocks = size / sizeof(block_q4_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3394,16 +3533,21 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { +static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q6_K) == 0); GGML_ASSERT(offset % sizeof(block_q6_K) == 0); const int nblocks = size / sizeof(block_q6_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3442,10 +3586,10 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { +static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { uint8_t * data_device = (uint8_t *) src0->data; size_t ncols = src0->ne[0]; size_t nrows = src0->ne[1]; @@ -3453,17 +3597,16 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { switch (src0->type) { case GGML_TYPE_Q4_0: - reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); - break; + return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); + case GGML_TYPE_Q8_0: + return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q4_K: - reorder_qw_q4_k(data_device, size, 0, stream); - break; + return reorder_qw_q4_k(data_device, size, 0, stream); case GGML_TYPE_Q6_K: - reorder_qw_q6_k(data_device, size, 0, stream); - break; + return reorder_qw_q6_k(data_device, size, 0, stream); default: GGML_ABORT("reorder_qw() called with unsupported type"); - break; + return false; } } @@ -3503,8 +3646,9 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * break; } - reorder_qw(src0, ctx->stream()); - extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering + if (reorder_qw(src0, ctx->stream())) { + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering + } } @@ -3560,9 +3704,16 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization // is enabled takes precedence over DMMV, the current if-else implementation // requires disabling DMMV if both conditions are met + if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + // Arc770 get benefit with Q4_0 by skipping it. + if (!(ggml_sycl_info().devices[ctx.device].hw_info.arch == + gpu_arch::intel_gpu_acm_g10 && + src0->type == GGML_TYPE_Q4_0)) { + use_dequantize_mul_mat_vec = + use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } } if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { @@ -3665,6 +3816,51 @@ __dpct_inline__ static void k_copy_dst_from_contiguous( } } +// Fused MoE TG fast path. Returns false to fall back to the per-expert loop below. +static bool ggml_sycl_mul_mat_id_mmvq_fused( + ggml_backend_sycl_context & ctx, const ggml_tensor * src0, + const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) +{ + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + if (ne12 != 1) return false; + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) return false; + if (ne10 != src0->ne[0] || ne10 % QK8_1 != 0) return false; + if (!ggml_is_contiguous(src1)) return false; + + // Reorder layout not supported; fall back. + const ggml_tensor_extra_gpu * src0_extra = + static_cast(src0->extra); + if (src0_extra && src0_extra->optimized_feature.reorder) return false; + + const int64_t n_ids_per_group = ids->ne[0]; + if (ids->ne[1] != 1) return false; + if (ne11 != 1 && ne11 != n_ids_per_group) return false; + + const queue_ptr stream = ctx.stream(); + const int src1_padded_cols = GGML_PAD((int) ne10, MATRIX_ROW_PADDING); + const int n_experts_used = (int) n_ids_per_group; + const int nrows = (int) src0->ne[1]; + + ggml_sycl_pool_alloc src1_q8_alloc(ctx.pool(), + (size_t) ne11 * src1_padded_cols * sizeof(block_q8_1) / QK8_1); + char * src1_ddq = src1_q8_alloc.get(); + quantize_row_q8_1_sycl( + (const float *) src1->data, src1_ddq, (int) ne10, (int) ne11, + src1_padded_cols, stream); + + const size_t bytes_per_qrow = (size_t) src1_padded_cols * sizeof(block_q8_1) / QK8_1; + const size_t src1_row_stride = (ne11 == 1) ? 0 : bytes_per_qrow; + + return ggml_sycl_mul_mat_vec_q_id( + src0->type, src0->data, src1_ddq, (const int32_t *) ids->data, + (float *) dst->data, (int) ne10, nrows, n_experts_used, + /*expert_weight_stride=*/ src0->nb[2], + /*dst_row_stride=*/ dst->nb[1], + src1_row_stride, stream); +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -3680,6 +3876,12 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; + if (ne12 == 1) { + if (ggml_sycl_mul_mat_id_mmvq_fused(ctx, src0, src1, ids, dst)) { + return; + } + } + std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; @@ -3730,8 +3932,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, } } } else { - ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + const int64_t n_routed_rows = ids->ne[1] * n_ids; + ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne10); + ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne0); src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); @@ -4497,6 +4700,8 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .free = */ ggml_backend_sycl_free, /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, // // TODO: update for the new // interface @@ -4665,26 +4870,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = op->src[1]; - if (a->ne[3] != b->ne[3]) { + // disable Q1_0 until implementation + if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) { return false; } - ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || - a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || - a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M - ) { - if (b->ne[1] == 1 && ggml_nrows(b) > 1) { - return false; - } - } - ggml_type src0_type = op->src[0]->type; - if (src0_type == GGML_TYPE_BF16 ) { - // TODO: support GGML_TYPE_BF16 - // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added + + if (a->ne[3] != b->ne[3]) { return false; } + ggml_type src0_type = op->src[0]->type; + + + // TODO: The configuration below needs more work to be supported with oneDNN if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) { @@ -4863,9 +5061,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: case GGML_OP_IM2COL: - return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + return true; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 316aa0d0fb5..8fa2198f35a 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -537,9 +537,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float } } +static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -662,6 +679,25 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -762,9 +798,9 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -806,9 +842,9 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -1084,7 +1120,13 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q2_K: mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); @@ -1145,8 +1187,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_MXFP4: mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; default: - GGML_ABORT("fatal error"); + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type)); } } GGML_UNUSED(src1); @@ -1154,3 +1199,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_UNUSED(src1_ddf_i); GGML_UNUSED(ctx); } + +// src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj). +template +static void mul_mat_vec_q_moe( + const void * __restrict__ vx_base, const void * __restrict__ vy_base, + float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev, + const int ncols, const int nrows, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + const sycl::nd_item<3> & item_ct1) { + + const int expert_idx = item_ct1.get_group(1); + const int i02 = ids_dev[expert_idx]; + + const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride; + const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride; + float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void launch_mul_mat_vec_q_moe( + const void * vx_base, const void * vy, const int32_t * ids_dev, + float * dst_base, const int ncols, const int nrows, const int n_experts_used, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + dpct::queue_ptr stream) { + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_moe( + vx_base, vy, dst_base, ids_dev, ncols, nrows, + expert_weight_stride, dst_row_stride, src1_row_stride, item); + }); + }); +} + +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, + const void * vy, + const int32_t * ids_dev, + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, + size_t dst_row_stride, + size_t src1_row_stride, + dpct::queue_ptr stream) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_1: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_1: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q8_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q2_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q3_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q6_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_MXFP4: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_NVFP4: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-sycl/mmvq.hpp b/ggml/src/ggml-sycl/mmvq.hpp index 049b43d4535..d674dc1d61e 100644 --- a/ggml/src/ggml-sycl/mmvq.hpp +++ b/ggml/src/ggml-sycl/mmvq.hpp @@ -24,4 +24,20 @@ void ggml_sycl_op_mul_mat_vec_q( const int64_t src1_ncols, const int64_t src1_padded_row_size, const dpct::queue_ptr &stream); +// Requires standard (non-reorder) block layout for src0. +// Returns false if src0_type isn't handled; caller should fall back. +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, // start of stacked expert weights + const void * vy, // pre-quantized src1 (Q8_1) + const int32_t * ids_dev, // device-side int32, length n_experts_used + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, // bytes between experts in vx_base + size_t dst_row_stride, // bytes between dst rows + size_t src1_row_stride, // 0 = shared src1, else per-expert stride in bytes + dpct::queue_ptr stream); + #endif // GGML_SYCL_MMVQ_HPP diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 14490fea5be..1f5b62740a8 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -105,6 +105,27 @@ template <> struct block_q_t { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK8_0; // 32 + static constexpr uint32_t qi = QI8_0; // 8 + static constexpr uint32_t qr = QR8_0; // 1 + static constexpr uint32_t vdr_mmvq = 4; + }; + + // Q8_0 reorder layout: [qs0|qs1|...|qsN][d0|d1|...|dN] + // Each block has 32 int8 weights (32 bytes) followed by all scales + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * QK8_0, 0 }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + return { (ncols * nrows) + block_index * sizeof(ggml_half), 0 }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } // 1 +}; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index a641c100913..8fb41943525 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -4,7 +4,11 @@ namespace utils { template static constexpr bool is_arithmetic_v() { - return std::is_arithmetic_v || std::is_same_v || std::is_same_v; + return std::is_arithmetic_v || std::is_same_v +#ifdef GGML_SYCL_HAS_BF16 + || std::is_same_v +#endif + ; } } @@ -181,6 +185,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#ifdef GGML_SYCL_HAS_BF16 case GGML_TYPE_BF16: set_rows_sycl( src0_d, src1_d, (char *)dst->data, @@ -193,6 +198,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#endif case GGML_TYPE_Q8_0: set_rows_sycl_q(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream); break; diff --git a/ggml/src/ggml-sycl/sycl_hw.cpp b/ggml/src/ggml-sycl/sycl_hw.cpp index 7041140034b..03b0c37a3cd 100644 --- a/ggml/src/ggml-sycl/sycl_hw.cpp +++ b/ggml/src/ggml-sycl/sycl_hw.cpp @@ -1,15 +1,67 @@ #include "sycl_hw.hpp" -// TODO: currently not used -/* -sycl_hw_info get_device_hw_info(sycl::device *device_ptr) { - sycl_hw_info res; - int32_t id = device_ptr->get_info(); - res.device_id = id; +using namespace std; - syclex::architecture arch = device_ptr->get_info(); - res.arch = arch; +/*defined in +* /opt/intel/oneapi/compiler/latest/include/sycl/ext/oneapi/experimental/device_architecture.def +*/ +static map> arch2name = { + {gpu_arch::intel_gpu_bdw, {"intel_gpu_bdw", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_skl, {"intel_gpu_skl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_kbl, {"intel_gpu_kbl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_cfl, {"intel_gpu_cfl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_apl, {"intel_gpu_apl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_glk, {"intel_gpu_glk", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_whl, {"intel_gpu_whl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_aml, {"intel_gpu_aml", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_cml, {"intel_gpu_cml", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_icllp, {"intel_gpu_icllp", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_ehl, {"intel_gpu_ehl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_tgllp, {"intel_gpu_tgllp", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_rkl, {"intel_gpu_rkl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_s, {"intel_gpu_adl_s", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_p, {"intel_gpu_adl_p", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_n, {"intel_gpu_adl_n", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_dg1, {"intel_gpu_dg1", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g10, {"intel_gpu_acm_g10", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g11, {"intel_gpu_acm_g11", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g12, {"intel_gpu_acm_g12", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_pvc, {"intel_gpu_pvc", GPU_FAMILY_DGPU_CLOUD}}, + {gpu_arch::intel_gpu_pvc_vg, {"intel_gpu_pvc_vg", GPU_FAMILY_DGPU_CLOUD}}, + {gpu_arch::intel_gpu_mtl_u, {"intel_gpu_mtl_u", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_mtl_h, {"intel_gpu_mtl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_arl_h, {"intel_gpu_arl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_bmg_g21, {"intel_gpu_bmg_g21", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_bmg_g31, {"intel_gpu_bmg_g31", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_lnl_m, {"intel_gpu_lnl_m", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_ptl_h, {"intel_gpu_ptl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_ptl_u, {"intel_gpu_ptl_u", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_wcl, {"intel_gpu_wcl", GPU_FAMILY_IGPU_XE}} +}; + + +sycl_hw_info get_device_hw_info(sycl::device* device_ptr) { + sycl_hw_info res; + int32_t id = + device_ptr->get_info(); + res.device_id = id; + + res.name = device_ptr->get_info(); - return res; + syclex::architecture arch = + device_ptr->get_info(); + res.arch = arch; + + map>::iterator it = + arch2name.find(res.arch); + if (it != arch2name.end()) { + res.arch_name = it->second.first; + res.gpu_family = it->second.second; + } else { + res.arch_name = "unknown"; + res.gpu_family = GPU_FAMILY_UKNOWN; + } + + return res; } -*/ diff --git a/ggml/src/ggml-sycl/sycl_hw.hpp b/ggml/src/ggml-sycl/sycl_hw.hpp index 36b140bf037..a5d20462572 100644 --- a/ggml/src/ggml-sycl/sycl_hw.hpp +++ b/ggml/src/ggml-sycl/sycl_hw.hpp @@ -9,18 +9,30 @@ #include namespace syclex = sycl::ext::oneapi::experimental; +using gpu_arch = sycl::ext::oneapi::experimental::architecture; + +// It's used to mark the GPU computing capacity +// The value must flow the order of performance. +enum sycl_intel_gpu_family { + GPU_FAMILY_UKNOWN = -1, + // iGPU without Xe core, before Meteor Lake iGPU(Xe) + GPU_FAMILY_IGPU_NON_XE = 0, + // iGPU with Xe core, Meteor Lake iGPU or newer. + GPU_FAMILY_IGPU_XE = 1, + // dGPU for gaming in client/data center (DG1/FLex 140 or newer). + GPU_FAMILY_DGPU_CLIENT_GAME = 2, + // dGPU for AI in cloud, PVC or newer. + GPU_FAMILY_DGPU_CLOUD = 3 +}; -// TODO: currently not used -/* struct sycl_hw_info { syclex::architecture arch; + const char* arch_name; int32_t device_id; + std::string name; + sycl_intel_gpu_family gpu_family; }; -bool is_in_vector(std::vector &vec, int item); - sycl_hw_info get_device_hw_info(sycl::device *device_ptr); -*/ - #endif // SYCL_HW_HPP diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp new file mode 100644 index 00000000000..9a6a1877566 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp @@ -0,0 +1,6 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(512, 512); + diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp index 32cf4f2859b..43ef94c118c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp index a61a19021bb..9404061d456 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp index 63b74fb347a..a8bb9f52d0c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp index 46e2d9853c5..7d61f6ab0af 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp index 7aabb6ff6e4..753bae09f83 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp index 148ea217f62..546a93b2570 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp index 4b169dbcdbc..53c8c2f2654 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp index 79f530b1815..5b409c55f21 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp index 2f7db51ce82..8c4ef588d63 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp index 9e3bf0b14a1..83f0a07552e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp index 18081879cec..9df9b03bba4 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp index 1c387b0d87c..6980c2a65bb 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp index f005b3762cc..bd61bc1dc2b 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp index 3553b1cdd16..492e229a58e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp index 687ec567115..30f88a2ebd5 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp index 2663bfe7466..db76663604e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp index 641b7c7ae2a..1dbcc8a85a8 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp index 3d3181d4719..d30996a6259 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp index 85d5026ad4f..bc0f635d922 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp index 1e81401a2c9..9e0378107cb 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp index 54251473f97..a8535ac9156 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp index d418c1fb21e..43d4fae9a61 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp index 0f26cfabd09..23335a41640 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp index 4fb98723519..52550a33757 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp index 85b79cd1976..4651f14c050 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp index 7348323b28b..2310fd8792c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp index f19af2aa0ba..d2494048bc1 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp index d7075bac600..be3a1fe97f5 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp index 627f9a57755..be0a89409ca 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp index 23304eecd35..6781efcb0d2 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp index 95acb5d4fbf..43a70ae3543 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp index 5e88f4bab8a..fa7eb8163ca 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp index 69f297feb0c..79d9cfbee96 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp index 455842a9421..86befd5d327 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp index f7ef7391571..c2f619b0b16 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp index 1c633bdf2fa..7cf31f8b8a1 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/type.hpp b/ggml/src/ggml-sycl/type.hpp new file mode 100644 index 00000000000..d7ff89d7d42 --- /dev/null +++ b/ggml/src/ggml-sycl/type.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +inline uint8_t float_to_e4m3(float f) +{ + if (sycl::isnan(f)) { + return 0x7F; // Canonical NaN (positive) + } + + uint32_t bits = sycl::bit_cast(f); + uint32_t sign = (bits >> 31) & 0x1u; + uint32_t exp = (bits >> 23) & 0xFFu; + uint32_t mant = bits & 0x7FFFFFu; + + // Zero + if (exp == 0 && mant == 0) { + return static_cast(sign << 7); + } + + // Extract biased exponent and mantissa for FP8 + int e = static_cast(exp) - 127; // true exponent (IEEE bias 127) + uint32_t m = mant; + + // Handle very large values → NaN (NVIDIA behavior for E4M3) + if (e > 7) { // max exponent for E4M3 is 7 (biased 14) + return static_cast((sign << 7) | 0x7F); + } + + // Handle subnormals and normal numbers + if (e < -6) { // smallest normal exponent is -6 + // Subnormal in FP8: shift mantissa right + int shift = -6 - e; + m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position + if (shift > 23) m = 0; + } else { + // Normal number: adjust exponent bias from 127 to 7 + int new_exp = e + 7; + m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1) + m |= (static_cast(new_exp) << 3); + } + + // Round-to-nearest-even (simple guard + round bit) + // For better accuracy you can add sticky bit, but this is sufficient for most use cases + uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits + if (round_bit) { + m += 1; + // Carry into exponent if mantissa overflows + if ((m & 0x8u) != 0) { + m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling + // If exponent overflows after carry → NaN + if ((m >> 3) > 14) { + return static_cast((sign << 7) | 0x7F); + } + } + } + + uint8_t result = static_cast((sign << 7) | (m & 0x7F)); + return result; +} + +inline float e4m3_to_float(uint8_t x) +{ + if (x == 0) return 0.0f; + + uint8_t sign = (x >> 7) & 0x1u; + uint8_t exp = (x >> 3) & 0xFu; + uint8_t mant = x & 0x7u; + + // NaN (NVIDIA uses 0x7F / 0xFF as NaN) + if (exp == 0xF && mant != 0) { + return std::numeric_limits::quiet_NaN(); + } + if (exp == 0xF) { // 0x7F or 0xFF treated as NaN + return std::numeric_limits::quiet_NaN(); + } + + float val; + + if (exp == 0) { + // Subnormal + val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f); + } else { + // Normal: implicit leading 1 + bias 7 + val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast(exp) - 7.0f); + } + + return sign ? -val : val; +} + +// The actual type definition +struct __nv_fp8_e4m3 { + uint8_t raw; + + __nv_fp8_e4m3() = default; + + explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {} + explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast(h))) {} + + operator float() const { return e4m3_to_float(raw); } + operator sycl::half() const { return static_cast(static_cast(*this)); } + + // Allow direct access for vector loads/stores + operator uint8_t&() { return raw; } + operator uint8_t() const { return raw; } +}; + +using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>; +using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>; + diff --git a/ggml/src/ggml-sycl/upscale.cpp b/ggml/src/ggml-sycl/upscale.cpp new file mode 100644 index 00000000000..e42cb419d83 --- /dev/null +++ b/ggml/src/ggml-sycl/upscale.cpp @@ -0,0 +1,410 @@ +#include "upscale.hpp" + +static void upscale_f32(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int ne13, + const float sf0, const float sf1, const float sf2, const float sf3) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + int index = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (index >= ne10 * ne11 * ne12 * ne13) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *((const float*)((const char*)x + i03 * nb03 + i02 * nb02 + + i01 * nb01 + i00 * nb00)); +} + +static void upscale_f32_bilinear(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + int y0_src = (int) sycl::floor((float) y_src_f); + int y1_src = y0_src + 1; + + y0_src = sycl::max(0, sycl::min(y0_src, ne01_src - 1)); + y1_src = sycl::max(0, sycl::min(y1_src, ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = sycl::max(0.0f, sycl::min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + int x0_src = (int) sycl::floor(x_src_f); + int x1_src = x0_src + 1; + + x0_src = sycl::max(0, sycl::min(x0_src, ne00_src - 1)); + x1_src = sycl::max(0, sycl::min(x1_src, ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = sycl::max(0.0f, sycl::min(dx, 1.0f)); + + const float* p_a = + (const float*)((const char*)x + (int64_t)x0_src * nb00 + + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_b = + (const float*)((const char*)x + (int64_t)x1_src * nb00 + + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_c = + (const float*)((const char*)x + (int64_t)x0_src * nb00 + + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_d = + (const float*)((const char*)x + (int64_t)x1_src * nb00 + + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst[index] = result; +} + +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static void upscale_f32_bilinear_antialias(const float * src0, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = sycl::max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = sycl::max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = sycl::max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = sycl::min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = sycl::max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = sycl::min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return sycl::max(1.0f - sycl::fabs(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = + *(const float*)((const char*)src0 + sx * nb00 + sy * nb01 + + i02_src * nb02 + i03_src * nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + +namespace bicubic_interpolation { +static float weight1(float x, const float &a) { return ((a + 2) * x - (a + 3)) * x * x + 1; }; +static float weight2(float x, const float &a) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; }; + +static float bicubic(float p0, float p1, float p2, float p3, float x, float a) { + const float w0 = weight2(x + 1, a); + const float w1 = weight1(x + 0, a); + const float w2 = weight1(1 - x, a); + const float w3 = weight2(2 - x, a); + return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3; +}; + +} + +static void upscale_f32_bicubic(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const float a = -0.75f; + using bicubic_interpolation::bicubic; + + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = + ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + const int y0_src = (int) sycl::floor((float) y_src_f); + const float dy = y_src_f - (float)y0_src; + + const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + const int x0_src = (int) sycl::floor((float) x_src_f); + const float dx = x_src_f - (float)x0_src; + + const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03; + + auto load = [=](int x_off, int y_off) -> float { + int i00_src = sycl::max(0, sycl::min(x0_src + x_off, ne00_src - 1)); + int i01_src = sycl::max(0, sycl::min(y0_src + y_off, ne01_src - 1)); + return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01); + }; + + const float result = bicubic( + bicubic(load(-1, -1), load(0, -1), load(1, -1), load(2, -1), dx, a), + bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx, a), + bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx, a), + bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx, a), + dy, + a); + + dst[index] = result; +} + +static void upscale_f32_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10 * ne11 * ne12 * ne13; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); + }); +} + +static void upscale_f32_bilinear_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset, + bool antialias, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + if (antialias) { + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + upscale_f32_bilinear_antialias( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, + ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + upscale_f32_bilinear( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, + ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + } +} + +static void upscale_f32_bicubic_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + upscale_f32_bicubic( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, + ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + }); + } +} + +void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int mode_flags = dst->op_params[0]; + const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF); + + float sf0 = (float)dst->ne[0]/src0->ne[0]; + float sf1 = (float)dst->ne[1]/src0->ne[1]; + float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; + + float pixel_offset = 0.5f; + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 + ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) + : sf0; + sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 + ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) + : sf1; + pixel_offset = 0.0f; + } + + if (mode == GGML_SCALE_MODE_NEAREST) { + upscale_f32_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); + upscale_f32_bilinear_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); + } else if (mode == GGML_SCALE_MODE_BICUBIC) { + upscale_f32_bicubic_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, stream); + } +} + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_upscale(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/upscale.hpp b/ggml/src/ggml-sycl/upscale.hpp new file mode 100644 index 00000000000..c36c1bdc970 --- /dev/null +++ b/ggml/src/ggml-sycl/upscale.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include +#include "dpct/helper.hpp" +#include "common.hpp" + +#define SYCL_UPSCALE_BLOCK_SIZE 256 + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 9a267d85a0c..9253168e5ea 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -15,6 +15,7 @@ #include "dpct/helper.hpp" #include "ggml.h" +#include "type.hpp" #include "quants.hpp" typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, @@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) { return x32; } +static __dpct_inline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __dpct_inline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -338,6 +351,46 @@ template <> struct reorder_vec_dot_q_sycl { }; }; +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q8_0; + + using q8_0_block = ggml_sycl_reordered::block_q_t; + using q8_0_traits = typename q8_0_block::traits; + + __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + // Q8_0 values are signed int8, no nibble extraction needed + // Direct dp4a: each int packs 4 int8 values + sumi = dpct::dp4a(v[i], u[i], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // Q8_0 has no bias term (values are signed), so just scale + return d8_0 * sumi * ds8f.x(); + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const int8_t * bq8_0 = static_cast(vbq) + ibx_offset.first; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first)); + int v[q8_0_traits::vdr_mmvq]; + int u[q8_0_traits::vdr_mmvq]; + +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_int8(bq8_0, iqs + i); + u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds); + }; +}; + static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { @@ -755,6 +808,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq, return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & iqs) { + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0]; + sum += d * float(sumi); + } + + return sum; +} static __dpct_inline__ float vec_dot_q5_0_q8_1(const void *__restrict__ vbq, diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp index 6b95362dd80..b6c561cd61e 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp @@ -101,6 +101,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor, /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, /* .clear = */ ggml_backend_remoting_buffer_clear, /* .reset = */ NULL, @@ -113,6 +115,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr, /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, /* .clear = */ ggml_backend_remoting_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp index a63ee2b9d2f..12756c9282f 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -34,6 +34,8 @@ static ggml_backend_i ggml_backend_remoting_interface = { /* .free = */ ggml_backend_remoting_free, /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7092361d2ea..c2f1883328f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -21,6 +21,20 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include +// SPIR-V Headers: different SDK installations expose different include paths. +// LunarG Vulkan SDK on Windows typically provides . +// Linux packages, MSYS2 and MinGW often use the Khronos layout . +#if __has_include() +# include +#elif __has_include() +# include +#elif __has_include() +# include +#else + // Fallback to let the compiler throw a standard "file not found" error +# include +#endif + #include #include #include @@ -191,6 +205,7 @@ struct vk_queue; struct vk_command_buffer { vk::CommandBuffer buf; + uint64_t use_counter = 0; bool in_use = false; }; @@ -425,10 +440,12 @@ struct vk_fa_pipeline_state { bool f32acc; uint32_t flags; uint32_t limit_occupancy_shmem; + ggml_type k_type; + ggml_type v_type; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) < - std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type); } }; @@ -784,6 +801,7 @@ struct vk_device_struct { vk_pipeline pipeline_arange_f32; vk_pipeline pipeline_fill_f32; + vk_pipeline pipeline_fill_f16; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -938,19 +956,24 @@ struct vk_subbuffer { } }; -// vk_event is used for the event-related backend interfaces. It uses 'event' for -// event_wait and 'fence' for event_synchronize. Polling on an event for +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +// vk_event is used for the event-related backend interfaces. It uses vk::Events for +// event_wait and a timeline semaphore for event_synchronize. Polling on an event for // event_synchronize wouldn't be sufficient to wait for command buffers to complete, // and would lead to validation errors. struct vk_event { + std::vector events_free; // Events available for reuse + std::vector events_submitted; // Events that are fully submitted and can be reused on next synchronize vk::Event event; - vk::Fence fence; - vk_command_buffer* cmd_buffer = nullptr; -}; + bool has_event; -struct vk_semaphore { - vk::Semaphore s; - uint64_t value; + vk_semaphore tl_semaphore; + vk_command_buffer* cmd_buffer = nullptr; + uint64_t cmd_buffer_use_counter = 0; }; struct vk_submission { @@ -1106,6 +1129,16 @@ struct vk_op_glu_push_constants { uint32_t mode; // 0: default, 1: swapped, 2: split float alpha; // for swiglu_oai float limit; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t ne01; + uint32_t ne02; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t ne11; + uint32_t ne12; }; struct vk_op_unary_push_constants { @@ -1371,7 +1404,7 @@ struct vk_op_im2col_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; uint32_t KW; uint32_t KH; - uint32_t pelements; + uint32_t OH_batch; uint32_t CHW; int32_t s0; int32_t s1; int32_t p0; int32_t p1; @@ -2115,6 +2148,66 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); + + // Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for + // separate shader variants compiled with -DRTE16. + std::vector spv; + if (device->float_controls_rte_fp16) { + const uint32_t* spv_words = reinterpret_cast(spv_data); + size_t word_count = spv_size / sizeof(uint32_t); + spv.assign(spv_words, spv_words + word_count); + + // Find insertion points respecting SPIR-V layout order: + // Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ... + size_t pos = 5; // skip header + size_t cap_insert_pos = pos; + size_t ext_insert_pos = pos; + size_t exec_insert_pos = pos; + uint32_t entry_point_id = 0; + + while (pos < spv.size()) { + uint32_t opcode = spv[pos] & spv::OpCodeMask; + uint32_t len = spv[pos] >> spv::WordCountShift; + if (len == 0) break; + + if (opcode == spv::OpCapability) { + cap_insert_pos = pos + len; + ext_insert_pos = pos + len; + } else if (opcode == spv::OpExtension) { + ext_insert_pos = pos + len; + } else if (opcode == spv::OpEntryPoint) { + entry_point_id = spv[pos + 2]; + exec_insert_pos = pos + len; + } else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) { + exec_insert_pos = pos + len; + } else if (entry_point_id != 0) { + break; + } + + pos += len; + } + + // Insert from latest position first so earlier indices stay valid. + + // OpExecutionMode %entrypoint RoundingModeRTE 16 + uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 }; + spv.insert(spv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode)); + + // OpExtension "SPV_KHR_float_controls" + const char ext_str[] = "SPV_KHR_float_controls"; + size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t)); + std::vector extension(1 + ext_str_words, 0); + extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension; + memcpy(&extension[1], ext_str, sizeof(ext_str)); + spv.insert(spv.begin() + ext_insert_pos, extension.begin(), extension.end()); + + // OpCapability RoundingModeRTE + uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE }; + spv.insert(spv.begin() + cap_insert_pos, std::begin(capability), std::end(capability)); + + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spv.size() * sizeof(uint32_t), spv.data()); + } + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); vk::PushConstantRange pcr( @@ -2319,7 +2412,7 @@ static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_comman vk::CommandBufferLevel::ePrimary, 1); const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); - p.cmd_buffers.push_back({ cmd_buffers.front(), true }); + p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true }); return &p.cmd_buffers[p.cmd_buffers.size()-1]; } @@ -2788,6 +2881,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct ); } +static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) { + VK_LOG_DEBUG("ggml_vk_set_event()"); + + ctx->s->buffer->buf.resetEvent( + event, + ctx->p->q->stage_flags + ); +} + static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) { VK_LOG_DEBUG("ggml_vk_set_event()"); @@ -2833,11 +2935,10 @@ struct vk_fa_tuning_params { } }; -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type); static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { - GGML_UNUSED(kv_type); vk_fa_tuning_params result{}; result.path = FA_SCALAR; @@ -2889,7 +2990,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; - if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) { + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) { result.block_rows /= 2; } @@ -2942,7 +3043,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device return result; } -static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { GGML_UNUSED(n_kv); GGML_UNUSED(f32acc); @@ -2956,7 +3057,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device if (small_rows) { result.block_rows = 32; result.block_cols = 32; - } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) { + } else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) { result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64; result.block_cols = 32; } else { @@ -2970,7 +3071,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device return result; } -static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { + // Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1. + if (k_type != v_type) { + GGML_ASSERT(device->coopmat2); + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); + } + FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; @@ -2982,7 +3089,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ if (path == FA_COOPMAT1) { bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); - const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); if (!shape_ok || !shmem_ok) { @@ -2995,20 +3102,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ path = FA_SCALAR; } + // Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it. + if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) { + path = FA_COOPMAT2; + } + switch (path) { case FA_SCALAR: - return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); case FA_COOPMAT1: - return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); case FA_COOPMAT2: - return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); default: throw std::runtime_error("unsupported FaCodePath"); } } static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, - bool use_mask, bool use_mask_opt, bool use_logit_softcap) { + bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) { const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary && (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2); @@ -3019,12 +3131,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; - return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem}; + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type}; } static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { - return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split, - state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem}; + const auto fa_block_bytes = [](ggml_type t) -> uint32_t { + // decodeBufF32 uses a block of vec4s for a better memory access pattern. + return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t); + }; + return { + /* 0 WorkGroupSize */ state.workgroup_size, + /* 1 Br */ state.Br, + /* 2 Bc */ state.Bc, + /* 3 HSK */ state.HSK, + /* 4 HSV */ state.HSV, + /* 5 Clamp */ static_cast(!state.aligned), + /* 6 D_split */ state.D_split, + /* 7 row_split */ state.row_split, + /* 8 SubGroupSize */ state.subgroup_size, + /* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u, + /*10 Flags */ state.flags, + /*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem, + /*12 FaTypeK */ static_cast(state.k_type), + /*13 FaTypeV */ static_cast(state.v_type), + /*14 FaBlockBytesK */ fa_block_bytes(state.k_type), + /*15 FaBlockBytesV */ fa_block_bytes(state.v_type), + }; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3055,6 +3187,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec case GGML_TYPE_MXFP4: lut_size = 4*16; break; + case GGML_TYPE_NVFP4: + // Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4). + lut_size = 4*16 + 128u * (uint32_t)sizeof(float); + break; default: break; } @@ -3420,13 +3556,47 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->fp16) { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product && device->subgroup_clustered) { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8) + } else +#endif + { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) + } } else { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product && device->subgroup_clustered) { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8) + } else +#endif + { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + } } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { @@ -3434,19 +3604,42 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1) } #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#define CREATE_FA_CM2_MIXED() \ + for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \ + for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \ + FaCodePath path = fa.first.path; \ + uint32_t Br = fa.first.Br; \ + uint32_t Bc = fa.first.Bc; \ + bool aligned = fa.first.aligned; \ + bool f32acc = fa.first.f32acc; \ + if (path == FA_COOPMAT2) { \ + if (aligned) { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ + } \ + } else { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ + } \ + } \ + } \ + } \ + } if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + CREATE_FA_CM2_MIXED(); } +#undef CREATE_FA_CM2_MIXED #endif #undef CREATE_FA @@ -3475,6 +3668,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) } #endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q1_0], matmul_q1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -3495,6 +3689,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) GGML_ASSERT(device->subgroup_ballot); @@ -3504,6 +3699,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5) } #endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) @@ -3524,6 +3720,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) #undef CREATE_MM #undef CREATE_MM2 } else @@ -3565,6 +3762,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->coopmat_acc_f16_support) { + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3586,7 +3784,9 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3608,6 +3808,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } GGML_ASSERT(device->subgroup_ballot); @@ -3621,6 +3822,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); @@ -3641,6 +3843,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #undef CREATE_MM2 #undef CREATE_MM } else @@ -3684,6 +3887,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3705,6 +3909,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3730,6 +3935,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3750,6 +3956,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3774,6 +3981,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3794,6 +4002,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3847,6 +4056,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3868,6 +4078,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3891,6 +4102,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_subgroup_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3911,12 +4123,14 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3937,6 +4151,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } } // reusing CREATE_MM from the fp32 path @@ -4014,6 +4229,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f32_f32", arr_dmmv_q1_0_f32_f32_len[reduc], arr_dmmv_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -4034,10 +4250,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f16_f32", arr_dmmv_q1_0_f16_f32_len[reduc], arr_dmmv_q1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -4058,6 +4276,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4088,6 +4307,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q1_0], "mul_mat_vec_id_q1_0_f32", arr_dmmv_id_q1_0_f32_f32_len[reduc], arr_dmmv_id_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); @@ -4108,6 +4328,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4142,6 +4363,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q1_0], "dequant_q1_0", dequant_q1_0_len, dequant_q1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 8, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -4162,11 +4384,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q1_0], "get_rows_q1_0", get_rows_q1_0_len, get_rows_q1_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -4187,11 +4411,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q1_0], "get_rows_q1_0_f32", get_rows_q1_0_f32_len, get_rows_q1_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -4212,6 +4438,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); @@ -4244,10 +4471,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); - if (device->float_controls_rte_fp16 && - sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { + if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -4272,43 +4498,32 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - } - -#define SET_ROWS(itype, rte) \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - - if (device->float_controls_rte_fp16) { - SET_ROWS(_i32, _rte) - SET_ROWS(_i64, _rte) - } else { - SET_ROWS(_i32, ) - SET_ROWS(_i64, ) - } + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + +#define SET_ROWS(itype) \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + + SET_ROWS(_i32) + SET_ROWS(_i64) #undef SET_ROWS + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q1_0], "cpy_q1_0_f32", cpy_q1_0_f32_len, cpy_q1_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q1_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); @@ -4324,11 +4539,10 @@ static void ggml_vk_load_shaders(vk_device& device) { return s; }; - bool rte = device->float_controls_rte_fp16; #define CREATE_BINARY(name, namemod, spec, bindings) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ - #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); CREATE_BINARY(add, , {0}, 4) @@ -4371,13 +4585,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - } + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4418,19 +4627,9 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(floor) CREATE_UNARY(trunc) CREATE_UNARY(sgn) + CREATE_UNARY(exp) #undef CREATE_UNARY -#define CREATE_UNARY_RTE(name) \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - } - CREATE_UNARY_RTE(exp) -#undef CREATE_UNARY_RTE - ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -4438,15 +4637,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fill_f16, "fill_f16", fill_f16_len, fill_f16_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - } + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); CREATE_GLU(geglu) CREATE_GLU(reglu) @@ -4479,25 +4674,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - } + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); @@ -4559,13 +4743,8 @@ static void ggml_vk_load_shaders(vk_device& device) { #define IM2COL(bda) \ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - } + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); if (device->shader_int64 && device->buffer_device_address) { IM2COL(_bda) } else { @@ -4589,12 +4768,42 @@ static void ggml_vk_load_shaders(vk_device& device) { {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, }; + const bool use_subgroup_reduce = device->subgroup_arithmetic; for (uint32_t si = 0; si < 3; si++) { + const uint32_t S_V = gdn_sizes[si]; + GGML_ASSERT(is_pow2(S_V)); + + uint32_t lanes_per_column; + if (S_V >= 128u && device->subgroup_clustered) { + lanes_per_column = 8u; + } else { + // Use largest power-of-two that divides both S_V and subgroup_size so that + // (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0. + // This means we don't need extra bounds checking logic in the shader. + lanes_per_column = std::min(S_V, device->subgroup_size); + } + + const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size); + size_t gdn_len; + const void * gdn_data; + if (use_subgroup_reduce && need_clustered_shader) { + gdn_len = gated_delta_net_f32_len; + gdn_data = (const void *)gated_delta_net_f32_data; + } else if (use_subgroup_reduce) { + gdn_len = gated_delta_net_f32_nocluster_len; + gdn_data = (const void *)gated_delta_net_f32_nocluster_data; + } else { + gdn_len = gated_delta_net_f32_shmem_len; + gdn_data = (const void *)gated_delta_net_f32_shmem_data; + } + + const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column; + const std::array wg_denoms = {1u, 1u, cols_per_wg}; + for (uint32_t kda = 0; kda < 2; kda++) { ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], - gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data, - "main", 7, sizeof(vk_op_gated_delta_net_push_constants), - {1, 1, 1}, {gdn_sizes[si], kda}, 1); + gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), + wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size); } } } @@ -4981,8 +5190,9 @@ static vk_device ggml_vk_get_device(size_t idx) { std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues - // On AMD, the graphics queue seems to be faster, so don't avoid it - const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics; + // Allow overriding avoiding the graphics queue because it can increase performance on RADV + const bool allow_graphics_queue = (getenv("GGML_VK_ALLOW_GRAPHICS_QUEUE") != nullptr); + const vk::QueueFlagBits graphics_flag = allow_graphics_queue ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics; const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1); const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1); @@ -4998,7 +5208,7 @@ static vk_device ggml_vk_get_device(size_t idx) { } else { device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); } - vk::DeviceCreateInfo device_create_info; + vk::DeviceCreateInfo device_create_info{}; std::vector device_extensions; vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); @@ -5367,12 +5577,10 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->name = GGML_VK_NAME + std::to_string(idx); - device_create_info = { - vk::DeviceCreateFlags(), - device_queue_create_infos, - {}, - device_extensions - }; + device_create_info + .setFlags(vk::DeviceCreateFlags()) + .setQueueCreateInfos(device_queue_create_infos) + .setPEnabledExtensionNames(device_extensions); device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); @@ -5443,11 +5651,14 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); + // Only use transfer queue on AMD non-GCN, when the graphics queue is not enabled + const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !allow_graphics_queue; + if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); - device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); + device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); } else { // TODO: Use pointer or reference to avoid copy device->transfer_queue.copyFrom(device->compute_queue); @@ -5953,6 +6164,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); switch (type) { case GGML_TYPE_F32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5973,6 +6185,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6024,6 +6237,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte } switch (src0_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6044,6 +6258,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6089,6 +6304,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6109,6 +6325,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6179,6 +6396,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6199,6 +6417,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6247,6 +6466,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6267,6 +6487,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6392,6 +6613,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer( static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) { for (auto& cmd_buffer : pool.cmd_buffers) { if (!cmd_buffer.in_use) { + cmd_buffer.use_counter++; cmd_buffer.in_use = true; return &cmd_buffer; } @@ -6496,14 +6718,15 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { } static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) { + vk_context result; if (!ctx->compute_ctx.expired()) { - return ctx->compute_ctx.lock(); - } - - vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + result = ctx->compute_ctx.lock(); + } else { + result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = result; - ggml_vk_ctx_begin(ctx->device, result); + ctx->compute_ctx = result; + ggml_vk_ctx_begin(ctx->device, result); + } if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { result->s->wait_semaphores.push_back(ctx->transfer_semaphore); @@ -6674,7 +6897,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } -static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { +static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); // Check if src is pinned memory vk_buffer buf = nullptr; @@ -6684,7 +6907,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz if (buf != nullptr) { // Memory is pinned, use as staging buffer std::vector slices(1); - if (width == spitch) { + if (width == spitch && width == dpitch) { // Only do single write if stride is equal slices[0].srcOffset = buf_offset; slices[0].dstOffset = offset; @@ -6693,7 +6916,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz slices.resize(height); for (size_t i = 0; i < height; i++) { slices[i].srcOffset = buf_offset + i * spitch; - slices[i].dstOffset = offset + i * width; + slices[i].dstOffset = offset + i * dpitch; slices[i].size = width; } } @@ -6710,21 +6933,30 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } // Staging buffer required - const size_t copy_size = width*height; - ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + const size_t staging_size = width * height; + ggml_vk_ensure_sync_staging_buffer(dst->device, staging_size); vk_buffer& staging_buffer = dst->device->sync_staging; - VkBufferCopy buf_copy = { - 0, - offset, - copy_size}; + std::vector slices(1); + if (width == dpitch) { + slices[0].srcOffset = 0; + slices[0].dstOffset = offset; + slices[0].size = staging_size; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = i * width; + slices[i].dstOffset = offset + i * dpitch; + slices[i].size = width; + } + } ggml_vk_sync_buffers(nullptr, subctx); - vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + subctx->s->buffer->buf.copyBuffer((VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, slices); if (width == spitch) { - deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, staging_size, &subctx->in_memcpys); } else { for (size_t i = 0; i < height; i++) { deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); @@ -6735,24 +6967,24 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); - return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, size, 1, sync_staging); } -static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); // Buffer is already mapped if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); } } else { std::lock_guard guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); - bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, dpitch, width, height, true); GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); @@ -6773,7 +7005,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); - ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); + ggml_vk_buffer_write_2d(dst, offset, src, size, size, size, 1); } static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { @@ -6819,15 +7051,35 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size } // Fall back to staging buffer - const size_t copy_size = dpitch * height; - ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + const size_t staging_size = width * height; + ggml_vk_ensure_sync_staging_buffer(src->device, staging_size); vk_buffer& staging_buffer = src->device->sync_staging; + std::vector staging_slices(1); + if (width == spitch) { + staging_slices[0].srcOffset = offset; + staging_slices[0].dstOffset = 0; + staging_slices[0].size = staging_size; + } else { + staging_slices.resize(height); + for (size_t i = 0; i < height; i++) { + staging_slices[i].srcOffset = offset + i * spitch; + staging_slices[i].dstOffset = i * width; + staging_slices[i].size = width; + } + } + ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices); + subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, staging_slices); - deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); + if (width == dpitch) { + deferred_memcpy(dst, staging_buffer->ptr, staging_size, &subctx->out_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) staging_buffer->ptr + i * width, width, &subctx->out_memcpys); + } + } return true; } @@ -6835,8 +7087,8 @@ static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); } -static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { - VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); +static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d(" << src->buffer << ", " << offset << ", " << width << ", " << height << ")"); // If the device is not an UMA device the memory is host-accessible through rebar. While writing // through PCIe is sufficient fast reading back data from PCIe is slower than going through @@ -6844,18 +7096,20 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - memcpy(dst, (uint8_t *) src->ptr + offset, size); + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + } } else { std::lock_guard guard(src->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); - bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + bool ret = ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch, dpitch, width, height, true); GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); - VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read_2d waitForFences"); src->device->device.resetFences({ src->device->fence }); ggml_vk_queue_command_pools_cleanup(src->device); @@ -6865,6 +7119,11 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ } } +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + ggml_vk_buffer_read_2d(src, offset, dst, size, size, size, 1); +} + static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); // Make sure both buffers are on same device @@ -6896,7 +7155,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr // Copy to src staging buffer ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); // Copy to dst buffer - ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1); + ggml_vk_buffer_write(dst, dst_offset, src->device->sync_staging->ptr, size); } } @@ -7192,6 +7451,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const } if (src->type == GGML_TYPE_F32) { switch (to) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -7206,6 +7466,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const if (to == GGML_TYPE_F32) { switch (src->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -7625,20 +7886,14 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: - if (k < 2048) { + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { + // Intel Windows proprietary driver MMVQ performance is worse than fp16, see + // https://github.com/ggml-org/llama.cpp/issues/17628 return false; } - if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { - // Intel Windows proprietary driver tuning - switch (src0_type) { - case GGML_TYPE_MXFP4: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - return false; - default: - return true; - } + if (k < 2048) { + return false; } switch (src0_type) { @@ -8687,7 +8942,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { GGML_UNUSED(f32acc); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; @@ -8696,21 +8951,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const bool mmq = device->integer_dot_product && device->subgroup_clustered && + (kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 || + kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 || + kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL); + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpshv4 = wg_size * 4 * float_type_size; const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; + uint32_t Qf, kvsh, kblocksh_size; + if (mmq) { + // block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds + const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size; + Qf = Br * (hsk / 32) * block_b_size; + + // kvsh uses D = HSV (K goes through kblocksh instead) + kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + // block_a_cache size depends on quant type + uint32_t block_a_size; + switch (kv_type) { + case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break; + case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break; + case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break; + case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break; + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break; + default: block_a_size = 0; break; + } + kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; + } else { + Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; + + const uint32_t D = std::max(hsk, hsv); + kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - const uint32_t D = std::max(hsk, hsv); - const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + kblocksh_size = 0; + } - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported); return supported; } @@ -8809,8 +9094,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(dst->type == GGML_TYPE_F32); assert(q->type == GGML_TYPE_F32); - assert(k->type == v->type); - uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; uint32_t workgroups_x = (uint32_t)neq1; @@ -8821,7 +9104,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). - vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc); + vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc); const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && @@ -8834,7 +9117,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); + tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); + + if (tuning_params.path != FA_COOPMAT2) { + GGML_ASSERT(k->type == v->type); + } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); @@ -8873,7 +9160,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16; vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, - mask != nullptr, use_mask_opt, logit_softcap != 0); + mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type); vk_pipeline pipeline = nullptr; @@ -9656,6 +9943,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_fill_f32; } + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_fill_f16; + } return nullptr; default: return nullptr; @@ -9876,7 +10166,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; - elements = { OW * KW * KH, OH, batch * IC }; + const uint32_t CHW = IC * KH * KW; + // Cap X workgroups to limit concurrent IC channel reads. + // The shader loops over X to cover the full CHW dimension. + // AMD prefers a lower limit + const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u; + const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW)); + elements = { x_elements, OW, OH * batch }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; @@ -10423,7 +10719,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, - pc, { H, n_seqs, 1u }); + pc, { H, n_seqs, S_v }); } static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { @@ -11003,8 +11299,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const float alpha = op_params_f[2]; const float limit = op_params_f[3]; - GGML_ASSERT(ggml_is_contiguous(src0)); - if (!split) { GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); } else { @@ -11022,7 +11316,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)dst->ne[0], mode, alpha, - limit + limit, + (uint32_t)(src0->nb[1] / src0->nb[0]), + (uint32_t)(src0->nb[2] / src0->nb[0]), + (uint32_t)(src0->nb[3] / src0->nb[0]), + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + (uint32_t)(dst->nb[1] / dst->nb[0]), + (uint32_t)(dst->nb[2] / dst->nb[0]), + (uint32_t)(dst->nb[3] / dst->nb[0]), + (uint32_t)dst->ne[1], + (uint32_t)dst->ne[2] }); } @@ -11531,7 +11835,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 - const uint32_t pelements = OW * KW * KH; const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; @@ -11543,7 +11846,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, - pelements, + OH * batch, IC * KH * KW, s0, s1, p0, p1, d0, d1, batch * IC }); @@ -12801,6 +13104,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (vk_perf_logger_enabled && vk_perf_logger_concurrent) { ctx->query_node_idx[ctx->query_idx] = node_idx; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } } // Add all fused nodes to the unsynchronized lists. @@ -13401,6 +13705,20 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } +static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " << + n_copies << ", " << stride_tensor << ", " << stride_data << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + if (size == 0) { + return; + } + + ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies); +} + static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -13414,6 +13732,21 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } +static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " << + n_copies << ", " << stride_tensor << ", " << stride_data << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + if (size == 0) { + return; + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies); +} + static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { if (ggml_nbytes(src) == 0) { return true; @@ -13448,6 +13781,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .set_tensor_2d = */ ggml_backend_vk_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_vk_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, /* .clear = */ ggml_backend_vk_buffer_clear, /* .reset = */ NULL, @@ -13603,8 +13938,9 @@ static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_b return &ctx->device->buffer_type; } -static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); +static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_2d_async(" << size << ", " << n_copies << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); @@ -13618,7 +13954,6 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor if (ctx->device->async_use_transfer_queue) { if (ctx->transfer_ctx.expired()) { - // Initialize new transfer context cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = cpy_ctx; ggml_vk_ctx_begin(ctx->device, cpy_ctx); @@ -13633,25 +13968,48 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size); + bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies); if (!ret) { - ggml_vk_ensure_sync_staging_buffer(ctx, size); + const size_t staging_size = size * n_copies; + ggml_vk_ensure_sync_staging_buffer(ctx, staging_size); ggml_vk_sync_buffers(nullptr, cpy_ctx); - vk::BufferCopy buffer_cpy; - buffer_cpy.srcOffset = 0; - buffer_cpy.dstOffset = dst_offset; - buffer_cpy.size = size; + std::vector slices(1); + if (size == stride_tensor) { + slices[0].srcOffset = 0; + slices[0].dstOffset = dst_offset; + slices[0].size = staging_size; + } else { + slices.resize(n_copies); + for (size_t i = 0; i < n_copies; i++) { + slices[i].srcOffset = i * size; + slices[i].dstOffset = dst_offset + i * stride_tensor; + slices[i].size = size; + } + } + + cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, slices); - cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys); + if (size == stride_data) { + deferred_memcpy(ctx->sync_staging->ptr, data, staging_size, &cpy_ctx->in_memcpys); + } else { + for (size_t i = 0; i < n_copies; i++) { + deferred_memcpy((uint8_t *)ctx->sync_staging->ptr + i * size, (const uint8_t *)data + i * stride_data, size, &cpy_ctx->in_memcpys); + } + } ggml_vk_synchronize(ctx); } } -static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_set_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size); +} + +static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_2d_async(" << size << ", " << n_copies << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); @@ -13666,24 +14024,45 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ vk_buffer buf = buf_ctx->dev_buffer; auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size); + bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies); - // If that failed, copy synchronously through a staging buffer if (!ret) { - ggml_vk_ensure_sync_staging_buffer(ctx, size); + const size_t staging_size = size * n_copies; + ggml_vk_ensure_sync_staging_buffer(ctx, staging_size); ggml_vk_sync_buffers(nullptr, compute_ctx); - vk::BufferCopy buffer_cpy; - buffer_cpy.srcOffset = src_offset; - buffer_cpy.dstOffset = 0; - buffer_cpy.size = size; + std::vector slices(1); + if (size == stride_tensor) { + slices[0].srcOffset = src_offset; + slices[0].dstOffset = 0; + slices[0].size = staging_size; + } else { + slices.resize(n_copies); + for (size_t i = 0; i < n_copies; i++) { + slices[i].srcOffset = src_offset + i * stride_tensor; + slices[i].dstOffset = i * size; + slices[i].size = size; + } + } - compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); - deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys); + compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, slices); + + if (size == stride_data) { + deferred_memcpy(data, ctx->sync_staging->ptr, staging_size, &compute_ctx->out_memcpys); + } else { + for (size_t i = 0; i < n_copies; i++) { + deferred_memcpy((uint8_t *)data + i * stride_data, (const uint8_t *)ctx->sync_staging->ptr + i * size, size, &compute_ctx->out_memcpys); + } + } ggml_vk_synchronize(ctx); } } +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); + ggml_backend_vk_get_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size); +} + static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context; @@ -13797,6 +14176,7 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { ctx->submit_pending = false; if (cmd_buf) { cmd_buf->in_use = false; + cmd_buf->buf.reset(); } } @@ -14158,8 +14538,7 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co } // conditions for pipeline creation - if (!(ctx->device->float_controls_rte_fp16 && - sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) { + if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) { return false; } @@ -14288,6 +14667,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg compute_ctx = ggml_vk_get_compute_ctx(ctx); ctx->query_idx = 0; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } ctx->prealloc_y_last_pipeline_used = nullptr; @@ -14524,6 +14904,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; ctx->query_fusion_names[ctx->query_idx] = fusion_string; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } else { // track a fusion string and number of fused ops for the current node_idx ctx->query_fusion_names[i] = fusion_string; @@ -14858,18 +15239,31 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset - // the backend interface doesn't have an explicit reset, so reset it here - // before we record the command to set it - ctx->device->device.resetEvent(vkev->event); - ctx->device->device.resetFences({ vkev->fence }); + if (vkev->has_event) { + // Move existing event into submitted + vkev->events_submitted.push_back(vkev->event); + } + + // Grab the next event and record it, create one if necessary + if (vkev->events_free.empty()) { + vkev->event = ctx->device->device.createEvent({}); + } else { + vkev->event = vkev->events_free.back(); + vkev->events_free.pop_back(); + } + + vkev->has_event = true; ggml_vk_set_event(compute_ctx, vkev->event); + vkev->tl_semaphore.value++; + compute_ctx->s->signal_semaphores.push_back(vkev->tl_semaphore); ggml_vk_ctx_end(compute_ctx); - ggml_vk_submit(compute_ctx, {vkev->fence}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; vkev->cmd_buffer = cmd_buf; + vkev->cmd_buffer_use_counter = cmd_buf->use_counter; ctx->compute_ctx.reset(); } @@ -14880,9 +15274,10 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); - ggml_vk_wait_events(compute_ctx, {vkev->event}); - ggml_vk_ctx_end(compute_ctx); - ctx->compute_ctx.reset(); + if (vkev->has_event) { + // Wait for latest event + ggml_vk_wait_events(compute_ctx, { vkev->event }); + } } // TODO: enable async and synchronize @@ -14891,6 +15286,8 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, + /* .set_tensor_2d_async = */ ggml_backend_vk_set_tensor_2d_async, + /* .get_tensor_2d_async = */ ggml_backend_vk_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, @@ -15157,8 +15554,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous(op->src[0]) && - (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type); default: @@ -15178,6 +15574,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15198,6 +15595,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return false; @@ -15246,42 +15644,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // It's straightforward to support different K/V dequant, but would - // significantly increase the number of pipelines - if (op->src[1]->type != op->src[2]->type) { + // mismatching K/V type is currently supported for coopmat2 only. + if (op->src[1]->type != op->src[2]->type && !coopmat2) { return false; } - switch (op->src[1]->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - // supported in scalar and coopmat2 paths - break; - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //case GGML_TYPE_Q2_K: - //case GGML_TYPE_Q3_K: - //case GGML_TYPE_Q4_K: - //case GGML_TYPE_Q5_K: - //case GGML_TYPE_Q6_K: - //case GGML_TYPE_IQ1_S: - //case GGML_TYPE_IQ1_M: - //case GGML_TYPE_IQ2_XXS: - //case GGML_TYPE_IQ2_XS: - //case GGML_TYPE_IQ2_S: - //case GGML_TYPE_IQ3_XXS: - //case GGML_TYPE_IQ3_S: - //case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - // currently supported only in coopmat2 path - if (!coopmat2) { + auto fa_kv_ok = [coopmat2](ggml_type t) { + switch (t) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_0: + return true; + case GGML_TYPE_Q1_0: + return coopmat2; + default: return false; } - break; - default: + }; + if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { return false; } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { @@ -15296,6 +15679,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15316,6 +15700,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_I32: return true; default: @@ -15328,6 +15713,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15351,6 +15737,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15365,6 +15752,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15492,8 +15880,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16); case GGML_OP_ARANGE: - case GGML_OP_FILL: return op->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; case GGML_OP_SCALE: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: @@ -15672,10 +16061,13 @@ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t return nullptr; } - // The event/fence is expected to initially be in the signaled state. - vkev->event = device->device.createEvent({}); - vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled}); - device->device.setEvent(vkev->event); + // No events initially, they get created on demand + vkev->has_event = false; + + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vkev->tl_semaphore = { device->device.createSemaphore(ci), 0 }; return new ggml_backend_event { /* .device = */ dev, @@ -15689,8 +16081,16 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe vk_event *vkev = (vk_event *)event->context; - device->device.destroyFence(vkev->fence); - device->device.destroyEvent(vkev->event); + device->device.destroySemaphore(vkev->tl_semaphore.s); + for (auto& event : vkev->events_free) { + device->device.destroyEvent(event); + } + for (auto& event : vkev->events_submitted) { + device->device.destroyEvent(event); + } + if (vkev->has_event) { + device->device.destroyEvent(vkev->event); + } delete vkev; delete event; } @@ -15701,10 +16101,29 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm auto device = ggml_vk_get_device(ctx->device); vk_event *vkev = (vk_event *)event->context; - VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); - // Finished using current command buffer so we flag for reuse - if (vkev->cmd_buffer) { - vkev->cmd_buffer->in_use = false; + // Only do something if the event has actually been used + if (vkev->has_event) { + vk::Semaphore sem = vkev->tl_semaphore.s; + uint64_t val = vkev->tl_semaphore.value; + vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, sem, val}; + VK_CHECK(device->device.waitSemaphores(swi, UINT64_MAX), "event_synchronize"); + + // Reset and move submitted events + for (auto& event : vkev->events_submitted) { + device->device.resetEvent(event); + } + vkev->events_free.insert(vkev->events_free.end(), vkev->events_submitted.begin(), vkev->events_submitted.end()); + vkev->events_submitted.clear(); + + // Finished using current command buffer so we flag for reuse + if (vkev->cmd_buffer) { + // Only flag for reuse if it hasn't been reused already + if (vkev->cmd_buffer_use_counter == vkev->cmd_buffer->use_counter) { + vkev->cmd_buffer->in_use = false; + vkev->cmd_buffer->buf.reset(); + } + vkev->cmd_buffer = nullptr; + } } } @@ -15958,6 +16377,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) case 0xE20C: // B570 return 18; case 0xE20B: // B580 + case 0xE211: // Pro B60 return 20; default: return 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 06df5095258..6a692147478 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -4,7 +4,7 @@ #include "generic_unary_head.glsl" #include "dequant_funcs.glsl" -#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) // 16 invocations needed for init_iq_shmem layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; #else diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index b8c40eec102..710c15296da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #if defined(SET_ROWS) && QUANT_K == 1 @@ -184,6 +183,31 @@ void quantize(uint dst_idx, uint src_idx) } #endif +#if defined(DATA_A_Q1_0) +void quantize(uint dst_idx, uint src_idx) +{ + float sum_abs = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0; j++) { + sum_abs += abs(data_s[src_idx + j]); + } + + const float d = sum_abs / QUANT_K_Q1_0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0 / 8; ++j) { + data_q[dst_idx].qs[j] = uint8_t(0); + } + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0; ++j) { + if (data_s[src_idx + j] >= 0.0) { + data_q[dst_idx].qs[j / 8] |= uint8_t(1 << (j % 8)); + } + } +} +#endif + #if defined(DATA_A_IQ4_NL) uint best_index(float x) { if (x <= kvalues_iq4nl[0]) return 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 7865a6bda79..88d07d2dfd5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -87,6 +87,23 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_Q1_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u); + return vec2( + (bits & 1u) != 0u ? 1.0f : -1.0f, + (bits & 2u) != 0u ? 1.0f : -1.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u); + return vec4( + (bits & 1u) != 0u ? 1.0f : -1.0f, + (bits & 2u) != 0u ? 1.0f : -1.0f, + (bits & 4u) != 0u ? 1.0f : -1.0f, + (bits & 8u) != 0u ? 1.0f : -1.0f); +} +#endif + #if defined(DATA_A_IQ1_S) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint ib32 = iqs / 32; @@ -433,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint sub = iqs >> 4; + const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]); + const uint j = iqs & 7; + const uint shift = (iqs & 8) >> 1; // 0 or 4 + const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]); + const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]); + const uint qs0 = (vui0 >> shift) & 0xF; + const uint qs1 = (vui1 >> shift) & 0xF; + return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5; +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 v1 = dequantize(ib, iqs + 2u, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -454,12 +490,25 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_Q1_0) +vec2 get_dm(uint ib, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + return vec2(d, 0); +} +#endif + #if defined(DATA_A_MXFP4) vec2 get_dm(uint ib, uint a_offset) { return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); } #endif +#if defined(DATA_A_NVFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1.0, 0.0); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 8ac6482dc94..c582aba87dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -13,6 +13,18 @@ float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], return vf16[idx]; } +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ1_0 { + block_q1_0 block; +}; + +float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint bit = (uint(bl.block.qs[(idx & 0x78) >> 3]) >> (idx & 0x7)) & 1u; + return bit != 0u ? d : -d; +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -685,7 +697,27 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords } #endif -#if defined(DATA_A_Q4_0) +#if defined(DATA_A_NVFP4) +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 { + block_nvfp4 block; +}; + +float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint sub = (idx & 0x30) >> 4; + const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7); + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + uint qs = uint(bl.block.qs[iqs]); + qs = (qs >> shift) & 0xF; + return float16_t(kvalues_mxfp4[qs] * d * 0.5); +} +#endif + +#if defined(DATA_A_Q1_0) +#define dequantFuncA dequantFuncQ1_0 +#elif defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 #elif defined(DATA_A_Q4_1) #define dequantFuncA dequantFuncQ4_1 @@ -729,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#elif defined(DATA_A_NVFP4) +#define dequantFuncA dequantFuncNVFP4 #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp new file mode 100644 index 00000000000..689089160b7 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint sub = tid / 16; + const uint ir = tid % 16; + const uint ib = 16 * i + ir; + if (ib >= p.nel / 64) { + return; + } + + const uint q_idx = 8 * sub; + const uint b_idx = 1024 * i + 64 * ir + 16 * sub; + + const float d = ue4m3_to_fp32(data_a[ib].d[sub]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp new file mode 100644 index 00000000000..ca0bdbc63e0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q1_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid / 4; + const uint ir = tid % 4; + const uint ib = 4*i + ir; + if (ib >= p.nel / 128) { + return; + } + + const uint b_idx = 512*i + 128*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint bits = uint(data_a[ib].qs[il]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l] = D_TYPE((bits & (1u << l)) != 0u ? d : -d); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp index cd3f42f4911..79761324f55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp index b69d4ddb096..c7cf5ec68f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "generic_head.glsl" #include "types.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ec48f5b1152..6e6bdabc92e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -10,6 +10,13 @@ #extension GL_EXT_shader_subgroup_extended_types_float16 : require #endif +#ifdef MMQ +#extension GL_EXT_integer_dot_product : require +#extension GL_KHR_shader_subgroup_clustered : require + +#include "mul_mmq_shmem_types.glsl" +#endif + #extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size]; const uint32_t masksh_stride = Br + 1; shared FLOAT_TYPE masksh[Bc * masksh_stride]; +#ifndef MMQ const uint32_t qf_stride = HSK / 4 + 1; shared FLOAT_TYPEV4 Qf[Br * qf_stride]; +#else +const uint32_t qf_stride = HSK / 32; +shared block_b_cache Qf[Br * qf_stride]; +#endif + +#ifndef MMQ const uint32_t D = HSK > HSV ? HSK : HSV; +#else +const uint32_t D = HSV; +#endif const uint32_t kvsh_stride = D / 4 + 1; shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; +#ifdef MMQ + +shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1]; +#endif + shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; +#ifdef MMQ +#include "flash_attn_mmq_funcs.glsl" +#endif + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -82,10 +108,39 @@ void main() { [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t r = (idx + tid) / (HSK / 4); - if (r < Br && d < HSK / 4 && - i * Br + r < N) { + const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N; +#ifndef MMQ + if (is_in_bounds) { Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } +#else + const uint buf_ib = r * qf_stride + d / 8; + const uint buf_iqs = d % 8; + + FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f); + const FLOAT_TYPEV4 abs_vals = abs(vals); + + const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8); + const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0); + const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0); + vals = round(vals * qd_inv); + + Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); + +#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } +#else // Q4_0, Q4_1, Q5_0, Q5_1 + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } +#endif +#endif } barrier(); @@ -195,6 +250,7 @@ void main() { if (SHMEM_STAGING != 0) { barrier(); +#ifndef MMQ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); @@ -214,9 +270,29 @@ void main() { kvsh[c * kvsh_stride + d] = K_Tf; } } +#else // MMQ + const uint ints_per_block = 8 / QUANT_R_MMQ; + const uint quant_iters = Bc * HSK / 32 * ints_per_block; + [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { + const uint32_t iqs = (idx + tid) % ints_per_block; + const uint32_t ib = (idx + tid) / ints_per_block; + const uint32_t c = ib / (HSK / 32); + const uint32_t block = ib % (HSK / 32); + if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) { + const uint buf_ib = c * qf_stride + block; + if (!KV_bounds_check || j * Bc + c < KV) { + const uint global_ib = (j * Bc + c) * k_stride + block; + k_block_to_shmem(buf_ib, global_ib, iqs, k_offset); + } else { + k_block_to_shmem_zero(buf_ib, iqs); + } + } + } +#endif // MMQ barrier(); } +#ifndef MMQ // More d iterations means Q register caching becomes relevant // Few iterations means the additional registers needed are worse than the speed-up from caching if (HSK_per_thread / 4 > 4) { @@ -245,7 +321,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); } } } @@ -270,11 +346,115 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); + } + } + } + } +#else // MMQ + const uint hsk4 = HSK_per_thread / 4; + const uint d_per_step = (hsk4 % 8 == 0) ? 8 : + (hsk4 % 4 == 0) ? 4 : + (hsk4 % 2 == 0) ? 2 : 1; + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + [[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) { + int32_t k_quants[d_per_step]; + ACC_TYPEV2 k_dm; + + if (SHMEM_STAGING != 0) { + const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; +#if QUANT_AUXF == 1 + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0); +#else + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + if (d_per_step == 8) { + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = kblocksh[buf_ib].qs[d]; + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); +#endif + } + } else +#endif + { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); + } + } + } else { + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE; + const uint iqs = (coord % BLOCK_SIZE); + +#if QUANT_AUXF == 1 + k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0); +#else + k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset)); +#endif +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + if (d_per_step == 8) { +#if defined(DATA_A_Q5_0) + uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0], + k_packed.k_data_packed16[k_offset + ib].qh[1])); +#elif defined(DATA_A_Q5_1) + uint qh = k_packed.k_data_packed16[k_offset + ib].qh; +#endif + [[unroll]] for (uint32_t d = 0; d < 4; d++) { +#if defined(A_TYPE_PACKED32) + uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d]; +#else + uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], + k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); +#endif + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint qh_lo = (qh >> (d * 4)) & 0xF; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); +#endif + } + } else +#endif + { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); + } + } + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8; + + int32_t acc = 0; + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]); + } + + Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x; + if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) { + Sf[r][c] += k_dot_correction(qib, k_dm); } } } } +#endif // MMQ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 172d38f034e..efed3a73e22 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32; layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; layout (constant_id = 10) const uint32_t Flags = 0; layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; +// ggml_type enumerant for K/V +layout (constant_id = 12) const uint32_t FaTypeK = 0; +layout (constant_id = 13) const uint32_t FaTypeV = 0; +// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4). +layout (constant_id = 14) const uint32_t FaBlockBytesK = 2; +layout (constant_id = 15) const uint32_t FaBlockBytesV = 2; const bool USE_MASK_OPT = (Flags & 1) != 0; const bool MASK_ENABLE = (Flags & 2) != 0; @@ -89,6 +95,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#if defined(A_TYPE_PACKED32) +layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32; +layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32; +#endif + #ifndef BLOCK_SIZE #define BLOCK_SIZE 1 #endif @@ -110,7 +121,11 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 +#elif defined(DATA_A_Q4_1) +#define BLOCK_BYTE_SIZE 20 +#endif +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); @@ -119,7 +134,12 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q4_1 + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); +#endif } else { uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -127,11 +147,100 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q4_1 + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); +#endif } } #endif +#if defined(DATA_A_Q5_0) +#define BLOCK_BYTE_SIZE 22 +#elif defined(DATA_A_Q5_1) +#define BLOCK_BYTE_SIZE 24 +#endif + +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + +#ifdef DATA_A_Q5_1 + uint qh = k_packed.k_data_packed16[a_offset + ib].qh; +#else + uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16); +#endif + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q5_1 + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); +#endif + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + +#ifdef DATA_A_Q5_1 + uint qh = v_packed.v_data_packed16[a_offset + ib].qh; +#else + uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16); +#endif + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q5_1 + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); +#endif + } +} +#endif + + +#if defined(DATA_A_IQ4_NL) +#define BLOCK_BYTE_SIZE 18 + +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( + kvalues_iq4nl[vui_lo & 0xF], + kvalues_iq4nl[(vui_lo >> 8) & 0xF], + kvalues_iq4nl[vui_hi & 0xF], + kvalues_iq4nl[(vui_hi >> 8) & 0xF]); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( + kvalues_iq4nl[vui_lo & 0xF], + kvalues_iq4nl[(vui_lo >> 8) & 0xF], + kvalues_iq4nl[vui_hi & 0xF], + kvalues_iq4nl[(vui_hi >> 8) & 0xF]); + } +} +#endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 0ea181342ce..8a7bbaeb92c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -17,8 +17,57 @@ #extension GL_EXT_null_initializer : enable #include "types.glsl" -#include "dequant_funcs_cm2.glsl" #include "flash_attn_base.glsl" +#include "dequant_funcs_cm2.glsl" + +// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { + uint8_t raw[FaBlockBytesK]; +}; +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V { + uint8_t raw[FaBlockBytesV]; +}; + +uint fa_block_elems(uint ty) { + switch (ty) { + case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32) + case 1u: return 1u; // GGML_TYPE_F16 + case 2u: return uint(QUANT_K_Q4_0); + case 3u: return uint(QUANT_K_Q4_1); + case 6u: return uint(QUANT_K_Q5_0); + case 7u: return uint(QUANT_K_Q5_1); + case 8u: return uint(QUANT_K_Q8_0); + case 41u: return uint(QUANT_K_Q1_0); + default: + return 1u; + } +} + +float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); + case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} + +float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); + case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -55,12 +104,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele return max(elem0, elem1); } -#if BLOCK_SIZE > 1 -#define DECODEFUNC , DEQUANTFUNC -#else -#define DECODEFUNC -#endif - // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) @@ -95,10 +138,6 @@ ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c } void main() { -#ifdef NEEDS_INIT_IQ_SHMEM - init_iq_shmem(gl_WorkGroupSize); -#endif - init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); @@ -107,10 +146,10 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -#if BLOCK_SIZE > 1 - tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); - tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); -#endif + const uint bs_k = fa_block_elems(FaTypeK); + const uint bs_v = fa_block_elems(FaTypeV); + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v); tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); @@ -120,10 +159,12 @@ void main() { if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; -#if BLOCK_SIZE == 1 - k_stride &= ~7; - v_stride &= ~7; -#endif + if (bs_k == 1u) { + k_stride &= ~7; + } + if (bs_v == 1u) { + v_stride &= ~7; + } m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); @@ -230,7 +271,13 @@ void main() { coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); + // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. + const bool k_use_decode = (bs_k > 1u); + if (k_use_decode) { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK); + } else { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); + } S = coopMatMulAdd(Qf16, K_T, S); if (LOGIT_SOFTCAP) { @@ -291,7 +338,12 @@ void main() { coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); + const bool v_use_decode = (bs_v > 1u); + if (v_use_decode) { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV); + } else { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); + } L = eM*L + rowsum; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl new file mode 100644 index 00000000000..e14e62d546a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -0,0 +1,149 @@ +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { +#ifdef DATA_A_Q4_0 + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); +#else + uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; +#endif + + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + return int32_t(vui & 0x0F0F0F0F); +} +#endif + +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { +#ifdef DATA_A_Q5_0 + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0], + k_packed.k_data_packed16[a_offset + ib].qh[1])); +#else + uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed.k_data_packed16[a_offset + ib].qh; +#endif + + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])); +} +#endif + +#if defined(DATA_A_IQ4_NL) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + u8vec4 idx = unpack8(vui & 0x0F0F0F0F); + return pack32(i8vec4(kvalues_iq4nl_const[idx.x], + kvalues_iq4nl_const[idx.y], + kvalues_iq4nl_const[idx.z], + kvalues_iq4nl_const[idx.w])); +} +#endif + +#if QUANT_AUXF == 1 +FLOAT_TYPE get_k_d(uint ib, uint a_offset) { + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d); +} +#else +FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) { + return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm); +} +#endif + +void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { +#if defined(DATA_A_Q4_0) + kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); +#elif defined(DATA_A_Q4_1) + kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; +#elif defined(DATA_A_Q5_0) + kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0], + k_packed.k_data_packed16[a_offset + global_ib].qh[1])); + } +#elif defined(DATA_A_Q5_1) + kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh; + } +#elif defined(DATA_A_Q8_0) + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); +#elif defined(DATA_A_IQ4_NL) + const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y], + kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w])); + kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y], + kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w])); +#endif + + if (iqs == 0) { +#if QUANT_AUXF == 1 + kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d); +#else + kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm); +#endif + } +} + +int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); +#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) + return kblocksh[buf_ib].qs[pos]; +#endif +} + +ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { +#if defined(DATA_A_Q4_0) + return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; +#elif defined(DATA_A_Q5_0) + return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; +#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) + return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; +#else + return ACC_TYPE(0.0); +#endif +} + +void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { + kblocksh[buf_ib].qs[iqs] = 0; +#if defined(DATA_A_IQ4_NL) + kblocksh[buf_ib].qs[iqs + 4] = 0; +#endif + if (iqs == 0) { +#if QUANT_AUXF == 1 + kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f); +#else + kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index f008859b99d..5e9f8308c1d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -1,11 +1,25 @@ #version 450 #extension GL_EXT_control_flow_attributes : require - +#extension GL_KHR_shader_subgroup_basic : enable +#if USE_SUBGROUP_CLUSTERED +#extension GL_KHR_shader_subgroup_clustered : enable +#endif +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0, +// so no bounds checking is needed. layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint KDA = 0; +layout(constant_id = 2) const uint SUBGROUP_SIZE = 32; +layout(constant_id = 3) const uint LANES_PER_COLUMN = 32; + +const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN; +const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Parameters { uint H; @@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; }; layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; }; layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; }; -shared FLOAT_TYPE s_k[S_V]; -shared FLOAT_TYPE s_q[S_V]; -shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i]) +#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED +shared FLOAT_TYPE temp[SUBGROUP_SIZE]; + +// This does a reduction across groups of LANES_PER_COLUMN +FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) { + const uint lane = gl_SubgroupInvocationID; + temp[lane] = partial; + barrier(); + [[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) { + FLOAT_TYPE other = temp[lane ^ s]; + barrier(); + temp[lane] += other; + barrier(); + } + const FLOAT_TYPE result = temp[lane]; + barrier(); + return result; +} +#endif + +// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant +FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) { + switch (LANES_PER_COLUMN) { + case 1u: + return partial; +#if USE_SUBGROUP_CLUSTERED + // Workaround for GLSL requiring a literal constant for the cluster size. + // The branches should all fold away. + case 2u: + return subgroupClusteredAdd(partial, 2u); + case 4u: + return subgroupClusteredAdd(partial, 4u); + case 8u: + return subgroupClusteredAdd(partial, 8u); + case 16u: + return subgroupClusteredAdd(partial, 16u); + case 32u: + return subgroupClusteredAdd(partial, 32u); + case 64u: + return subgroupClusteredAdd(partial, 64u); +#endif + default: +#if USE_SUBGROUP_ADD + return subgroupAdd(partial); +#else + return reduce_add_shmem(partial); +#endif + } +} void main() { const uint head_id = gl_WorkGroupID.x; - const uint seq_id = gl_WorkGroupID.y; - const uint col = gl_LocalInvocationID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN; + const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN); const uint iq1 = head_id % neq1; const uint iq3 = seq_id / rq3; @@ -42,9 +103,9 @@ void main() { const uint state_size = S_V * S_V; const uint state_base = (seq_id * H + head_id) * state_size; - FLOAT_TYPE state[S_V]; - [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]); + FLOAT_TYPE s_shard[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); } uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -53,76 +114,56 @@ void main() { const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; const uint k_off = q_off; const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; - - s_q[col] = FLOAT_TYPE(data_q[q_off + col]); - s_k[col] = FLOAT_TYPE(data_k[k_off + col]); - const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; + const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); - if (KDA != 0) { - const uint g_base = gb_off * S_V; - s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col])); + FLOAT_TYPE k_reg[ROWS_PER_LANE]; + FLOAT_TYPE q_reg[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = FLOAT_TYPE(data_k[k_off + i]); + q_reg[r] = FLOAT_TYPE(data_q[q_off + i]); } - barrier(); - - const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); - const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); - + FLOAT_TYPE g_exp[ROWS_PER_LANE]; if (KDA == 0) { const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off])); - - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - kv_col += dot( - vec4(state[i], state[i+1], state[i+2], state[i+3]), - vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]) - ); - } - - FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val; - - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = g_val * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; - - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + g_exp[r] = g_val; } - - data_dst[attn_off + col] = attn_col * scale; } else { - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - kv_col += dot(gv * sv, kv); + const uint g_base = gb_off * S_V; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i])); } + } + + const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); - FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + FLOAT_TYPE kv_shard = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + kv_shard += g_exp[r] * s_shard[r] * k_reg[r]; + } + FLOAT_TYPE kv_col = reduce_partial(kv_shard); - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = gv * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); - } + FLOAT_TYPE attn_partial = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; + } + FLOAT_TYPE attn_col = reduce_partial(attn_partial); + if (lane == 0) { data_dst[attn_off + col] = attn_col * scale; } attn_off += S_V * H; - barrier(); } - [[unroll]] for (uint i = 0; i < S_V; i++) { - data_dst[s_off + state_base + col * S_V + i] = state[i]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index ba7909c4d38..dc657f3c708 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -1,7 +1,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.glsl" #include "utils.glsl" #if RMS_NORM_ROPE_FUSION #include "rope_params.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 2168989340b..d8fdd8f7b5e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -1,6 +1,5 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -16,4 +15,14 @@ layout (push_constant) uniform parameter uint mode; float alpha; float limit; + uint nb01; + uint nb02; + uint nb03; + uint ne01; + uint ne02; + uint nb11; + uint nb12; + uint nb13; + uint ne11; + uint ne12; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl index 85cf65a9eca..359461306a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl @@ -8,22 +8,32 @@ void main() { const uint row = i / p.ne20; const uint col = i - row * p.ne20; + const uint i3 = row / (p.ne01 * p.ne02); + const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01; + const uint i1 = row % p.ne01; + const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col; + + const uint dst_i3 = row / (p.ne11 * p.ne12); + const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11; + const uint dst_i1 = row % p.ne11; + const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col; + if (p.mode == 0) { // Default const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); } else if (p.mode == 1) { // Swapped const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); } else { // Split - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index db14f5a3cf3..ba4c2103f0c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -3,7 +3,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.glsl" #include "types.glsl" layout (push_constant) uniform parameter @@ -14,7 +13,7 @@ layout (push_constant) uniform parameter uint IW; uint IH; uint OW; uint OH; uint KW; uint KH; - uint pelements; + uint OH_batch; uint CHW; int s0; int s1; int p0; int p1; @@ -35,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void im2col(const uint y, const uint z) { - const uint gidx = gl_GlobalInvocationID.x; +void im2col(const uint ow, const uint z_idx) { + const uint oh = z_idx % p.OH; + const uint batch_idx = z_idx / p.OH; - const uint oh = y; - const uint batch = z / p.IC; - const uint ic = z % p.IC; + const uint gidx = gl_LocalInvocationID.x; + const uint src_batch = batch_idx * p.batch_offset; + const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW; - const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); - const int oh_s1 = int(oh) * p.s1; - const uint ksize = p.OW * p.KH; + const uint KHKW = p.KH * p.KW; - const uint base_linear_idx = gidx * NUM_ITER; + uint wg_x = gl_WorkGroupID.x; + do { + const uint wg_offset = wg_x * 512; - uint current_kx = base_linear_idx / ksize; - const uint rem = base_linear_idx - (current_kx * ksize); - uint current_ky = rem / p.OW; - uint current_ix = rem % p.OW; + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { + const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; - A_TYPE values[NUM_ITER]; - BDA_OFFSET_T offset_dst[NUM_ITER]; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - values[idx] = A_TYPE(0); - } - - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - - const uint linear_idx = base_linear_idx + idx; - - if (linear_idx >= p.pelements) { - continue; - } - - const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; - const uint iih = oh_s1 + current_ky * p.d1 - p.p1; - - offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; - - if ((iih < p.IH) && (iiw < p.IW)) { - values[idx] = data_a[src_base + iih * p.IW + iiw]; - } - - if (++current_ix == p.OW) { - current_ix = 0; - if (++current_ky == p.KH) { - current_ky = 0; - current_kx++; + if (chw_idx >= p.CHW) { + return; } - } - } - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + const uint ic = chw_idx / KHKW; + const uint rem = chw_idx - ic * KHKW; + const uint ky = rem / p.KW; + const uint kx = rem - ky * p.KW; - const uint linear_idx = base_linear_idx + idx; + const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - if (linear_idx >= p.pelements) { - continue; - } + A_TYPE val = A_TYPE(0); + if (iih < p.IH && iiw < p.IW) { + val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + } #if BDA - D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); - dst_addr.d = D_TYPE(values[idx]); + D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); + out_ptr.d = D_TYPE(val); #else - data_d[offset_dst[idx]] = D_TYPE(values[idx]); + data_d[dst_row + chw_idx] = D_TYPE(val); #endif - } + } + + wg_x += gl_NumWorkGroups.x; + } while (wg_x * 512 < p.CHW); } void main() { - uint y = gl_GlobalInvocationID.y; - while (y < p.OH) { + uint ow = gl_GlobalInvocationID.y; + while (ow < p.OW) { uint z = gl_GlobalInvocationID.z; - while (z < p.batch_IC) { - im2col(y, z); + while (z < p.OH_batch) { + im2col(ow, z); z += gl_NumWorkGroups.z; } - y += gl_NumWorkGroups.y; + ow += gl_NumWorkGroups.y; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp index 4bf8b4ca046..93f61fd8543 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -4,7 +4,6 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "rte.glsl" #include "types.glsl" layout (push_constant) uniform parameter diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp index ff2812d3d75..3cda6a63c45 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index 337dbd796ad..e8d053cdd43 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -6,8 +6,8 @@ #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -#if defined(A_TYPE_VEC4) -layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +#if defined(A_TYPEV4) +layout (binding = 0) readonly buffer AV4 {A_TYPEV4 data_a_v4[];}; #endif #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; @@ -17,11 +17,11 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32 #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -#ifdef B_TYPE_VEC2 -layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#ifdef B_TYPEV2 +layout (binding = 1) readonly buffer BV2 {B_TYPEV2 data_b_v2[];}; #endif -#ifdef B_TYPE_VEC4 -layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#ifdef B_TYPEV4 +layout (binding = 1) readonly buffer BV4 {B_TYPEV4 data_b_v4[];}; #endif layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 619de054cb8..975cec8013f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -41,7 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = vec2(data_a[ib0 + i].dm); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 6af5a81587d..93fbacc6282 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 3695b47b98d..54d7e1bcdca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index 6ddbed309d7..bc580aeeb83 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -11,8 +11,8 @@ FLOAT_TYPE get_dm(uint ib) { #endif #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +FLOAT_TYPEV2 get_dm(uint ib) { + return FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif @@ -23,9 +23,9 @@ FLOAT_TYPE get_dm(uint ib) { #endif #if defined(DATA_A_Q2_K) -FLOAT_TYPE_VEC2 get_dm(uint ib) { +FLOAT_TYPEV2 get_dm(uint ib) { const uint ib_k = ib / 8; - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + return FLOAT_TYPEV2(data_a_packed32[ib_k].dm); } #endif @@ -296,15 +296,24 @@ vec2 get_dm_scale(uint ib, uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; const uint is = iqs_k / 8; - u8vec2 scale_dm; - if (is < 4) { - scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); - } else { - scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), - (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); - } - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0], + data_a_packed32[ib_k].scales[1], + data_a_packed32[ib_k].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); + u8vec2 scale_dm = u8vec2(sc, mbyte); + + return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm); } FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { @@ -422,7 +431,7 @@ vec2 get_dm(uint ib, uint iqs) { const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); // the -1 cancels out the bias in iq1s_grid_gpu - return FLOAT_TYPE_VEC2(dl, dl * (delta - 1)); + return FLOAT_TYPEV2(dl, dl * (delta - 1)); } FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 23f3bd8d6d0..89346e48e06 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -125,8 +125,8 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit #define SHMEM_STRIDE (BK / 2 + 1) #endif -shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; -shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; +shared FLOAT_TYPEV2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPEV2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -258,17 +258,17 @@ void main() { sums[i] = coopmat(0.0f); } #else - ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; + ACC_TYPEV2 sums[WMITER * TM * WNITER * TN/2]; #if defined(DATA_A_F32) || defined(DATA_A_F16) - FLOAT_TYPE_VEC4 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC4 cache_b; + FLOAT_TYPEV4 cache_a[WMITER * TM]; + FLOAT_TYPEV4 cache_b; #else - FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC2 cache_b; + FLOAT_TYPEV2 cache_a[WMITER * TM]; + FLOAT_TYPEV2 cache_b; #endif [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { - sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); + sums[i] = ACC_TYPEV2(0.0f, 0.0f); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index ce7f2d699a2..73595168984 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -3,7 +3,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #if LOAD_VEC_A == 8 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]); + FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); buf_a[buf_idx ] = aa[0].xy; buf_a[buf_idx + 1] = aa[0].zw; buf_a[buf_idx + 2] = aa[1].xy; @@ -11,38 +11,38 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #elif LOAD_VEC_A == 4 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; #else // LOAD_VEC_BATCH_A == 2 const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], - data_a[idx + 1]); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], + data_a[idx + 1]); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f); } else { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif #elif defined(DATA_A_BF16) #if LOAD_VEC_A == 4 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; #else // LOAD_VEC_BATCH_A == 2 const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), - TO_FLOAT_TYPE(data_a[idx + 1])); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), + TO_FLOAT_TYPE(data_a[idx + 1])); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); } else { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif #elif defined(DATA_A_Q4_0) @@ -57,10 +57,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); - buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -73,10 +73,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); - buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); - buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy); - buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy); + buf_a[buf_idx + 1 ] = FLOAT_TYPEV2(v0.zw); + buf_a[buf_idx + 8 ] = FLOAT_TYPEV2(v1.xy); + buf_a[buf_idx + 9 ] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -92,8 +92,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a_packed16[ib].qs[iqs]); const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -112,10 +112,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw); - buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xz); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v1.xz); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v0.yw); + buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -128,8 +128,22 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); +#elif defined(DATA_A_Q1_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 16; + const uint iqs = idx & 0xfu; + + const float d = float(data_a[ib].d); + const uint bits = uint(data_a[ib].qs[iqs]); + + buf_a[buf_idx ] = FLOAT_TYPEV2((bits & 0x01u) != 0u ? d : -d, (bits & 0x02u) != 0u ? d : -d); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((bits & 0x04u) != 0u ? d : -d, (bits & 0x08u) != 0u ? d : -d); + buf_a[buf_idx + 2] = FLOAT_TYPEV2((bits & 0x10u) != 0u ? d : -d, (bits & 0x20u) != 0u ? d : -d); + buf_a[buf_idx + 3] = FLOAT_TYPEV2((bits & 0x40u) != 0u ? d : -d, (bits & 0x80u) != 0u ? d : -d); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -147,8 +161,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -171,8 +185,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy); const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x), - dl * (qs.y - hm.y)); + buf_a[buf_idx] = FLOAT_TYPEV2(dl * (qs.x - hm.x), + dl * (qs.y - hm.y)); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -187,27 +201,28 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); + buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -223,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -244,8 +260,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; const vec4 q = vec4(unpack8(qs | qh)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); + buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -267,7 +283,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303; const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale; - buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y); + buf_a[buf_idx] = FLOAT_TYPEV2(q.x, q.y); #elif defined(DATA_A_IQ1_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -284,8 +300,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); [[unroll]] for (int k = 0; k < 4; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), - dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; @@ -306,8 +322,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); [[unroll]] for (int k = 0; k < 4; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), - dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; @@ -332,14 +348,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -358,14 +374,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -386,14 +402,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -414,10 +430,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint grid = iq3xxs_grid[qs]; const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, - (sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, - (sign & 8) != 0 ? -v.w : v.w); + buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -436,27 +452,28 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, - (sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, - (sign & 8) != 0 ? -v.w : v.w); + buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint iq = 16 * ib32 + 2 * (idx % 8); + const uint ib = idx / 64; // 4 values per idx + const uint ib32 = (idx % 64) / 8; // 0..7 + const uint iq = 4 * ib32 + (idx % 4); const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; - const uint qshift = (idx & 8) >> 1; - u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy; + const uint qshift = idx & 4; + u8vec4 qs = unpack8((uint(data_a_packed32[ib].qs[iq]) >> qshift) & 0x0F0F0F0F); const float d = float(data_a[ib].d); - const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -467,10 +484,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF], - kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); - buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], - kvalues_iq4nl[vui >> 12]); + buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF], + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); + buf_a[buf_idx + 8] = d * FLOAT_TYPEV2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], + kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -482,10 +499,27 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d, - kvalues_mxfp4[vui2 & 0xF] * d); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d, - kvalues_mxfp4[vui2 >> 4] * d); + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); +#elif defined(DATA_A_NVFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + // lo and hi nibbles are 8 elements apart, which doesn't quite line up with + // how the thread mapping and buf_idx calculation works for other types. + const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2; + + const uint ib = idx / 16u; + const uint sub = (idx & 0xC) >> 2; + const uint iqs = (idx & 0xF) * 2; + const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5; + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); #endif } @@ -495,7 +529,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin // Not supported for b_type bf16 because bf16mat2x4 does not exist const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); buf_b[buf_idx + 0] = bb[0].xy; buf_b[buf_idx + 1] = bb[0].zw; buf_b[buf_idx + 2] = bb[1].xy; @@ -504,9 +538,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; @@ -514,12 +548,12 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + col * p.stride_b + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_n < p.N && block + row * 2 + 1 < end_k) { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); } else if (idx_n < p.N && block + row * 2 < end_k) { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif } @@ -530,7 +564,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); buf_b[buf_idx + 0] = bb[0].xy; buf_b[buf_idx + 1] = bb[0].zw; buf_b[buf_idx + 2] = bb[1].xy; @@ -540,9 +574,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; @@ -552,14 +586,14 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin if (row_i < _ne1 && block + row * 2 + 1 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); } else if (row_i < _ne1 && block + row * 2 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 9c297d1c60d..59931b04b94 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -21,7 +21,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif } @@ -72,7 +72,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); buf_a[buf_ib].qh = data_a_packed32[ib].qh; } #endif @@ -203,7 +203,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib_k].dm); buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147 } } @@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147 - buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); + buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); } } @@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); } - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); + buf_a[buf_ib].dm = FLOAT_TYPEV2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); } } @@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint is = iqs_k / 4; const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy; - buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales)); + buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales)); } } @@ -426,7 +426,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo const uint ib_inner = ib % 4; if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + buf_b[buf_ib].ds = FLOAT_TYPEV2(data_b[ib_outer].ds[ib_inner]); } const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; @@ -436,7 +436,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; } else { if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_ib].ds = FLOAT_TYPEV2(0.0f); } buf_b[buf_ib].qs[iqs * 4 ] = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 1c0f5306f38..10552d013a2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -8,7 +8,7 @@ struct block_a_cache { #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q5_0) #define QUANT_R_MMQ 2 @@ -22,7 +22,7 @@ struct block_a_cache { struct block_a_cache { uint32_t qs[16/4]; uint32_t qh; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q8_0) #define QUANT_R_MMQ 1 @@ -32,6 +32,12 @@ struct block_a_cache { int32_t qs[32/4]; FLOAT_TYPE dm; }; +#elif defined(DATA_A_IQ4_NL) +#define QUANT_R_MMQ 2 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE dm; +}; #elif defined(DATA_A_MXFP4) #define QUANT_R_MMQ 2 struct block_a_cache { @@ -43,36 +49,36 @@ struct block_a_cache { struct block_a_cache { uint32_t qs[2]; u8vec2 scales; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q3_K) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[4]; - FLOAT_TYPE_VEC2 d_scales; + FLOAT_TYPEV2 d_scales; }; #elif defined(DATA_A_Q4_K) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[4]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q5_K) #define QUANT_R_MMQ 1 struct block_a_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q6_K) #define QUANT_R_MMQ 1 struct block_a_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 d_scales; + FLOAT_TYPEV2 d_scales; }; #endif struct block_b_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 ds; + FLOAT_TYPEV2 ds; }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 10cf5202a4a..26d194e9e8d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -8,7 +8,6 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "rte.glsl" #include "types.glsl" #include "utils.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index d9b4d4c03f3..51a127bcd87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -2,7 +2,6 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.glsl" #include "rope_params.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index ec6ceaca9bd..2e2a7e14c66 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -1,8 +1,6 @@ #if !defined(GGML_ROPE_PARAMS) #define GGML_ROPE_PARAMS -#include "rte.glsl" - struct rope_params { uint rope_mode; uint nrows; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl deleted file mode 100644 index ad51c1e80b8..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +++ /dev/null @@ -1,5 +0,0 @@ - -#if RTE16 -#extension GL_EXT_spirv_intrinsics : enable -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif // RTE16 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp index e18d0ffa307..f9b78f96072 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index bdb2c09259b..4bcd97756fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -188,6 +188,22 @@ struct block_q8_0_packed16 #define DATA_A_QUANT_LEGACY #endif +#define QUANT_K_Q1_0 128 +#define QUANT_R_Q1_0 1 + +struct block_q1_0 +{ + float16_t d; + uint8_t qs[QUANT_K_Q1_0 / 8]; +}; + +#if defined(DATA_A_Q1_0) +#define QUANT_K QUANT_K_Q1_0 +#define QUANT_R QUANT_R_Q1_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q1_0 +#endif + #define QUANT_K_Q8_1 32 #define QUANT_R_Q8_1 1 @@ -1676,6 +1692,7 @@ struct block_iq4_nl_packed16 #if defined(DATA_A_IQ4_NL) #define QUANT_K QUANT_K_IQ4_NL #define QUANT_R QUANT_R_IQ4_NL +#define QUANT_AUXF 1 #define A_TYPE block_iq4_nl #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif @@ -1696,6 +1713,22 @@ struct block_mxfp4 #define A_TYPE block_mxfp4 #endif +#define QUANT_K_NVFP4 64 +#define QUANT_R_NVFP4 1 + +struct block_nvfp4 +{ + uint8_t d[QUANT_K_NVFP4 / 16]; + uint8_t qs[QUANT_K_NVFP4 / 2]; +}; + +#if defined(DATA_A_NVFP4) +#define QUANT_K QUANT_K_NVFP4 +#define QUANT_R QUANT_R_NVFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_nvfp4 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), @@ -1715,7 +1748,7 @@ void init_iq_shmem(uvec3 wgsize) } #endif -#if defined(DATA_A_MXFP4) +#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) const int8_t kvalues_mxfp4_const[16] = { int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), @@ -1723,6 +1756,24 @@ const int8_t kvalues_mxfp4_const[16] = { shared int8_t kvalues_mxfp4[16]; +#if defined(DATA_A_NVFP4) +// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero. +shared float ue4m3_fp32_lut[128]; + +float ue4m3_to_fp32_build(uint u) { + if (u == 0u || u == 127u) { + return 0.0; + } + const uint exp = (u >> 3) & 15u; + const uint man = u & 7u; + if (exp == 0u) { + return float(man) * (1.0 / 512.0); + } + const uint bits = (exp + 120u) << 23 | (man << 20); + return uintBitsToFloat(bits); +} +#endif + #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { @@ -1730,6 +1781,11 @@ void init_iq_shmem(uvec3 wgsize) for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; } +#if defined(DATA_A_NVFP4) + for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) { + ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i); + } +#endif barrier(); } #endif @@ -1766,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) { return uintBitsToFloat(bits); } +#if defined(DATA_A_NVFP4) +float ue4m3_to_fp32(uint8_t x) { + return ue4m3_fp32_lut[uint(x)]; +} +#endif + #if BDA #extension GL_EXT_buffer_reference : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4b00ba3debb..6f2a929c40c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -45,6 +45,7 @@ std::string target_cpp = ""; const std::vector type_names = { "f32", "f16", + "q1_0", "q4_0", "q4_1", "q5_0", @@ -65,6 +66,7 @@ const std::vector type_names = { "iq4_xs", "iq4_nl", "mxfp4", + "nvfp4", "bf16", }; @@ -137,6 +139,7 @@ void execute_command(std::vector& command, std::string& stdout_str, pid_t pid = fork(); if (pid < 0) { + std::cerr << strerror(errno) << "\n"; throw std::runtime_error("Failed to fork process"); } @@ -404,8 +407,8 @@ std::map merge_maps(const std::map> compiles; -void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix; std::string out_path = join_paths(output_dir, name + ".spv"); if (input_filepath == "") { @@ -445,8 +448,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c base_dict["FLOAT16"] = "1"; } - base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; - base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPEV2"] = f16acc ? "f16vec2" : "vec2"; if (f16acc) { base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } @@ -513,10 +516,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; const std::map float_type_dict_f16 = { - {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")}, - {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")}, + {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, "f16")}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, "f16")}, + {"FLOAT_TYPEV8", FLOAT_TYPE(8, "f16")}, }; // Shaders with f16 B_TYPE @@ -535,9 +538,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; const std::map float_type_dict_bf16 = { - {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")}, + {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, "bf16")}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, "bf16")}, }; // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader @@ -552,9 +555,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4")) load_vec_quant = "4"; if (tname == "bf16") { @@ -568,10 +571,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; const std::map float_type_dict = { - {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)}, - {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)}, + {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, tname)}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, tname)}, + {"FLOAT_TYPEV8", FLOAT_TYPE(8, tname)}, }; // don't generate f32 variants for coopmat2 @@ -623,39 +626,37 @@ void process_shaders() { for (const bool& fp16 : {false, true}) { std::map base_dict; if (fp16) { - base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; } else { - base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}}; + base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}}; } // flash attention for (const bool& f16acc : {false, true}) { std::map fa_base_dict = base_dict; fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2"; fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; if (fp16 && f16acc) { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } + if (fp16) { +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); +#endif + } + for (const auto& tname : type_names) { if (tname == "bf16") continue; if (fp16) { -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); - } else { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); - } -#endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); @@ -666,45 +667,51 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (tname != "f32") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8"); + } +#endif } } } } - std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; - string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif @@ -725,9 +732,9 @@ void process_shaders() { string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}}); - string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}); // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -735,7 +742,7 @@ void process_shaders() { string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}})); - string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}})); + string_to_spv("rms_norm_mul_rope_f32_f16", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -757,17 +764,14 @@ void process_shaders() { string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); - for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } - for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { - string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); - string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } auto get_type_str = [](bool f16) { @@ -784,12 +788,10 @@ void process_shaders() { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { - for (auto rte : {false, true}) { auto source = op == "add_rms" ? std::string("add") : op; - auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); auto add_rms = op == "add_rms" ? "1" : "0"; - string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); - } + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , add_rms}}); } } } @@ -837,14 +839,11 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - for (auto rte : {false, true}) { - std::string suffix = rte ? "_rte" : ""; - string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - } + string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -887,6 +886,7 @@ void process_shaders() { string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("fill_f16", "fill.comp", {{"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -898,21 +898,18 @@ void process_shaders() { string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - for (auto rte : {false, true}) { - std::string suffix = rte ? "_rte" : ""; - string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - } + string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_oai_f16", "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_oai_f32", "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -932,25 +929,18 @@ void process_shaders() { string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); @@ -973,7 +963,6 @@ void process_shaders() { std::string bda_def = bda ? "1" : "0"; string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}})); string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}})); - string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}})); } } @@ -987,7 +976,9 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}})); + string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); + string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); @@ -1024,8 +1015,8 @@ void process_shaders() { string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); - string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "1"}}); string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); @@ -1078,8 +1069,8 @@ void write_output_files() { std::string suffixes[2] = {"_f32", "_f16"}; for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) { - hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; - hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + hdr << "extern const void * " << op << "_data[2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2];\n"; std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; if (basename(input_filepath) != op_file) { @@ -1087,8 +1078,8 @@ void write_output_files() { } std::stringstream data = make_generic_stringstream(); std::stringstream len = make_generic_stringstream(); - data << "const void * " << op << "_data[2][2][2][2] = "; - len << "const uint64_t " << op << "_len[2][2][2][2] = "; + data << "const void * " << op << "_data[2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2] = "; for (uint32_t t0 = 0; t0 < 2; ++t0) { if (t0 == 0) { data << "{"; @@ -1104,20 +1095,10 @@ void write_output_files() { data << "{"; len << "{"; } - for (uint32_t rte = 0; rte < 2; ++rte) { - if (rte == 0) { - data << "{"; - len << "{"; - } - data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); - len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); - data << "_data,"; - len << "_len,"; - if (rte == 1) { - data << "}, "; - len << "}, "; - } - } + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2]; + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2]; + data << "_data,"; + len << "_len,"; if (t2 == 1) { data << "}, "; len << "}, "; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3d7e59fddf3..cff93b8d170 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,6 +1,7 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" @@ -26,36 +27,30 @@ // Matrix multiplication parameters // Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 8 -#define WEBGPU_MUL_MAT_TILE_N 8 -#define WEBGPU_MUL_MAT_WG_SIZE_M 8 -#define WEBGPU_MUL_MAT_WG_SIZE_N 8 -#define WEBGPU_MUL_MAT_TILE_K 32 +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 +#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 // Subgroup matrix parameters // The number of subgroups in the M dimension -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_N 4 // The number of subgroup matrices each subgroup accumulates over -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide -// mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 - -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 - -// Requires 32 threads per output (wg_size/outputs_per_wg == 32) -#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 -// Requires at least two (and multiple of 2) k-quant blocks per tile -#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 @@ -65,19 +60,33 @@ template inline void ggml_webgpu_hash_combine(size_t & seed, const seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +// Calculates base address of a tensor ignoring the fake base pointer +inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (uintptr_t) base_tensor->data + tensor->view_offs; +} + +inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b); +} + +inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) && + ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a); +} + struct ggml_webgpu_shader_lib_context { ggml_tensor * src0; ggml_tensor * src1; ggml_tensor * src2; ggml_tensor * src3; ggml_tensor * src4; + ggml_tensor * src5; ggml_tensor * dst; uint32_t max_wg_size; size_t wg_mem_limit_bytes = 0; - bool inplace = false; - bool overlap = false; - bool src_overlap = false; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -93,6 +102,51 @@ struct webgpu_pipeline { struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; + bool inplace = false; +}; + +struct ggml_webgpu_binary_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; +}; + +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + std::shared_ptr decisions; +}; + +struct ggml_webgpu_ssm_conv_shader_decisions { + uint32_t block_size; + uint32_t tokens_per_wg; +}; + +struct ggml_webgpu_ssm_scan_pipeline_key { + int type; + int d_state; + bool xbc_overlap; + + bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { + return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap; + } +}; + +struct ggml_webgpu_ssm_scan_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.d_state); + ggml_webgpu_hash_combine(seed, key.xbc_overlap); + return seed; + } +}; + +struct ggml_webgpu_ssm_scan_shader_decisions { + uint32_t wg_size; + uint32_t tokens_per_tile; + bool xbc_overlap = false; }; /** Argsort **/ @@ -131,6 +185,26 @@ struct ggml_webgpu_set_rows_shader_decisions { uint32_t wg_size; }; +/** Set **/ + +struct ggml_webgpu_set_pipeline_key { + ggml_type type; + bool inplace; + + bool operator==(const ggml_webgpu_set_pipeline_key & other) const { + return type == other.type && inplace == other.inplace; + } +}; + +struct ggml_webgpu_set_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Get Rows **/ struct ggml_webgpu_get_rows_pipeline_key { @@ -151,6 +225,55 @@ struct ggml_webgpu_get_rows_pipeline_key_hash { } }; +/** Row Norm **/ + +struct ggml_webgpu_row_norm_pipeline_key { + ggml_op op; + bool inplace; + + bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { + return op == other.op && inplace == other.inplace; + } +}; + +struct ggml_webgpu_row_norm_pipeline_key_hash { + size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** RMS_NORM + MUL **/ + +struct ggml_webgpu_rms_norm_mul_pipeline_key { + bool inplace; // rn_src == dst + bool overlap; // mul_src == dst + bool src_overlap; // rn_src == mul_src + + bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { + return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; + } +}; + +struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + ggml_webgpu_hash_combine(seed, key.src_overlap); + return seed; + } +}; + +struct ggml_webgpu_rms_norm_mul_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -166,6 +289,107 @@ struct ggml_webgpu_pad_pipeline_key_hash { } }; +/** Solve Tri **/ +struct ggml_webgpu_solve_tri_pipeline_key { + int type; + int n; + int k; + + bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const { + return type == other.type && n == other.n && k == other.k; + } +}; + +struct ggml_webgpu_solve_tri_pipeline_key_hash { + size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.n); + ggml_webgpu_hash_combine(seed, key.k); + return seed; + } +}; + +/** SSM Conv **/ +struct ggml_webgpu_ssm_conv_pipeline_key { + int type; + int vectorized; + + bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const { + return type == other.type && vectorized == other.vectorized; + } +}; + +/** CONV 2D */ +struct ggml_webgpu_conv2d_pipeline_key { + ggml_type weight_type; + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const { + return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_conv2d_pipeline_key_hash { + size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.weight_type); + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + +/** Im2Col **/ +struct ggml_webgpu_im2col_pipeline_key { + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_im2col_pipeline_key_hash { + size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + +/** Gated Delta Net **/ +struct ggml_webgpu_gated_delta_net_pipeline_key { + int type; + int s_v; + int kda; + + bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const { + return type == other.type && s_v == other.s_v && kda == other.kda; + } +}; + +struct ggml_webgpu_gated_delta_net_pipeline_key_hash { + size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.s_v); + ggml_webgpu_hash_combine(seed, key.kda); + return seed; + } +}; + +struct ggml_webgpu_ssm_conv_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + /** Scale **/ struct ggml_webgpu_scale_pipeline_key { @@ -182,6 +406,31 @@ struct ggml_webgpu_scale_pipeline_key_hash { } }; +/** Upscale **/ + +struct ggml_webgpu_upscale_pipeline_key { + ggml_type input_type; + ggml_type output_type; + uint32_t base_mode; + bool antialias; + + bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode && + antialias == other.antialias; + } +}; + +struct ggml_webgpu_upscale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + ggml_webgpu_hash_combine(seed, key.base_mode); + ggml_webgpu_hash_combine(seed, key.antialias); + return seed; + } +}; + /** Concat **/ struct ggml_webgpu_concat_pipeline_key { @@ -244,13 +493,15 @@ struct ggml_webgpu_binary_pipeline_key_hash { /** Unary **/ struct ggml_webgpu_unary_pipeline_key { - int type; - int op; - bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella - bool inplace; + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + ggml_tri_type ttype; // only used for GGML_OP_TRI bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { - return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace && + ttype == other.ttype; } }; @@ -261,25 +512,35 @@ struct ggml_webgpu_unary_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.is_unary); ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.ttype); return seed; } }; /** FlashAttention */ +enum ggml_webgpu_flash_attn_path : uint32_t { + GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u, + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u, + GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u, + GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u, +}; + struct ggml_webgpu_flash_attn_pipeline_key { ggml_type kv_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; + bool kv_overlap; bool has_mask; bool has_sinks; bool uses_logit_softcap; + uint32_t path; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && - kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; @@ -290,26 +551,105 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.path); return seed; } }; -struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_webgpu_flash_attn_pipeline_key key; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; +struct ggml_webgpu_flash_attn_decisions { + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; + bool kv_direct = false; + bool kv_overlap = false; }; -struct ggml_webgpu_flash_attn_shader_decisions { - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; + +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || + key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + return 2u; + case 96: + return 4u; + default: + return 1u; + } +} + +inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( + const ggml_webgpu_shader_lib_context & context, + uint32_t path) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + bool kv_direct = false; + if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; + if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + kv_direct_align = context.sg_mat_k; + } + kv_direct = (context.src1->type == GGML_TYPE_F16) && + (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + } + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.kv_type = context.src1->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = kv_direct; + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = has_mask; + key.has_sinks = has_sinks; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + key.path = path; + return key; +} + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { + uint32_t head_dim_v; + uint32_t wg_size; +}; + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.wg_size); + return seed; + } +}; + +inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; +} + +struct ggml_webgpu_flash_attn_blk_pipeline_key { + uint32_t kv_tile; + + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; } +}; + +struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.kv_tile); + return seed; + } }; // This is exposed because it's necessary in supports_op @@ -336,6 +676,124 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_pipeline_key & key) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_granularity = context.sg_mat_n; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + kv_granularity = std::max(1u, context.max_subgroup_size); + } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + q_tile = 1u; + kv_granularity = 8u; + } + const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!key.kv_direct) { + bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + } + if (key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / kv_granularity) * kv_granularity; +} + +inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( + const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + ggml_webgpu_flash_attn_decisions decisions = {}; + const size_t alignment = std::max(1u, storage_offset_alignment); + const auto * K = context.src1; + const auto * V = context.src2; + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast(base->data) - ptr_base_addr + tensor->view_offs; + }; + + const uint32_t k_offset_elems = + (uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && + (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); + const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && + V->type == GGML_TYPE_F16 && f16_vec4_aligned && + (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; + + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + return decisions; + } + + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + decisions.kv_direct = key.kv_direct; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + // invalidate if even the smallest kv_tile doesn't fit in shared memory + if (max_kv_tile == 0) { + decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + return decisions; + } + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + decisions.q_tile = 1u; + decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); + decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; + decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + if (decisions.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= 8u; + } + } + return decisions; + } + + decisions.q_tile = + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; + decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::min(64u, max_kv_tile) : + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size); + decisions.kv_tile = + std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity); + } + + if (decisions.kv_direct) { + GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::max(1u, context.max_subgroup_size) : + context.sg_mat_n; + } + } + return decisions; +} + /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { @@ -378,7 +836,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t wg_size; - uint32_t tile_k; uint32_t outputs_per_wg; uint32_t vec_size; }; @@ -426,6 +883,120 @@ struct ggml_webgpu_mul_mat_shader_decisions { uint32_t mul_mat_wg_size; }; +/** MUL_MAT_ID **/ + +struct ggml_webgpu_mul_mat_id_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + uint32_t n_experts; + int vectorized; + + bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && + vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_mul_mat_id_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +/** Cpy **/ + +struct ggml_webgpu_cpy_pipeline_key { + ggml_type src_type; + ggml_type dst_type; + + bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const { + return src_type == other.src_type && dst_type == other.dst_type; + } +}; + +struct ggml_webgpu_cpy_pipeline_key_hash { + size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + return seed; + } +}; + +/** Glu **/ + +struct ggml_webgpu_glu_pipeline_key { + ggml_glu_op glu_op; + ggml_type type; + bool split; + + bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { + return glu_op == other.glu_op && type == other.type && split == other.split; + } +}; + +struct ggml_webgpu_glu_pipeline_key_hash { + size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.glu_op); + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.split); + return seed; + } +}; + +/** Rope **/ + +struct ggml_webgpu_rope_pipeline_key { + ggml_type type; + bool inplace; + bool has_ff; + + bool operator==(const ggml_webgpu_rope_pipeline_key & other) const { + return type == other.type && inplace == other.inplace && has_ff == other.has_ff; + } +}; + +struct ggml_webgpu_rope_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.has_ff); + return seed; + } +}; + +/** SoftMax **/ + +struct ggml_webgpu_soft_max_pipeline_key { + ggml_type mask_type; + bool has_mask; + bool has_sink; + bool inplace; + + bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const { + return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink && + inplace == other.inplace; + } +}; + +struct ggml_webgpu_soft_max_pipeline_key_hash { + size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.mask_type); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sink); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -435,22 +1006,43 @@ class ggml_webgpu_shader_lib { std::unordered_map argsort_pipelines; // key is order std::unordered_map argsort_merge_pipelines; // key is order std::unordered_map cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map + row_norm_pipelines; // op/inplace + std::unordered_map - get_rows_pipelines; // src_type, vectorized + get_rows_pipelines; // src_type, vectorized std::unordered_map - unary_pipelines; // type/op/inplace + unary_pipelines; // type/op/inplace std::unordered_map - scale_pipelines; // inplace + scale_pipelines; // inplace + std::unordered_map + solve_tri_pipelines; // type + std::unordered_map + ssm_conv_pipelines; // type/vectorized + std::unordered_map + ssm_scan_pipelines; // type/d_state + std::unordered_map + gated_delta_net_pipelines; // type/S_v/kda std::unordered_map - pad_pipelines; // circular/non-circular + pad_pipelines; // circular/non-circular std::unordered_map - binary_pipelines; // type/op/inplace/overlap + binary_pipelines; // type/op/inplace/overlap std::unordered_map - concat_pipelines; // type + concat_pipelines; // type std::unordered_map - repeat_pipelines; // type + repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_reduce_pipelines; + std::unordered_map + flash_attn_blk_pipelines; std::unordered_map @@ -458,10 +1050,33 @@ class ggml_webgpu_shader_lib { std::unordered_map mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map - mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + std::unordered_map mul_mat_id_gather_pipelines; // key is fixed + std::unordered_map + mul_mat_id_pipelines; // src0_type/src1_type + std::unordered_map + mul_mat_id_vec_pipelines; // src0_type/src1_type std::unordered_map set_rows_pipelines; + std::unordered_map set_pipelines; + std::unordered_map cpy_pipelines; + std::unordered_map glu_pipelines; + std::unordered_map + rope_pipelines; + std::unordered_map + soft_max_pipelines; + std::unordered_map + conv2d_pipelines; + std::unordered_map + im2col_pipelines; + + std::unordered_map + rms_norm_mul_pipelines; + std::unordered_map + upscale_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -479,8 +1094,50 @@ class ggml_webgpu_shader_lib { return sum_rows_pipelines[1]; } - webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { - bool vec4 = context.src0->ne[0] % 4 == 0; + webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_row_norm_pipeline_key key = {}; + key.op = context.dst->op; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = row_norm_pipelines.find(key); + if (it != row_norm_pipelines.end()) { + return it->second; + } + std::vector defines; + std::string variant; + + switch (key.op) { + case GGML_OP_RMS_NORM: + defines.push_back("RMS_NORM"); + variant = "rms_norm"; + break; + case GGML_OP_L2_NORM: + defines.push_back("L2_NORM"); + variant = "l2_norm"; + break; + default: + GGML_ABORT("Unsupported op for row_norm shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + const uint32_t row_norm_wg_size = 128u; + uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + auto processed = preprocessor.preprocess(wgsl_row_norm, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->inplace = key.inplace; + row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + row_norm_pipelines[key].context = decisions; + return row_norm_pipelines[key]; + } + + webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool vec4 = context.src0->ne[0] % 4 == 0; auto it = argmax_pipelines.find(vec4); if (it != argmax_pipelines.end()) { @@ -500,9 +1157,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, - .vec4 = context.src0->ne[0] % 4 == 0, - .i64_idx = context.src1->type == GGML_TYPE_I64 }; + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -546,6 +1204,49 @@ class ggml_webgpu_shader_lib { return set_rows_pipelines[key]; } + webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_set_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = set_pipelines.find(key); + if (it != set_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "set"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for set shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_set, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + set_pipelines[key] = pipeline; + return set_pipelines[key]; + } + webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = cumsum_pipelines.find(1); if (it != cumsum_pipelines.end()) { @@ -614,10 +1315,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; - ggml_webgpu_get_rows_pipeline_key key = { - .src_type = context.src0->type, - .vectorized = (int) vectorized, - }; + ggml_webgpu_get_rows_pipeline_key key = {}; + key.src_type = context.src0->type; + key.vectorized = (int) vectorized; auto it = get_rows_pipelines.find(key); if (it != get_rows_pipelines.end()) { @@ -632,6 +1332,7 @@ class ggml_webgpu_shader_lib { switch (key.src_type) { case GGML_TYPE_F32: + defines.push_back("FLOAT_PARALLEL"); if (key.vectorized) { defines.push_back("F32_VEC"); defines.push_back("SRC_TYPE=vec4"); @@ -646,6 +1347,7 @@ class ggml_webgpu_shader_lib { variant += "_f32"; break; case GGML_TYPE_F16: + defines.push_back("FLOAT_PARALLEL"); defines.push_back("F16"); defines.push_back("SRC_TYPE=f16"); defines.push_back("DST_TYPE=f32"); @@ -653,6 +1355,7 @@ class ggml_webgpu_shader_lib { variant += "_f16"; break; case GGML_TYPE_I32: + defines.push_back("FLOAT_PARALLEL"); defines.push_back("I32"); defines.push_back("SRC_TYPE=i32"); defines.push_back("DST_TYPE=i32"); @@ -664,6 +1367,32 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + switch (key.src_type) { + case GGML_TYPE_Q1_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } + } + defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); @@ -674,11 +1403,12 @@ class ggml_webgpu_shader_lib { variant += "_"; variant += type_str; - defines.push_back(std::string("SRC_TYPE=") + type_str); defines.push_back("DST_TYPE=f32"); - if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || - key.src_type == GGML_TYPE_IQ4_NL) { + if (key.src_type == GGML_TYPE_Q1_0) { + defines.push_back("BLOCK_SIZE=128u"); + } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || + key.src_type == GGML_TYPE_IQ4_NL) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { defines.push_back("BLOCK_SIZE=256u"); @@ -705,7 +1435,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; + ggml_webgpu_scale_pipeline_key key = {}; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -725,14 +1456,189 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_scale, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; scale_pipelines[key] = pipeline; return scale_pipelines[key]; } + webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_solve_tri_pipeline_key key = {}; + key.type = context.dst->type; + key.n = (int) context.src0->ne[0]; + key.k = (int) context.src1->ne[0]; + + auto it = solve_tri_pipelines.find(key); + if (it != solve_tri_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "solve_tri"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for solve_tri shader"); + } + + const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size); + const uint32_t k_tile = wg_size; + const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES; + const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row); + + defines.push_back(std::string("N=") + std::to_string(key.n)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("K_TILE=") + std::to_string(k_tile)); + defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n)); + + auto processed = preprocessor.preprocess(wgsl_solve_tri, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + solve_tri_pipelines[key] = pipeline; + return solve_tri_pipelines[key]; + } + + webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_conv_pipeline_key key = {}; + key.type = context.dst->type; + key.vectorized = context.src1->ne[0] == 4; + + auto it = ssm_conv_pipelines.find(key); + if (it != ssm_conv_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "ssm_conv"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_conv shader"); + } + + if (key.vectorized) { + defines.push_back("VECTORIZED"); + variant += "_vec4"; + } + + constexpr uint32_t block_size = 32u; + constexpr uint32_t tokens_per_wg = 8u; + + defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u"); + defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u"); + + auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines); + auto decisions = std::make_shared(); + decisions->block_size = block_size; + decisions->tokens_per_wg = tokens_per_wg; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_conv_pipelines[key] = pipeline; + return ssm_conv_pipelines[key]; + } + + webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_scan_pipeline_key key = {}; + key.type = context.dst->type; + key.d_state = (int) context.src0->ne[0]; + key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && + ggml_webgpu_tensor_overlap(context.src1, context.src5); + + auto it = ssm_scan_pipelines.find(key); + if (it != ssm_scan_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "ssm_scan"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_scan shader"); + } + + const uint32_t wg_size = (uint32_t) key.d_state; + + constexpr uint32_t tokens_per_tile = 4u; + + defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u"); + defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u"); + + if (context.supports_subgroups) { + defines.push_back("USE_SUBGROUP_REDUCTION"); + variant += "_sg_reduce"; + } else { + variant += "_wg_reduce"; + } + + if (key.xbc_overlap) { + defines.push_back("XBC_OVERLAP"); + } + + variant += "_d" + std::to_string(key.d_state); + + auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->tokens_per_tile = tokens_per_tile; + decisions->xbc_overlap = key.xbc_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_scan_pipelines[key] = pipeline; + return ssm_scan_pipelines[key]; + } + + webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_gated_delta_net_pipeline_key key = {}; + key.type = context.dst->type; + key.s_v = (int) context.src2->ne[0]; + key.kda = context.src3->ne[0] == context.src2->ne[0]; + + auto it = gated_delta_net_pipelines.find(key); + if (it != gated_delta_net_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "gated_delta_net"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for gated_delta_net shader"); + } + + if (key.kda) { + defines.push_back("KDA"); + variant += "_kda"; + } + + defines.push_back("S_V=" + std::to_string(key.s_v) + "u"); + defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u"); + + auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + gated_delta_net_pipelines[key] = pipeline; + return gated_delta_net_pipelines[key]; + } + webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + ggml_webgpu_pad_pipeline_key key = {}; + key.circular = ggml_get_op_params_i32(context.dst, 8) != 0; auto it = pad_pipelines.find(key); if (it != pad_pipelines.end()) { @@ -759,15 +1665,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_vec_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - }; + ggml_webgpu_mul_mat_vec_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -775,7 +1679,8 @@ class ggml_webgpu_shader_lib { } std::vector defines; - std::string variant = "mul_mat_vec"; + std::string variant = "mul_mat_vec"; + const char * shader_src = wgsl_mul_mat_vec; // src0 type (matrix row) switch (context.src0->type) { @@ -800,9 +1705,26 @@ class ggml_webgpu_shader_lib { defines.push_back("BYTE_HELPERS"); defines.push_back("MUL_ACC_" + type_upper); - - // For fast path we always dequantize from f16 inside the shader - defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } break; } } @@ -825,25 +1747,27 @@ class ggml_webgpu_shader_lib { defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; - if (key.src0_type >= GGML_TYPE_Q2_K) { - tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; + if (key.src0_type == GGML_TYPE_Q1_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q2_K) { outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; } else if (key.src0_type >= GGML_TYPE_Q4_0) { - tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } - auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; - decisions->tile_k = tile_k; decisions->outputs_per_wg = outputs_per_wg; decisions->vec_size = key.vectorized ? 4 : 1; @@ -854,15 +1778,14 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - .use_subgroup_matrix = context.supports_subgroup_matrix - }; + ggml_webgpu_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -915,9 +1838,27 @@ class ggml_webgpu_shader_lib { defines.push_back("MUL_ACC_" + type_upper); defines.push_back("INIT_SRC0_SHMEM_" + type_upper); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); - - // Use f16 inside the shader for quantized types - defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } variant += std::string("_") + src0_name; break; @@ -927,13 +1868,22 @@ class ggml_webgpu_shader_lib { // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + const bool is_quant = ggml_is_quantized(context.src0->type); + + uint32_t tile_k; + if (key.use_subgroup_matrix) { + tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; + } else { + tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + } + // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); - defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); // Subgroup matrix specifics if (key.use_subgroup_matrix) { + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); @@ -953,12 +1903,13 @@ class ggml_webgpu_shader_lib { if (!key.use_subgroup_matrix) { defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); } auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); - decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_k = tile_k; decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; decisions->use_subgroup_matrix = key.use_subgroup_matrix; @@ -982,8 +1933,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, - .src1_type = context.src1->type }; + ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_legacy_pipelines.find(key); if (it != mul_mat_legacy_pipelines.end()) { @@ -1022,11 +1974,34 @@ class ggml_webgpu_shader_lib { break; default: { - // quantized types std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - defines.push_back(std::string("SRC0_TYPE=") + src0_name); + switch (context.src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC0_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } + } + defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); @@ -1050,79 +2025,397 @@ class ggml_webgpu_shader_lib { return mul_mat_legacy_pipelines[key]; } - webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool is_unary = context.dst->op == GGML_OP_UNARY; - const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; - ggml_webgpu_unary_pipeline_key key = { - .type = context.dst->type, - .op = op, - .is_unary = is_unary, - .inplace = context.inplace, - }; + webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = mul_mat_id_gather_pipelines.find(1); + if (it != mul_mat_id_gather_pipelines.end()) { + return it->second; + } + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto it = unary_pipelines.find(key); - if (it != unary_pipelines.end()) { + auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather"); + pipeline.context = decisions; + mul_mat_id_gather_pipelines[1] = pipeline; + return pipeline; + } + + webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.n_experts = context.src0->ne[2]; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + + auto it = mul_mat_id_pipelines.find(key); + if (it != mul_mat_id_pipelines.end()) { return it->second; } std::vector defines; - std::string variant = - key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op); - defines.push_back(variant); + std::string variant = "mul_mat_id"; + defines.push_back("MUL_MAT_ID"); - switch (key.type) { + // src1 type + switch (context.src1->type) { case GGML_TYPE_F32: - defines.push_back("TYPE_F32"); + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); variant += "_f32"; break; case GGML_TYPE_F16: - defines.push_back("TYPE_F16"); + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); variant += "_f16"; break; default: - GGML_ABORT("Unsupported type for unary shader"); + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + + variant += std::string("_") + src0_name; + break; + } } - if (key.inplace) { - defines.push_back("INPLACE"); - variant += "_inplace"; + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + // mul_mat_id is register-tile only. + const uint32_t tile_k = + ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); + + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); - auto processed = preprocessor.preprocess(wgsl_unary, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - unary_pipelines[key] = pipeline; - return unary_pipelines[key]; - } + auto decisions = std::make_shared(); + decisions->tile_k = tile_k; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; - webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, - .src_overlap = context.src_overlap, - }; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_id_pipelines[key] = pipeline; + return mul_mat_id_pipelines[key]; + } - auto it = binary_pipelines.find(key); - if (it != binary_pipelines.end()) { + webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.n_experts = context.src0->ne[2]; + key.vectorized = (context.src0->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + + auto it = mul_mat_id_vec_pipelines.find(key); + if (it != mul_mat_id_vec_pipelines.end()) { return it->second; } std::vector defines; - std::string op_name = ggml_op_name((ggml_op) key.op); - std::string variant = op_name; - - defines.push_back(std::string("OP_") + op_name); + std::string variant = "mul_mat_id_vec"; + const char * shader_src = wgsl_mul_mat_id_vec; - switch (key.type) { + // src1 type + switch (context.src1->type) { case GGML_TYPE_F32: - defines.push_back("TYPE_F32"); - variant += "_f32"; + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f16"; + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; + + if (key.src0_type == GGML_TYPE_Q1_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q2_K) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q4_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } + + defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->outputs_per_wg = outputs_per_wg; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_id_vec_pipelines[key] = pipeline; + return mul_mat_id_vec_pipelines[key]; + } + + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool is_unary = context.dst->op == GGML_OP_UNARY; + const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; + ggml_webgpu_unary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = op; + key.is_unary = is_unary; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); + + auto it = unary_pipelines.find(key); + if (it != unary_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = + key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op); + defines.push_back(variant); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for unary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (op == GGML_OP_TRI) { + switch (key.ttype) { + case GGML_TRI_TYPE_LOWER: + defines.push_back("TRI_TYPE_LOWER"); + variant += "_tri_type_lower"; + break; + case GGML_TRI_TYPE_LOWER_DIAG: + defines.push_back("TRI_TYPE_LOWER_DIAG"); + variant += "_tri_type_lower_diag"; + break; + case GGML_TRI_TYPE_UPPER: + defines.push_back("TRI_TYPE_UPPER"); + variant += "_tri_type_upper"; + break; + case GGML_TRI_TYPE_UPPER_DIAG: + defines.push_back("TRI_TYPE_UPPER_DIAG"); + variant += "_tri_upper_diag"; + break; + default: + GGML_ABORT("Unsupported ggml_tri_type for unary shader"); + } + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_unary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + unary_pipelines[key] = pipeline; + return unary_pipelines[key]; + } + + webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rms_norm_mul_pipeline_key key = {}; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); + + auto it = rms_norm_mul_pipelines.find(key); + if (it != rms_norm_mul_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = "RMS_NORM_MUL"; + std::string variant = op_name; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + rms_norm_mul_pipelines[key] = pipeline; + return rms_norm_mul_pipelines[key]; + } + + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_binary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = context.dst->op; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); + + auto it = binary_pipelines.find(key); + if (it != binary_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = ggml_op_name((ggml_op) key.op); + std::string variant = op_name; + + defines.push_back(std::string("OP_") + op_name); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("TYPE_F16"); @@ -1145,19 +2438,22 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_binary, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; + pipeline.context = pipeline_decisions; binary_pipelines[key] = pipeline; return binary_pipelines[key]; } webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_concat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_concat_pipeline_key key = {}; + key.type = context.dst->type; auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -1192,9 +2488,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_repeat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_repeat_pipeline_key key = {}; + key.type = context.dst->type; auto it = repeat_pipelines.find(key); if (it != repeat_pipelines.end()) { @@ -1232,30 +2527,20 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - - bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && - (context.src1->ne[1] % context.sg_mat_n == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = context.src1->type, - .head_dim_qk = (uint32_t) context.src0->ne[0], - .head_dim_v = (uint32_t) context.src2->ne[0], - .kv_direct = kv_direct, - .has_mask = has_mask, - .has_sinks = has_sinks, - .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, - }; - - auto it = flash_attn_pipelines.find(key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + const ggml_webgpu_flash_attn_decisions decisions = + ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); + GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } - std::vector defines; - std::string variant = "flash_attn"; + std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" : + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" : + "flash_attn"; switch (key.kv_type) { case GGML_TYPE_F32: @@ -1277,7 +2562,12 @@ class ggml_webgpu_shader_lib { if (key.has_mask) { defines.push_back("MASK"); - variant += "_mask"; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + defines.push_back("BLK"); + variant += "_mask_blk"; + } else { + variant += "_mask"; + } } if (key.has_sinks) { defines.push_back("SINKS"); @@ -1291,6 +2581,10 @@ class ggml_webgpu_shader_lib { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); @@ -1298,37 +2592,456 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); variant += std::string("_hsv") + std::to_string(key.head_dim_v); - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + const char * shader_src = wgsl_flash_attn; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + defines.push_back("KV_GRANULARITY=8"); + defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u"); + shader_src = wgsl_flash_attn_vec_split; + } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + shader_src = wgsl_flash_attn_tile; + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size)); + defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); + variant += "_tile"; + } else { + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + } - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = - std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, - context.wg_mem_limit_bytes, context.max_subgroup_size }), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (key.kv_direct) { - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } + auto pipeline_decisions = std::make_shared(decisions); + pipeline_decisions->kv_overlap = key.kv_overlap; + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); + pipeline.context = pipeline_decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { + ggml_webgpu_flash_attn_blk_pipeline_key key = {}; + key.kv_tile = kv_tile; + auto it = flash_attn_blk_pipelines.find(key); + if (it != flash_attn_blk_pipelines.end()) { + return it->second; } - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile)); + variant += std::string("_kvt") + std::to_string(key.kv_tile); - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); - auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant); + flash_attn_blk_pipelines[key] = pipeline; + return flash_attn_blk_pipelines[key]; + } - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - flash_attn_pipelines[key] = pipeline; - return flash_attn_pipelines[key]; + webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.wg_size = context.max_wg_size; + auto it = flash_attn_vec_reduce_pipelines.find(key); + if (it != flash_attn_vec_reduce_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant); + flash_attn_vec_reduce_pipelines[key] = pipeline; + return flash_attn_vec_reduce_pipelines[key]; + } + + webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_cpy_pipeline_key key = {}; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; + + auto it = cpy_pipelines.find(key); + if (it != cpy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "cpy"; + + switch (key.src_type) { + case GGML_TYPE_F32: + defines.push_back("SRC_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src type for cpy shader"); + } + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("DST_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported dst type for cpy shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cpy, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + cpy_pipelines[key] = pipeline; + return cpy_pipelines[key]; + } + + webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_glu_pipeline_key key = {}; + key.glu_op = ggml_get_glu_op(context.dst); + key.type = context.dst->type; + key.split = (context.src1 != nullptr); + + auto it = glu_pipelines.find(key); + if (it != glu_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "glu"; + + switch (key.glu_op) { + case GGML_GLU_OP_REGLU: + defines.push_back("OP_REGLU"); + variant += "_reglu"; + break; + case GGML_GLU_OP_GEGLU: + defines.push_back("OP_GEGLU"); + variant += "_geglu"; + break; + case GGML_GLU_OP_SWIGLU: + defines.push_back("OP_SWIGLU"); + variant += "_swiglu"; + break; + case GGML_GLU_OP_SWIGLU_OAI: + defines.push_back("OP_SWIGLU_OAI"); + variant += "_swiglu_oai"; + break; + case GGML_GLU_OP_GEGLU_ERF: + defines.push_back("OP_GEGLU_ERF"); + variant += "_geglu_erf"; + break; + case GGML_GLU_OP_GEGLU_QUICK: + defines.push_back("OP_GEGLU_QUICK"); + variant += "_geglu_quick"; + break; + default: + GGML_ABORT("Unsupported GLU op"); + } + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for GLU shader"); + } + + if (key.split) { + variant += "_split"; + } else { + defines.push_back("NO_SPLIT"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_glu, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + glu_pipelines[key] = pipeline; + return glu_pipelines[key]; + } + + webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rope_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.has_ff = (context.src2 != nullptr); + + auto it = rope_pipelines.find(key); + if (it != rope_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "rope"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for ROPE shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (key.has_ff) { + defines.push_back("FF_FUNC"); + variant += "_ff"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rope, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rope_pipelines[key] = pipeline; + return rope_pipelines[key]; + } + + webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_soft_max_pipeline_key key = {}; + key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; + key.has_mask = (context.src1 != nullptr); + key.has_sink = (context.src2 != nullptr); + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = soft_max_pipelines.find(key); + if (it != soft_max_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "soft_max"; + + if (key.has_mask) { + defines.push_back("HAS_MASK"); + switch (key.mask_type) { + case GGML_TYPE_F32: + defines.push_back("MASK_F32"); + variant += "_mask_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("MASK_F16"); + variant += "_mask_f16"; + break; + default: + GGML_ABORT("Unsupported type for SOFT_MAX shader"); + } + } + + if (key.has_sink) { + defines.push_back("HAS_SINK"); + variant += "_sink"; + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_soft_max, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + soft_max_pipelines[key] = pipeline; + return soft_max_pipelines[key]; + } + + webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_conv2d_pipeline_key key = {}; + key.weight_type = context.src0->type; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = conv2d_pipelines.find(key); + if (it != conv2d_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "conv_2d"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for CONV_2D shader"); + } + }; + + push_type_defines("WEIGHT", key.weight_type); + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_conv2d, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + conv2d_pipelines[key] = pipeline; + return conv2d_pipelines[key]; + } + + webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_im2col_pipeline_key key = {}; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = im2col_pipelines.find(key); + if (it != im2col_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "im2col"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for IM2COL shader"); + } + }; + + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_im2col, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + im2col_pipelines[key] = pipeline; + return im2col_pipelines[key]; + } + + webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0); + const uint32_t base_mode = mode_flags & 0xFFu; + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u; + + ggml_webgpu_upscale_pipeline_key key = {}; + key.input_type = context.src0->type; + key.output_type = context.dst->type; + key.base_mode = base_mode; + key.antialias = antialias; + + auto it = upscale_pipelines.find(key); + if (it != upscale_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "upscale"; + + if (key.input_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } else { + variant += "_src_f32"; + } + + if (key.output_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } else { + variant += "_dst_f32"; + } + + switch (base_mode) { + case GGML_SCALE_MODE_NEAREST: + defines.push_back("NEAREST"); + variant += "_nearest"; + break; + case GGML_SCALE_MODE_BILINEAR: + defines.push_back("BILINEAR"); + variant += "_bilinear"; + break; + case GGML_SCALE_MODE_BICUBIC: + defines.push_back("BICUBIC"); + variant += "_bicubic"; + break; + default: + GGML_ABORT("Unsupported upscale mode"); + } + + if (antialias) { + defines.push_back("ANTIALIAS"); + variant += "_aa"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_upscale, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + upscale_pipelines[key] = pipeline; + return upscale_pipelines[key]; } private: @@ -1350,25 +3063,6 @@ class ggml_webgpu_shader_lib { pipeline_desc.layout = nullptr; // nullptr means auto layout return { device.CreateComputePipeline(&pipeline_desc), label }; } - - static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; - } }; #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 128b7dc3de8..cab0aead198 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,6 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" +#include "ggml.h" #ifdef __EMSCRIPTEN__ # include @@ -16,7 +17,6 @@ #include #include -#include #include #include #ifdef GGML_WEBGPU_GPU_PROFILE @@ -25,7 +25,6 @@ #if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE) # include #endif -#include #include #include #include @@ -43,6 +42,12 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim wg_x = CEIL_DIV(total_wg, wg_y); } +static inline uint32_t ggml_webgpu_u32_from_f32(float value) { + uint32_t bits; + memcpy(&bits, &value, sizeof(bits)); + return bits; +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -75,21 +80,19 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32 -# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps +# define WEBGPU_MAX_PROFILE_QUERY_COUNT 4096u +# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES (WEBGPU_MAX_PROFILE_QUERY_COUNT * sizeof(uint64_t)) #endif /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 96u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u -#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the -// parameter buffer pool -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) -#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 64u +#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u +#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u +#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6) +#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters +#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 // For operations which process a row in parallel, this seems like a reasonable // default @@ -105,12 +108,9 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT -// Always returns the base offset of a tensor, regardless of views. -static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { - if (tensor->view_src) { - return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base; - } - return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs; } /* Struct definitions */ @@ -122,164 +122,64 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_buf_pool { - std::vector free; - - // The pool must be synchronized because - // 1. The memset pool is shared globally by every ggml buffer, - // since allocating a pool per ggml buffer would consume too much memory. - // 2. For the per-thread buffer pools in webgpu_context, - // buffers are allocated and freed in Dawn callbacks, - // which can run on a different thread than the calling thread. - std::mutex mutex; - std::condition_variable cv; - size_t cur_pool_size; - size_t max_pool_size; - wgpu::Device device; - wgpu::BufferUsage dev_buf_usage; - size_t buf_size; - bool should_grow; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - bool should_grow = false, - size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { - this->max_pool_size = max_pool_size; - this->cur_pool_size = num_bufs; - this->device = device; - this->dev_buf_usage = dev_buf_usage; - this->buf_size = buf_size; - this->should_grow = should_grow; - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back(dev_buf); - } +// Slot-based parameter arena for compute graph encoding. Each encoded kernel +// gets a unique uniform-buffer slice within the current batch, and the slot +// cursor is reset immediately after that batch is submitted. +struct webgpu_param_arena { + wgpu::Buffer buffer; + size_t slot_stride = 0; + size_t slot_size = 0; + uint32_t slot_count = 0; + uint32_t next_slot = 0; + + void init(wgpu::Device device, size_t slot_size, uint32_t slot_count, size_t alignment) { + this->slot_stride = ROUNDUP_POW2(slot_size, alignment); + this->slot_size = slot_size; + this->slot_count = slot_count; + this->next_slot = 0; + + ggml_webgpu_create_buffer(device, buffer, this->slot_stride * slot_count, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, "ggml_webgpu_param_arena"); } - wgpu::Buffer alloc_bufs() { - std::unique_lock lock(mutex); - if (!free.empty()) { - wgpu::Buffer buf = free.back(); - free.pop_back(); - return buf; + size_t alloc_slot(size_t size) { + GGML_ASSERT(size <= slot_size); + if (next_slot >= slot_count) { + GGML_ABORT("ggml_webgpu: parameter arena exhausted while encoding a batch"); } - // Try growing the pool if no free buffers - if (free.empty() && cur_pool_size < max_pool_size && should_grow) { - cur_pool_size++; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - - if (!dev_buf) { - GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); - } - return dev_buf; - } - cv.wait(lock, [this] { return !free.empty(); }); - wgpu::Buffer buf = free.back(); - free.pop_back(); - return buf; + return slot_stride * next_slot++; } - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } + void reset() { next_slot = 0; } void cleanup() { - std::lock_guard lock(mutex); - for (auto & buf : free) { - if (buf) { - buf.Destroy(); - } + if (buffer) { + buffer.Destroy(); + buffer = nullptr; } - free.clear(); } - ~webgpu_buf_pool() { this->cleanup(); } + ~webgpu_param_arena() { this->cleanup(); } }; +struct webgpu_encoded_op { + uint32_t num_kernels = 0; #ifdef GGML_WEBGPU_GPU_PROFILE -struct webgpu_gpu_profile_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - wgpu::QuerySet query_set; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_gpu_profile_buf_pool { - std::vector free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); - // Create a query set for 2 timestamps - wgpu::QuerySetDescriptor ts_query_set_desc = {}; - - ts_query_set_desc.type = wgpu::QueryType::Timestamp; - ts_query_set_desc.count = 2; - wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); - - free.push_back({ host_buf, dev_buf, ts_query_set }); - } - } - - webgpu_gpu_profile_bufs alloc_bufs() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_gpu_profile_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - bufs.query_set.Destroy(); - } - free.clear(); - } - - ~webgpu_gpu_profile_buf_pool() { this->cleanup(); } -}; + std::vector pipeline_names; #endif +}; -struct webgpu_command { - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; -#ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs timestamp_query_bufs; - std::string pipeline_name; -#endif +struct webgpu_dispatch_desc { + webgpu_pipeline pipeline; + std::vector params; + std::vector bind_group_entries; + std::pair workgroups = { 1, 1 }; }; struct webgpu_capabilities { wgpu::Limits limits; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; @@ -297,16 +197,19 @@ struct webgpu_global_context_struct { wgpu::Adapter adapter; wgpu::Device device; wgpu::Queue queue; + uint32_t command_submit_batch_size = WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; + uint32_t max_inflight_batches = UINT32_MAX; webgpu_capabilities capabilities; // Shared buffer to move data from device to host wgpu::Buffer get_tensor_staging_buf; - // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. + // Global mutex for get_tensor std::recursive_mutex mutex; - webgpu_buf_pool memset_buf_pool; - std::map memset_pipelines; // variant or type index + wgpu::Buffer memset_params_buf; + webgpu_pipeline memset_pipeline; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) std::unordered_map cpu_time_ms; @@ -314,13 +217,6 @@ struct webgpu_global_context_struct { std::unordered_map cpu_detail_ms; #endif -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; -#endif - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; @@ -331,6 +227,10 @@ struct webgpu_global_context_struct { this->get_tensor_staging_buf.Destroy(); this->get_tensor_staging_buf = nullptr; } + if (this->memset_params_buf) { + this->memset_params_buf.Destroy(); + this->memset_params_buf = nullptr; + } #ifdef GGML_WEBGPU_DEBUG if (this->debug_host_buf) { this->debug_host_buf.Destroy(); @@ -346,13 +246,6 @@ struct webgpu_global_context_struct { typedef std::shared_ptr webgpu_global_context; -struct webgpu_submission { - wgpu::FutureWaitInfo submit_done; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector profile_futures; -#endif -}; - // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context @@ -360,19 +253,47 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; - webgpu_buf_pool param_buf_pool; - wgpu::Buffer set_rows_dev_error_buf; - wgpu::Buffer set_rows_host_error_buf; - - std::map> cpy_pipelines; // src_type, dst_type + webgpu_param_arena param_arena; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; + wgpu::CommandEncoder active_command_encoder; + wgpu::ComputePassEncoder active_compute_pass; - std::map rms_norm_pipelines; // inplace - std::map>> rope_pipelines; // type, ff, inplace - std::map>> glu_pipelines; // glu_op, type, split + size_t memset_bytes_per_thread; - std::map>> soft_max_pipelines; // mask_type, has_sink, inplace +#ifdef GGML_WEBGPU_GPU_PROFILE + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; +#endif - size_t memset_bytes_per_thread; + ~webgpu_context_struct() { +#ifdef GGML_WEBGPU_GPU_PROFILE + if (this->profile_timestamp_host_buf) { + this->profile_timestamp_host_buf.Destroy(); + this->profile_timestamp_host_buf = nullptr; + } + if (this->profile_timestamp_dev_buf) { + this->profile_timestamp_dev_buf.Destroy(); + this->profile_timestamp_dev_buf = nullptr; + } + if (this->profile_timestamp_query_set) { + this->profile_timestamp_query_set.Destroy(); + this->profile_timestamp_query_set = nullptr; + } +#endif + if (this->set_rows_host_error_buf) { + this->set_rows_host_error_buf.Destroy(); + this->set_rows_host_error_buf = nullptr; + } + if (this->set_rows_dev_error_buf) { + this->set_rows_dev_error_buf.Destroy(); + this->set_rows_dev_error_buf = nullptr; + } + } }; typedef std::shared_ptr webgpu_context; @@ -451,110 +372,119 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } -/** End WebGPU object initializations */ +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} -/** WebGPU Actions */ +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} -static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) { - switch (status) { - case wgpu::WaitStatus::Success: - return true; - case wgpu::WaitStatus::TimedOut: - if (allow_timeout) { - return false; - } - GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n"); - return false; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - return false; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - return false; - } +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } -#ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) { - futures.erase(std::remove_if(futures.begin(), futures.end(), - [](const wgpu::FutureWaitInfo & info) { return info.completed; }), - futures.end()); +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); } -static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, - std::vector & futures, - bool block) { - if (futures.empty()) { - return; - } +struct ggml_webgpu_merged_binding_range { + size_t offset; + size_t size; +}; - uint64_t timeout_ms = block ? UINT64_MAX : 0; - if (block) { - while (!futures.empty()) { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } - } - } else { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } +static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range( + webgpu_context & ctx, + std::initializer_list tensors) { + size_t merged_offset = SIZE_MAX; + size_t merged_end = 0; + + for (ggml_tensor * tensor : tensors) { + const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor); + const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor); + + merged_offset = std::min(merged_offset, bind_offset); + merged_end = std::max(merged_end, bind_end); } + + return { merged_offset, merged_end - merged_offset }; } -#endif -// Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & subs, - bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. - if (subs.empty()) { - return; +static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor, + const ggml_webgpu_merged_binding_range & merged_range) { + return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type)); +} + +static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, + wgpu::Buffer buffer, + uint64_t offset, + uint64_t size) { + wgpu::BindGroupEntry entry = {}; + entry.binding = binding; + entry.buffer = std::move(buffer); + entry.offset = offset; + entry.size = size; + return entry; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_tensor_bind_group_entry(webgpu_context & ctx, + uint32_t binding, + ggml_tensor * tensor) { + return ggml_webgpu_make_bind_group_entry(binding, ggml_webgpu_tensor_buf(tensor), + ggml_webgpu_tensor_align_offset(ctx, tensor), + ggml_webgpu_tensor_binding_size(ctx, tensor)); +} + +/** End WebGPU object initializations */ + +/** WebGPU Actions */ + +template +static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, + T callback_status, + T success_status, + const char * wait_name, + const char * failure_name, + const char * callback_message) { + if (wait_status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: %s timed out after %u ms\n", wait_name, WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); } - while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); -#endif - subs.erase(subs.begin()); - } + if (wait_status == wgpu::WaitStatus::Error) { + GGML_ABORT("ggml_webgpu: %s failed\n", wait_name); } - - if (subs.empty()) { - return; + if (callback_status != success_status) { + GGML_ABORT("ggml_webgpu: %s failed with status %d: %s\n", failure_name, static_cast(callback_status), + callback_message); } +} - if (block) { - for (auto & sub : subs) { - while (!sub.submit_done.completed) { - auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX); - ggml_backend_webgpu_handle_wait_status(waitStatus); - } -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true); -#endif - } - subs.clear(); - } else { - // Poll each submit future once and remove completed submissions. - for (auto sub = subs.begin(); sub != subs.end();) { - auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); - ggml_backend_webgpu_handle_wait_status(waitStatus, true); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); - if (sub->submit_done.completed && sub->profile_futures.empty()) { -#else - if (sub->submit_done.completed) { -#endif - sub = subs.erase(sub); - } else { - ++sub; - } - } - } +// TODO: these next two functions may want tuning across different platforms and workloads, +static uint32_t ggml_backend_webgpu_get_max_inflight_batches() { + return UINT32_MAX; +} + +static uint32_t ggml_backend_webgpu_get_command_submit_batch_size() { + return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; +} + +static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { + wgpu::QueueWorkDoneStatus callback_status = wgpu::QueueWorkDoneStatus::Error; + std::string callback_message; + + const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( + ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + callback_status = status; + callback_message = std::string(message); + }), + WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + + ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::QueueWorkDoneStatus::Success, + "Queue wait", "Queue work", callback_message.c_str()); } static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, @@ -562,14 +492,31 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, wgpu::MapMode mode, size_t offset, size_t size) { - ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", - message.data); - } - }), - UINT64_MAX); + wgpu::MapAsyncStatus callback_status = wgpu::MapAsyncStatus::Error; + std::string callback_message; + + const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( + buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, + [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { + callback_status = status; + callback_message = std::string(message); + }), + WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + + ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::MapAsyncStatus::Success, + "Buffer map wait", "Buffer map", callback_message.c_str()); +} + +static void ggml_backend_webgpu_submit_commands(webgpu_context & ctx, + const wgpu::CommandBuffer commands, + uint32_t & num_inflight_batches) { + if (num_inflight_batches >= ctx->global_ctx->max_inflight_batches) { + ggml_backend_webgpu_wait_queue(ctx->global_ctx); + num_inflight_batches = 0; + } + + ctx->global_ctx->queue.Submit(1, &commands); + num_inflight_batches++; } #ifdef GGML_WEBGPU_DEBUG @@ -588,147 +535,79 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx, - std::vector & commands, - webgpu_buf_pool & param_buf_pool) { - std::vector command_buffers; - std::vector params_bufs; - webgpu_submission submission; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector> pipeline_name_and_ts_bufs; -#endif - - for (const auto & command : commands) { - command_buffers.push_back(command.commands); - params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); - } - ctx->queue.Submit(command_buffers.size(), command_buffers.data()); - - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); - } - // Free the staged buffers - param_buf_pool.free_bufs(params_bufs); - }); - submission.submit_done = { p_f }; - -#ifdef GGML_WEBGPU_GPU_PROFILE - for (const auto & command : commands) { - auto label = command.pipeline_name; - auto ts_bufs = command.timestamp_query_bufs; - - wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); - } else { - const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange(); - // WebGPU timestamps are in ns; convert to ms - double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; - ctx->shader_gpu_time_ms[label] += elapsed_ms; - } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); - }); - submission.profile_futures.push_back({ f }); - } -#endif - return submission; -} - -static webgpu_command ggml_backend_webgpu_build_multi( - webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, - const std::vector & pipelines, - const std::vector> & params_list, - const std::vector> & bind_group_entries_list, - const std::vector> & workgroups_list) { - GGML_ASSERT(pipelines.size() == params_list.size()); - GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); - GGML_ASSERT(pipelines.size() == workgroups_list.size()); - - std::vector params_bufs_list; +static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & ctx, + const std::vector & dispatches) { + webgpu_encoded_op result = {}; std::vector bind_groups; + std::vector param_offsets; + result.num_kernels = dispatches.size(); - for (size_t i = 0; i < pipelines.size(); i++) { - wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); + for (size_t i = 0; i < dispatches.size(); i++) { + const webgpu_dispatch_desc & dispatch = dispatches[i]; + const size_t param_size = dispatch.params.size() * sizeof(uint32_t); + const size_t param_offset = ctx->param_arena.alloc_slot(param_size); - std::vector entries = bind_group_entries_list[i]; + std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); - entries.push_back( - { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); + entries.push_back(ggml_webgpu_make_bind_group_entry(params_binding_num, ctx->param_arena.buffer, param_offset, + ctx->param_arena.slot_size)); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); + bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = entries.size(); bind_group_desc.entries = entries.data(); - bind_group_desc.label = pipelines[i].name.c_str(); - bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); - - params_bufs_list.push_back(params_bufs); + bind_group_desc.label = dispatch.pipeline.name.c_str(); + bind_groups.push_back(ctx->global_ctx->device.CreateBindGroup(&bind_group_desc)); + param_offsets.push_back(param_offset); } - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (size_t i = 0; i < params_bufs_list.size(); i++) { - ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); + for (size_t i = 0; i < param_offsets.size(); i++) { + ctx->global_ctx->queue.WriteBuffer(ctx->param_arena.buffer, param_offsets[i], dispatches[i].params.data(), + dispatches[i].params.size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); - } + for (size_t i = 0; i < dispatches.size(); i++) { + GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); -#else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); -#endif - for (size_t i = 0; i < pipelines.size(); i++) { - pass.SetPipeline(pipelines[i].pipeline); + wgpu::PassTimestampWrites ts_writes = {}; + ts_writes.querySet = ctx->profile_timestamp_query_set; + ts_writes.beginningOfPassWriteIndex = query_begin; + ts_writes.endOfPassWriteIndex = query_end; + wgpu::ComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &ts_writes; + + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + + pass.SetPipeline(dispatches[i].pipeline.pipeline); pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + result.pipeline_names.push_back(dispatches[i].pipeline.name); + } +#else + for (size_t i = 0; i < dispatches.size(); i++) { + ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); + ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); + ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); } - pass.End(); - -#ifdef GGML_WEBGPU_GPU_PROFILE - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); #endif - wgpu::CommandBuffer commands = encoder.Finish(); - webgpu_command result = {}; - result.commands = commands; - result.params_bufs = params_bufs_list; - result.num_kernels = pipelines.size(); -#ifdef GGML_WEBGPU_GPU_PROFILE - result.timestamp_query_bufs = ts_bufs; - // TODO: handle multiple pipeline names - result.pipeline_name = pipelines.front().name; -#endif return result; } -static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, - webgpu_pipeline & pipeline, - std::vector params, - std::vector bind_group_entries, - uint32_t wg_x, - uint32_t wg_y = 1) { - return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, - { - pipeline - }, - { std::move(params) }, { std::move(bind_group_entries) }, - { { wg_x, wg_y } }); +static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_context & ctx, + webgpu_pipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + uint32_t wg_y = 1) { + return ggml_backend_webgpu_build_multi( + ctx, { + { pipeline, std::move(params), std::move(bind_group_entries), { wg_x, wg_y } }, + }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -736,18 +615,37 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector entries = { - { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } - }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + + ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); + + wgpu::BindGroupEntry params_entry = {}; + params_entry.binding = 1; + params_entry.buffer = ctx->memset_params_buf; + params_entry.offset = 0; + params_entry.size = WEBGPU_PARAMS_BUF_SIZE_BYTES; + entries.push_back(params_entry); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = ctx->memset_pipeline.name.c_str(); + wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); + + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(ctx->memset_pipeline.pipeline); + pass.SetBindGroup(0, bind_group); + pass.DispatchWorkgroups(wg_x, 1, 1); + pass.End(); - webgpu_command command = - ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector commands = { command }; - std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; - ggml_backend_webgpu_wait(ctx, sub); + wgpu::CommandBuffer command = encoder.Finish(); + std::vector commands = { command }; + ctx->queue.Submit(commands.size(), commands.data()); } /** End WebGPU Actions */ @@ -787,12 +685,12 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_GPU_PROFILE std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; double total_gpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { total_gpu += kv.second; } std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; std::cout << "\nggml_webgpu: gpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) << pct << "%)\n"; @@ -807,162 +705,586 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { delete backend; } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; +static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; +static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const bool inplace = decisions->inplace; + + const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); + const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + 1u, + (uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size), + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector entries; + uint32_t binding_index = 0; + if (!inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + binding_index++; + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index + 1, dst)); + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Strides (in elements) + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + // Shapes + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + // Pad sizes + (uint32_t) ggml_get_op_params_i32(dst, 0), + (uint32_t) ggml_get_op_params_i32(dst, 1), + (uint32_t) ggml_get_op_params_i32(dst, 2), + (uint32_t) ggml_get_op_params_i32(dst, 3), + (uint32_t) ggml_get_op_params_i32(dst, 4), + (uint32_t) ggml_get_op_params_i32(dst, 5), + (uint32_t) ggml_get_op_params_i32(dst, 6), + (uint32_t) ggml_get_op_params_i32(dst, 7), + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + + webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); + const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); +static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); +static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1; + + const uint32_t KW = src0->ne[0]; + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1]; + + const uint32_t IW = src1->ne[0]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2]; + + const uint32_t OW = dst->ne[1]; + const uint32_t OH = is_2D ? dst->ne[2] : 1; + + const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)); + const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0; + const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)); + const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)); + + const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)); + const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)); + const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0; + const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) : + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + si0, + si1, + si2, + si3, + so0, + so1, + so2, + so3, + + KW, + KH, + IC, + + IW, + IH, + N, + + OW, + OH, + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} +static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + token_tiles, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); + const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * src6, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.src5 = src5; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + const bool xbc_overlap = decisions->xbc_overlap; + + uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)); + uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)); + size_t xbc_bind_offset = 0; + size_t xbc_bind_size = 0; + if (xbc_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 }); + xbc_bind_offset = merged_range.offset; + xbc_bind_size = merged_range.size; + offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range); + offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range); + } + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + offset_x, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), + offset_B, + offset_C, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; -}; + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + (uint32_t) src3->ne[0], + (uint32_t) (src3->nb[1] / ggml_type_size(src3->type)), - return flags; -} + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), -static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - uint32_t ne = (uint32_t) ggml_nelements(dst); + (uint32_t) (src5->nb[1] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[2] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[3] / ggml_type_size(src5->type)), - std::vector params = { - ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), - (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), - (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shapes - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) src4->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + (uint32_t) ggml_nelements(src1), }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; + if (xbc_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst)); + } + + const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x; + uint32_t wg_y; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type], - params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; +static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = src3; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); + + const uint32_t s_v = (uint32_t) src2->ne[0]; + const uint32_t h = (uint32_t) src2->ne[1]; + const uint32_t n_tokens = (uint32_t) src2->ne[2]; + const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const float scale = 1.0f / sqrtf((float) s_v); + uint32_t scale_u32; + memcpy(&scale_u32, &scale, sizeof(scale_u32)); - webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); + std::vector params = { + h, + n_tokens, + n_seqs, + s_v * h * n_tokens * n_seqs, - auto * decisions = static_cast(pipeline.context.get()); + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - const uint32_t ne = (uint32_t) ggml_nelements(dst); + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[3] / ggml_type_size(src2->type)), - std::vector params = { - ne, - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Strides (in elements) - (uint32_t) (src->nb[0] / ggml_type_size(src->type)), - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), - (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), - // Shapes - (uint32_t) src->ne[0], - (uint32_t) src->ne[1], - (uint32_t) src->ne[2], - (uint32_t) src->ne[3], - (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], - (uint32_t) dst->ne[3], - // Pad sizes - (uint32_t) ggml_get_op_params_i32(dst, 0), - (uint32_t) ggml_get_op_params_i32(dst, 1), - (uint32_t) ggml_get_op_params_i32(dst, 2), - (uint32_t) ggml_get_op_params_i32(dst, 3), - (uint32_t) ggml_get_op_params_i32(dst, 4), - (uint32_t) ggml_get_op_params_i32(dst, 5), - (uint32_t) ggml_get_op_params_i32(dst, 6), - (uint32_t) ggml_get_op_params_i32(dst, 7), + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) src0->ne[1], + (uint32_t) (src2->ne[3] / src0->ne[3]), + scale_u32, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, dst), }; - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); } -static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { return std::nullopt; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = idx, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = idx; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); @@ -985,25 +1307,14 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; if (decisions->i64_idx) { - entries.push_back({ .binding = 3, - .buffer = ctx->set_rows_dev_error_buf, - .offset = 0, - .size = ctx->set_rows_dev_error_buf.GetSize() }); + entries.push_back(ggml_webgpu_make_bind_group_entry(3, ctx->set_rows_dev_error_buf, 0, + ctx->set_rows_dev_error_buf.GetSize())); } uint32_t threads; @@ -1013,7 +1324,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1024,16 +1335,17 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = WEBGPU_MAX_WG_SIZE, - }; +static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { + const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1057,30 +1369,22 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst) }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); + uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); + uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); + uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; + uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); @@ -1100,16 +1404,24 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: case GGML_TYPE_Q6_K: - use_fast = true; - break; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat - use_fast = !is_vec; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q1_0: + use_fast = true; + break; + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + use_fast = true; break; default: break; @@ -1119,17 +1431,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline webgpu_pipeline pipeline; @@ -1164,18 +1477,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, // Build bind group entries std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; // Calculate workgroup dimensions @@ -1219,22 +1523,212 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -#ifndef __EMSCRIPTEN__ -static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { - float scale = *(float *) dst->op_params; - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float logit_softcap; - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); +static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const uint32_t param_n_expert = (uint32_t) src0->ne[2]; + const uint32_t param_n_expert_used = (uint32_t) dst->ne[1]; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_mul_mat_id_vec_pipeline(shader_lib_ctx); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + param_n_expert, + param_n_expert_used, + (uint32_t) src1->ne[1], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + }; + + std::vector entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * param_n_expert_used; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + // we can use mat-vec fast path + if (dst->ne[2] == 1) { + return ggml_webgpu_mul_mat_id_vec(ctx, src0, src1, src2, dst); + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + // Get or create pipeline + webgpu_pipeline gather_pipeline; + webgpu_pipeline main_pipeline; + + std::vector dispatches; + + gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx); + main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx); + + const uint32_t param_n_expert = (uint32_t) src0->ne[2]; + const uint32_t param_n_expert_used = (uint32_t) dst->ne[1]; + const uint32_t param_n_tokens = (uint32_t) dst->ne[2]; + + // params for mul_mat_id_gather.wgsl + std::vector gather_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + param_n_expert, + param_n_expert_used, + param_n_tokens, + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + }; + + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t gathered_buf_nbytes = src0->ne[2] * src1->ne[2] * sizeof(uint32_t); + + const size_t gathered_expert_used_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t gathered_tokens_align_offset = + ROUNDUP_POW2(gathered_expert_used_align_offset + gathered_buf_nbytes, + ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t gathered_count_ids_align_offset = + ROUNDUP_POW2(gathered_tokens_align_offset + gathered_buf_nbytes, + ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + + const size_t gathered_binding_size = ROUNDUP_POW2(gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t gathered_count_ids_binding_size = + ROUNDUP_POW2(src0->ne[2] * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + + // bind group entries for mul_mat_id_gather.wgsl + std::vector gather_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), + }; + + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + + const uint32_t gather_total_wg = param_n_expert; + const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); + const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); + + dispatches.push_back({ + gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y } + }); + + // params for mul_mat_id.wgsl + std::vector main_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + param_n_expert, + param_n_expert_used, + param_n_tokens, + (uint32_t) src1->ne[1], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + }; + + // bind group entries for mul_mat_id.wgsl + std::vector main_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(4, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(5, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), + }; + + // Calculate workgroup dimensions + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + auto * main_decisions = static_cast(main_pipeline.context.get()); + + uint32_t wg_m; + + uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m; + uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + uint32_t total_gathered = dst->ne[1] * dst->ne[2]; + uint32_t max_active_experts = std::min((uint32_t) src0->ne[2], total_gathered); + uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts; + uint32_t total_wg = wg_m * max_wg_n; + + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + dispatches.push_back({ + main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y } + }); + + return ggml_backend_webgpu_build_multi(ctx, dispatches); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + float scale = ggml_get_op_params_f32(dst, 0); + float max_bias = ggml_get_op_params_f32(dst, 1); + float logit_softcap = ggml_get_op_params_f32(dst, 2); if (logit_softcap != 0.0f) { scale /= logit_softcap; } @@ -1242,13 +1736,44 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( + shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); + const int has_mask = (mask != nullptr); + const int has_sinks = (sinks != nullptr); + const bool kv_overlap = decisions->kv_overlap; + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + if (kv_overlap) { + const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); + kv_bind_offset = merged_range.offset; + kv_bind_size = merged_range.size; + offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); + offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); + } std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), + offset_k, + offset_v, has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1266,86 +1791,225 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) - *(uint32_t *) &max_bias, - *(uint32_t *) &logit_softcap, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(logit_softcap), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), }; - uint32_t binding_index = 3; + if (kv_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + } + uint32_t binding_index = kv_overlap ? 2u : 3u; if (has_mask) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } if (has_sinks) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = Q, - .src1 = K, - .src2 = V, - .src3 = mask, - .src4 = sinks, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); + + if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + } + + wgpu::Buffer blk_buf = {}; + uint64_t blk_size_bytes = 0; + uint32_t blk_nblk0 = 0; + uint32_t blk_nblk1 = 0; + uint32_t blk_batch_count = 0; + + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); + + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region. + tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT; + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } + + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (has_mask) { + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = (uint32_t) Q->ne[1]; + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 + }; + blk_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask)), + ggml_webgpu_make_bind_group_entry(1, blk_buf, scratch_offset, blk_size_bytes), + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (has_mask) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), + ggml_webgpu_tensor_binding_size(ctx, Q)), }; + if (kv_overlap) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + } else { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), + ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K))); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), + ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V))); + } + uint32_t split_binding_index = kv_overlap ? 2u : 3u; + if (has_mask) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask))); + } + if (has_sinks) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), + ggml_webgpu_tensor_align_offset(ctx, sinks), + ggml_webgpu_tensor_binding_size(ctx, sinks))); + } + if (has_mask) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); + } + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, tmp_buf, tmp_bind_offset, tmp_bind_size)); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(dst), + ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst))); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + reduce_shader_ctx.max_wg_size = reduce_wg_size; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; + + reduce_entries = { + ggml_webgpu_make_bind_group_entry(0, tmp_buf, tmp_bind_offset, tmp_size_bytes), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + } - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); + uint32_t wg_x = Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); - auto * decisions = static_cast(pipeline.context.get()); + std::vector dispatches; - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + if (has_mask) { + dispatches.push_back({ + blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } + }); + } + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); + if (use_vec_reduce) { + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } + }); + } + + return ggml_backend_webgpu_build_multi(ctx, dispatches); } -#endif -static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; - bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); + const bool inplace = decisions->inplace; uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1372,10 +2036,10 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s float alpha_p = ggml_get_op_params_f32(dst, 2); float beta = ggml_get_op_params_f32(dst, 3); float eps = ggml_get_op_params_f32(dst, 4); - params.push_back(*reinterpret_cast(&alpha_n)); - params.push_back(*reinterpret_cast(&alpha_p)); - params.push_back(*reinterpret_cast(&beta)); - params.push_back(*reinterpret_cast(&eps)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_n)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_p)); + params.push_back(ggml_webgpu_u32_from_f32(beta)); + params.push_back(ggml_webgpu_u32_from_f32(eps)); break; } default: @@ -1384,71 +2048,61 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s } else if (dst->op == GGML_OP_CLAMP) { float clamp_min = ggml_get_op_params_f32(dst, 0); float clamp_max = ggml_get_op_params_f32(dst, 1); - params.push_back(*reinterpret_cast(&clamp_min)); - params.push_back(*reinterpret_cast(&clamp_max)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_min)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_max)); } else if (dst->op == GGML_OP_FILL) { float fill_val = ggml_get_op_params_f32(dst, 0); - params.push_back(*reinterpret_cast(&fill_val)); + params.push_back(ggml_webgpu_u32_from_f32(fill_val)); effective_src = dst; // fill simply fills dst } std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(effective_src), - .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src), - .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, effective_src), }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = flags.inplace, - .overlap = flags.overlap, - .src_overlap = flags.src_overlap, - }; - - webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); +static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - auto * decisions = static_cast(pipeline.context.get()); + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); - uint32_t offset_merged_src0 = 0; - uint32_t offset_merged_src1 = 0; - if (flags.src_overlap) { - size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); - offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); - offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); } std::vector params = { ne, - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + offset_src0, + offset_src1, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - offset_merged_src0, - offset_merged_src1, (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), @@ -1468,53 +2122,30 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, std::vector entries; - if (flags.src_overlap) { - size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); - size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), - src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = merged_offset, - .size = merged_end - merged_offset, - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = src0_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = src1_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); - if (!flags.inplace && !flags.overlap) { - entries.push_back({ - .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), + src0_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src0))); + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), + src1_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src1))); + if (!decisions->inplace && !decisions->overlap) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; @@ -1540,34 +2171,24 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { ne, @@ -1587,31 +2208,106 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src (uint32_t) (dst->ne[2]) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); +static std::optional ggml_webgpu_rms_norm_mul(webgpu_context & ctx, + ggml_tensor * rn_src, + ggml_tensor * rn_dst, + ggml_tensor * mul_src0, + ggml_tensor * mul_src1, + ggml_tensor * dst) { + ggml_tensor * mul_src; + + if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) { + mul_src = mul_src1; + } else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) { + mul_src = mul_src0; + } else { + GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); + } + + uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)); + uint32_t offset_mul_src = + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + + std::vector params = { + offset_rn_src, + offset_mul_src, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)), + (uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) mul_src->ne[0], + (uint32_t) mul_src->ne[1], + (uint32_t) mul_src->ne[2], + (uint32_t) mul_src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader + }; + + std::vector entries; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = rn_src; + shader_lib_ctx.src1 = mul_src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range); + offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range); + params[0] = offset_rn_src; + params[1] = offset_mul_src; + } + + if (decisions->inplace || decisions->overlap) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + } else if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); +} +static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1625,33 +2321,42 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], - *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; - if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params, - entries, ggml_nrows(src)); + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); } -static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_freq_factor = (src2 != nullptr); +static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + const bool inplace = decisions->inplace; + const int has_freq_factor = (src2 != nullptr); const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -1695,49 +2400,47 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, (uint32_t) src0->ne[2], (uint32_t) n_dims, (uint32_t) mode, - *(uint32_t *) &theta_scale, - *(uint32_t *) &attn_factor, - *(uint32_t *) &freq_scale, - *(uint32_t *) &ext_factor, - *(uint32_t *) &corr_dims[0], - *(uint32_t *) &corr_dims[1], + ggml_webgpu_u32_from_f32(theta_scale), + ggml_webgpu_u32_from_f32(attn_factor), + ggml_webgpu_u32_from_f32(freq_scale), + ggml_webgpu_u32_from_f32(ext_factor), + ggml_webgpu_u32_from_f32(corr_dims[0]), + ggml_webgpu_u32_from_f32(corr_dims[1]), (uint32_t) sections[0], (uint32_t) sections[1], (uint32_t) sections[2], (uint32_t) sections[3] }; - - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - uint32_t dst_binding = 2; + + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1) }; + uint32_t dst_binding = 2; if (has_freq_factor) { dst_binding = 3; - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); } if (!inplace) { - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); } - webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int split = (src1 != nullptr); std::vector params = { @@ -1760,45 +2463,31 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], - (uint32_t) ((int32_t *) dst->op_params)[1], // swapped - *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai - *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 2)), // alpha, for swiglu_oai + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; uint32_t dst_binding = 1; if (split) { dst_binding = 2; - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); - } - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); -} - -static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool inplace = ggml_webgpu_tensor_equal(src, dst); - - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); + + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1817,54 +2506,55 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &dst->op_params[1] // bias + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 1)) // bias }; // bindgroups unchanged - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; - if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); -} - -static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here - const int has_sink = (src2 != nullptr); - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + const bool inplace = decisions->inplace; + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), @@ -1872,83 +2562,63 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2], - mask_type < 2 ? (uint32_t) src1->ne[2] : 0, - mask_type < 2 ? (uint32_t) src1->ne[3] : 0, - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &max_bias, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + has_mask ? (uint32_t) src1->ne[2] : 0, + has_mask ? (uint32_t) src1->ne[3] : 0, + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) } - }; - uint32_t binding_num = 1; - if (mask_type < 2) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + std::vector entries = { ggml_webgpu_make_bind_group_entry( + 0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)) }; + uint32_t binding_num = 1; + if (has_mask) { + entries.push_back(ggml_webgpu_make_bind_group_entry(binding_num, ggml_webgpu_tensor_buf(src1), + ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1))); binding_num++; } if (has_sink) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, src2)); binding_num++; } if (!inplace) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, dst)); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, - ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); } -static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); @@ -1997,10 +2667,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1]; const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2]; - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst; const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); @@ -2017,21 +2684,16 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr const uint32_t wg_x_init = std::min(total_wg_init, max_wg); const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); std::vector init_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) }; - pipelines.push_back(argsort_pipeline); - params_list.push_back(std::move(init_params)); - entries_list.push_back(std::move(init_entries)); - workgroups_list.push_back({ wg_x_init, wg_y_init }); + dispatches.push_back({ + argsort_pipeline, std::move(init_params), std::move(init_entries), { wg_x_init, wg_y_init } + }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } bool in_is_tmp = start_in_tmp; @@ -2072,59 +2734,45 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr nrows }; std::vector merge_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in }, - { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), align_in, size_in), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; const uint32_t total_wg_merge = nm * nrows; const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg); const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge); - workgroups_list.push_back({ wg_x_merge, wg_y_merge }); - pipelines.push_back(argsort_merge_pipeline); - params_list.push_back(std::move(merge_params)); - entries_list.push_back(std::move(merge_entries)); + dispatches.push_back({ + argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge } + }); len <<= 1; in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, - workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool total_sum = dst->op == GGML_OP_SUM; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2135,36 +2783,105 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s total_sum ? 1 : (uint32_t) src->ne[1], total_sum ? 1 : (uint32_t) src->ne[2] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + return false; + } + + // additional constraints specific to this fusion + const ggml_tensor * rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + + return true; +} + +static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * src, ggml_tensor * dst) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(dst, 0); + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + mode_flags }; + + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional ggml_webgpu_encode(webgpu_context ctx, + ggml_cgraph * cgraph, + int node_idx, + int & num_encoded_ops) { + ggml_tensor ** nodes = cgraph->nodes; + ggml_tensor * node = nodes[node_idx]; + if (ggml_is_empty(node)) { return std::nullopt; } if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { return std::nullopt; } - WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; @@ -2181,18 +2898,18 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_CPY: case GGML_OP_CONT: return ggml_webgpu_cpy(ctx, src0, node); + case GGML_OP_SET: + return ggml_webgpu_set(ctx, src0, src1, node); case GGML_OP_SET_ROWS: return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, src0, src1, node); + case GGML_OP_MUL_MAT_ID: + return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: -#ifndef __EMSCRIPTEN__ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); -#else - return std::nullopt; -#endif case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -2203,7 +2920,15 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_REPEAT: return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: - return ggml_webgpu_rms_norm(ctx, src0, node); + if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) { + num_encoded_ops = 2; + ggml_tensor * mul_node = nodes[node_idx + 1]; + return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node); + } else { + return ggml_webgpu_row_norm(ctx, src0, node); + } + case GGML_OP_L2_NORM: + return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: @@ -2220,7 +2945,18 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + case GGML_OP_DIAG: + case GGML_OP_TRI: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SOLVE_TRI: + return ggml_webgpu_solve_tri(ctx, src0, src1, node); + case GGML_OP_SSM_CONV: + return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + case GGML_OP_SSM_SCAN: + return ggml_webgpu_ssm_scan(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node->src[6], + node); + case GGML_OP_GATED_DELTA_NET: + return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: @@ -2234,11 +2970,73 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SUM: case GGML_OP_SUM_ROWS: return ggml_webgpu_sum_rows(ctx, src0, node); + case GGML_OP_CONV_2D: + return ggml_webgpu_conv_2d(ctx, src0, src1, node); + case GGML_OP_IM2COL: + return ggml_webgpu_im2col(ctx, src0, src1, node); + case GGML_OP_UPSCALE: + return ggml_webgpu_upscale(ctx, src0, node); default: return std::nullopt; } } +#ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_collect_profile_results(webgpu_context & ctx, + const std::vector & pipeline_names, + uint32_t & num_inflight_batches) { + if (pipeline_names.empty()) { + return; + } + + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.ResolveQuerySet(ctx->profile_timestamp_query_set, 0, ctx->profile_timestamp_query_count, + ctx->profile_timestamp_dev_buf, 0); + encoder.CopyBufferToBuffer(ctx->profile_timestamp_dev_buf, 0, ctx->profile_timestamp_host_buf, 0, + ctx->profile_timestamp_query_count * sizeof(uint64_t)); + + wgpu::CommandBuffer profile_commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, profile_commands, num_inflight_batches); + + const size_t mapped_size = ctx->profile_timestamp_query_count * sizeof(uint64_t); + GGML_ASSERT(ctx->profile_timestamp_query_count == 2 * pipeline_names.size()); + + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->profile_timestamp_host_buf, wgpu::MapMode::Read, 0, + mapped_size); + const uint64_t * ts_data = (const uint64_t *) ctx->profile_timestamp_host_buf.GetConstMappedRange(0, mapped_size); + + for (size_t i = 0; i < pipeline_names.size(); ++i) { + // WebGPU timestamps are in ns; convert to ms. + const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; + ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + } + + ctx->profile_timestamp_host_buf.Unmap(); +} +#endif + +// Don't bother checking set_rows index overflow for now, since practically the WebGPU doesn't need to support +// models that would require it right now. +static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) { +#ifdef GGML_WEBGPU_CHECK_SET_ROWS + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, + ctx->set_rows_host_error_buf.GetSize()); + wgpu::CommandBuffer commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, commands, num_inflight_batches); + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, + ctx->set_rows_host_error_buf.GetSize()); + const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + ctx->set_rows_host_error_buf.Unmap(); +#else + GGML_UNUSED(ctx); + GGML_UNUSED(num_inflight_batches); +#endif +} + static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); @@ -2247,69 +3045,186 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector subs; - uint32_t num_batched_kernels = 0; - bool contains_set_rows = false; + std::vector commands; + + uint32_t num_batched_kernels = 0; + uint32_t num_inflight_batches = 0; + bool contains_set_rows = false; + bool batch_compute_passes = true; + int num_encoded_ops = 1; + int node_idx = 0; + +#ifdef GGML_WEBGPU_GPU_PROFILE + ctx->profile_timestamp_query_count = 0; + batch_compute_passes = false; + std::vector profile_pipeline_names; +#endif + + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { + while (node_idx < cgraph->n_nodes) { + if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; +#ifdef GGML_WEBGPU_GPU_PROFILE + profile_pipeline_names.insert(profile_pipeline_names.end(), cmd->pipeline_names.begin(), + cmd->pipeline_names.end()); +#endif } - if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); - // Process events and check for completed submissions - ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); + if (num_batched_kernels >= ctx->global_ctx->command_submit_batch_size) { + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + } + num_batched_kernels = 0; + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); + + // reset state for next batch + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } + ctx->param_arena.reset(); commands.clear(); } + + node_idx += num_encoded_ops; + num_encoded_ops = 1; } - if (!commands.empty()) { - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); + + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + ctx->active_compute_pass = nullptr; + } + + if (num_batched_kernels > 0) { + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); + ctx->param_arena.reset(); commands.clear(); } + ctx->active_command_encoder = nullptr; + +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_results(ctx, profile_pipeline_names, num_inflight_batches); +#endif - // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking. if (contains_set_rows) { - wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, - ctx->set_rows_host_error_buf.GetSize()); - wgpu::CommandBuffer set_rows_commands = encoder.Finish(); - ctx->global_ctx->queue.Submit(1, &set_rows_commands); - ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, - ctx->set_rows_host_error_buf.GetSize()); - const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - ctx->set_rows_host_error_buf.Unmap(); + ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches); } - ggml_backend_webgpu_wait(ctx->global_ctx, subs); WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } +struct ggml_backend_webgpu_event_context { + webgpu_global_context global_ctx; + wgpu::Future future; + bool recorded = false; +}; + +static ggml_backend_event_t ggml_backend_webgpu_device_event_new(ggml_backend_dev_t device) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) device->context; + + auto * event_ctx = new ggml_backend_webgpu_event_context(); + event_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + + auto * event = new ggml_backend_event; + event->device = device; + event->context = event_ctx; + return event; +} + +static void ggml_backend_webgpu_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + delete static_cast(event->context); + delete event; +} + +static void ggml_backend_webgpu_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + if (!event_ctx->recorded) { + return; + } + wgpu::WaitStatus status = + event_ctx->global_ctx->instance.WaitAny(event_ctx->future, WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + if (status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: event_synchronize timed out after %u ms\n", WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + event_ctx->recorded = false; +} + +static void ggml_backend_webgpu_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + + event_ctx->future = backend_ctx->webgpu_ctx->global_ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {}); + event_ctx->recorded = true; +} + +static void ggml_backend_webgpu_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_UNUSED(backend); + ggml_backend_webgpu_device_event_synchronize(nullptr, event); +} + +static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_UNUSED(backend); + auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; + + // Write aligned portion + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + + if (size % 4 != 0) { + // If size is not a multiple of 4, we need to memset the remaining bytes + size_t remaining_size = size % 4; + + // pack the remaining bytes into a uint32_t + uint32_t val32 = 0; + + for (size_t i = 0; i < remaining_size; i++) { + ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; + } + // memset the remaining bytes + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); + } +} + +static void ggml_backend_webgpu_synchronize(ggml_backend_t backend) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_wait_queue(backend_ctx->webgpu_ctx->global_ctx); +} + static ggml_backend_i ggml_backend_webgpu_i = { /* .get_name = */ ggml_backend_webgpu_name, /* .free = */ ggml_backend_webgpu_free, - /* .set_tensor_async = */ NULL, + /* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .synchronize = */ ggml_backend_webgpu_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_webgpu_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_webgpu_event_record, + /* .event_wait = */ ggml_backend_webgpu_event_wait, /* .graph_optimize = */ NULL, }; @@ -2350,7 +3265,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; @@ -2369,7 +3284,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); @@ -2386,17 +3301,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, // memset the remaining bytes ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); - } else { - // wait for WriteBuffer to complete - buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); } WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx); } @@ -2412,7 +3316,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, << ", " << offset << ", " << size << ")"); wgpu::Device device = buf_ctx->global_ctx->device; - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; size_t final_size = size; if (size % 4 != 0) { @@ -2469,6 +3373,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, /* .reset = */ NULL, // TODO: optional, think it coordinates with @@ -2537,6 +3443,89 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const ggml_tensor * sinks = tensor->src[4]; + if (Q && K && V) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = const_cast(Q); + shader_lib_ctx.src1 = const_cast(K); + shader_lib_ctx.src2 = const_cast(V); + shader_lib_ctx.src3 = const_cast(mask); + shader_lib_ctx.src4 = const_cast(sinks); + shader_lib_ctx.dst = const_cast(tensor); + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( + shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const uint32_t kv_tile = decisions.kv_tile; + + const uint32_t vec_nwg_cap = std::max( + 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + + const size_t align = + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2( + (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; + } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + } + break; + case GGML_OP_MUL_MAT_ID: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + if (src0 && src1) { + const size_t gathered_size = sizeof(uint32_t) * tensor->src[0]->ne[2] * tensor->src[1]->ne[2]; + const size_t gathered_count_ids_size = sizeof(uint32_t) * tensor->src[0]->ne[2]; + res = ROUNDUP_POW2( + res + gathered_size * 2 + gathered_count_ids_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment * 3, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; default: break; } @@ -2592,8 +3581,9 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct } static ggml_guid_t ggml_backend_webgpu_guid(void) { - static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast((void *) guid_str); + static ggml_guid guid = { 0x67, 0xc7, 0xa4, 0xb1, 0x78, 0x74, 0x4f, 0x51, + 0x9d, 0x65, 0x44, 0x6d, 0xe4, 0x1b, 0x82, 0x9a }; + return &guid; } static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { @@ -2603,153 +3593,11 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; - constants[1].key = "bytes_per_thread"; - constants[1].value = ctx->capabilities.memset_bytes_per_thread; - ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); -} - -static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); -} - -static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - webgpu_ctx->rms_norm_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants); - webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); -} - -static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); -} - -static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - // REGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); - - // GEGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); - - // SWIGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); - - // SWIGLU_OAI - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); - - // GEGLU_ERF - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); - - // GEGLU_QUICK - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); -} - -static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - // f32 (no mask) - webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); - - // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); - webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, - "soft_max_f32_mask_f32_sink_inplace", constants); - - // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); - webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); - webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, - "soft_max_f32_mask_f16_sink_inplace", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = ctx->capabilities.memset_bytes_per_thread; + ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { @@ -2787,19 +3635,25 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } #endif ctx->webgpu_global_ctx->adapter.GetInfo(&info); + ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); + ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + ctx->webgpu_global_ctx->capabilities.supports_subgroups = + ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); -#ifndef __EMSCRIPTEN__ - // Only support square f16 matrices of size 8 or 16 for now bool valid_subgroup_matrix_config = false; +#ifndef __EMSCRIPTEN__ + // Accept f16 subgroup matrix configurations (square or non-square). + // NVIDIA GPUs typically report square configs (e.g. 16x16x16), + // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). + // The shaders are already parameterized to handle any M/N/K dimensions. if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 && config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M; ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N; @@ -2809,8 +3663,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } } } - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; #endif + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. @@ -2821,11 +3675,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { #ifndef __EMSCRIPTEN__ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } #endif + if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) { + required_features.push_back(wgpu::FeatureName::Subgroups); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif @@ -2855,12 +3712,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { // Enable Dawn-specific toggles to increase native performance // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; wgpu::DawnTogglesDescriptor deviceTogglesDesc; deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.enabledToggleCount = 3; deviceTogglesDesc.disabledToggles = deviceDisabledToggles; deviceTogglesDesc.disabledToggleCount = 1; @@ -2881,19 +3738,11 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr); ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); - ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, ctx->webgpu_global_ctx->memset_params_buf, + WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + "memset_params_buf"); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); -#ifdef GGML_WEBGPU_GPU_PROFILE - // Initialize buffer pool for timestamp queries, used for profiling - ctx->webgpu_global_ctx->timestamp_query_buf_pool.init( - ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); -#endif - GGML_LOG_INFO( "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " "device_desc: %s\n", @@ -2907,9 +3756,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); - webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); + webgpu_ctx->param_arena.init( + webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, + webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, + webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); @@ -2917,11 +3767,19 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); - ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); - ggml_webgpu_init_rope_pipeline(webgpu_ctx); - ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_webgpu_create_buffer( + webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_host_buf, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "profile_timestamp_host_buf"); + wgpu::QuerySetDescriptor query_set_desc = {}; + query_set_desc.type = wgpu::QueryType::Timestamp; + query_set_desc.count = WEBGPU_MAX_PROFILE_QUERY_COUNT; + webgpu_ctx->profile_timestamp_query_set = webgpu_ctx->global_ctx->device.CreateQuerySet(&query_set_desc); +#endif + #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, @@ -2961,17 +3819,16 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { - /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ - ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ - ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ - ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ - ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false + /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, + /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, - /* .context = */ - NULL + /* .context = */ NULL }; return &ggml_backend_webgpu_buffer_type; @@ -2984,6 +3841,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm static bool ggml_webgpu_supported_qtype(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3053,6 +3911,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) || (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32); break; + case GGML_OP_SET: + supports_op = src0->type == src1->type && src0->type == op->type && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); + break; case GGML_OP_SET_ROWS: supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); @@ -3074,6 +3936,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3103,33 +3966,112 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } break; } + case GGML_OP_MUL_MAT_ID: + switch (src1->type) { + case GGML_TYPE_F16: + supports_op |= (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + supports_op = true; + break; + default: + break; + } + break; + default: + break; + } + break; case GGML_OP_FLASH_ATTN_EXT: { -#ifndef __EMSCRIPTEN__ - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + supports_op = src0->type == GGML_TYPE_F32 && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || + src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && + src2->type == src1->type && op->type == GGML_TYPE_F32; + if (!supports_op) { break; } - // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, - (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); - if (min_bytes > limit_bytes) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = op->src[3]; + shader_lib_ctx.src4 = op->src[4]; + shader_lib_ctx.dst = const_cast(op); + shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( + shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + supports_op = false; + break; + } + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } - supports_op = src0->type == GGML_TYPE_F32 && - (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || - src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && - src2->type == src1->type && op->type == GGML_TYPE_F32; -#endif + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } + break; + } + + if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + supports_op = false; + break; + } + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; case GGML_OP_ROPE: @@ -3192,6 +4134,40 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } } break; + case GGML_OP_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_DIAG: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_SOLVE_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + break; + case GGML_OP_CONV_2D: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + break; + case GGML_OP_IM2COL: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; + case GGML_OP_SSM_CONV: + supports_op = op->type == GGML_TYPE_F32; + break; + case GGML_OP_SSM_SCAN: + supports_op = op->type == GGML_TYPE_F32 && + src0->ne[0] <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + break; + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t s_v = (uint32_t) src2->ne[0]; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && + src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 && + op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 && + s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + } + break; case GGML_OP_CLAMP: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; @@ -3232,6 +4208,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SUM_ROWS: supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0); break; + case GGML_OP_UPSCALE: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; default: break; } @@ -3273,9 +4253,9 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { /* .supports_op = */ ggml_backend_webgpu_device_supports_op, /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft, /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_webgpu_device_event_new, + /* .event_free = */ ggml_backend_webgpu_device_event_free, + /* .event_synchronize = */ ggml_backend_webgpu_device_event_synchronize, }; /* End GGML Backend Device Interface */ @@ -3330,9 +4310,25 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - static ggml_backend_webgpu_reg_context ctx; - ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 1; + // Intentionally leak the global registry context to avoid crashing inside + // Dawn/Vulkan static teardown during process exit. + static ggml_backend_webgpu_reg_context * ctx = new ggml_backend_webgpu_reg_context(); + + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_webgpu_reg_i, + /* .context = */ ctx, + }; + + ctx->name = GGML_WEBGPU_NAME; + ctx->device_count = 0; + + // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend + // registry. Recreating it on repeated registry lookups can invalidate + // adapter/device references that are still held by the backend/device layer. + if (ctx->webgpu_global_ctx != nullptr && ctx->webgpu_global_ctx->instance != nullptr) { + return ® + } wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; @@ -3347,23 +4343,33 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); - ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); - ctx.webgpu_global_ctx->instance = std::move(inst); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx->webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx->webgpu_global_ctx->instance = std::move(inst); + + // Probe for adapter support + wgpu::Adapter adapter; + if (ctx->webgpu_global_ctx->instance != nullptr) { + wgpu::RequestAdapterOptions options = {}; + + // probe for adapter support + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + adapter = std::move(_adapter); + }), + UINT64_MAX); + } -#ifdef __EMSCRIPTEN__ - if (ctx.webgpu_global_ctx->instance == nullptr) { - GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); - return nullptr; + if (adapter != nullptr) { + ctx->device_count = 1; } -#endif - GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr); - static ggml_backend_reg reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, - }; return ® } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index a748dc1b86c..605de7aa7be 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -7,8 +7,6 @@ struct Params { offset_src0: u32, offset_src1: u32, offset_dst: u32, - offset_merged_src0: u32, - offset_merged_src1: u32, stride_src0_0: u32, stride_src0_1: u32, @@ -134,8 +132,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) { @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x < params.ne) { - let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); + let src0_i = params.offset_src0 + src0_index(gid.x); + let src1_i = params.offset_src1 + src1_index(gid.x); update(params.offset_dst + gid.x, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 9a5b18ebc07..14c045b0ba6 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -8,12 +8,74 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { } #endif -#ifdef Q4_0_T -struct q4_0 { - d: f16, - qs: array -}; +#ifdef U32_DEQUANT_HELPERS +#ifdef DECLARE_BYTE_LOADERS_SRC +fn load_u16_at_src(byte_offset: u32) -> u32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_u32_at_src(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src[word_idx]; + let hi = src[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); +} + +fn load_f16_at_src(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src(byte_offset)); + return f16(packed[0]); +} + +fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; + return unpack2x16float(d_bits)[0]; +} +#endif + +#ifdef DECLARE_BYTE_LOADERS_SRC0 +fn load_u16_at_src0(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +// Always reads the 4-byte-aligned word containing byte_offset. +// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u. +// this is used in k-quants for better performance +fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 { + return src0[(byte_offset & ~3u) / 4u]; +} + +fn load_u32_at_src0(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src0[word_idx]; + let hi = src0[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); +} + +fn load_f16_at_src0(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src0(byte_offset)); + return f16(packed[0]); +} + +fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; + return unpack2x16float(d_bits)[0]; +} #endif +#endif + + #ifdef Q4_1_T struct q4_1 { @@ -23,13 +85,6 @@ struct q4_1 { }; #endif -#ifdef Q5_0_T -struct q5_0 { - d: f16, - qh: array, - qs: array -}; -#endif #ifdef Q5_1_T struct q5_1 { @@ -40,12 +95,6 @@ struct q5_1 { }; #endif -#ifdef Q8_0_T -struct q8_0 { - d: f16, - qs: array -}; -#endif #ifdef Q8_1_T struct q8_1 { @@ -64,14 +113,6 @@ struct q2_K { }; #endif -#ifdef Q3_K_T -struct q3_K { - hmask: array, - qs: array, - scales: array, - d: f16 -}; -#endif #if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array) -> vec2 { @@ -108,64 +149,6 @@ struct q5_K { }; #endif -#ifdef Q6_K_T -struct q6_K { - ql: array, - qh: array, - scales: array, - d: f16 -}; -#endif - -#ifdef IQ2_XXS_T -struct iq2_xxs { - d: f16, - qs: array -}; -#endif - -#ifdef IQ2_XS_T -struct iq2_xs { - d: f16, - qs: array, - scales: array -}; -#endif - -#ifdef IQ2_S_T -struct iq2_s { - d: f16, - qs: array, - qh: array, - scales: array -}; -#endif - -#ifdef IQ3_XXS_T -struct iq3_xxs { - d: f16, - qs: array -}; -#endif - -#ifdef IQ3_S_T -struct iq3_s { - d: f16, - qs: array, - qh: array, - signs: array, - scales: array -}; -#endif - -#ifdef IQ1_S_T -struct iq1_s { - d: f16, - qs: array, - qh: array -}; -#endif - #ifdef IQ1_M_T struct iq1_m { qs: array, @@ -174,17 +157,9 @@ struct iq1_m { }; #endif -#ifdef IQ4_NL_T -struct iq4_nl { - d: f16, - qs: array, -}; -#endif - #ifdef IQ4_XS_T struct iq4_xs { - d: f16, - scales_h: f16, + d_scales_h: u32, scales_l: u32, qs: array }; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl new file mode 100644 index 00000000000..9eb131dc221 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl @@ -0,0 +1,165 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(WEIGHT_F32) +var weights: array; +#elif defined(WEIGHT_F16) +var weights: array; +#endif + +@group(0) @binding(1) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(2) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_w: u32, + offset_i: u32, + offset_o: u32, + + // element strides + sw0: u32, sw1: u32, sw2: u32, sw3: u32, + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + // kernel dimensions + KW: u32, KH: u32, IC: u32, + // input dimensions + IW: u32, IH: u32, + // output dimensions + OW: u32, OH: u32, OC_out: u32, N_out: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +}; + +@group(0) @binding(3) +var params: Params; + +fn load_weight(idx: u32) -> f32 { + #if defined(WEIGHT_F32) + return weights[idx]; + #elif defined(WEIGHT_F16) + return f32(weights[idx]); + #endif +} + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +fn ceil_div_u32(x: u32, y: u32) -> u32 { + return (x + y - 1) / y; +} + +// returns the first valid kernel index k such that base + k * step >= 0 +fn first_valid_k(base: i32, step: u32) -> u32 { + if (base >= 0) { + return 0; + } + + return ceil_div_u32(u32(-base), step); +} + +// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k) +fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 { + let remaining = i32(limit) - base; + if (remaining <= 0) { + return 0; + } + + return min(k_max, ceil_div_u32(u32(remaining), step)); +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let n_out = params.OW * params.OH * params.OC_out * params.N_out; + + var sum: f32 = 0.0; + if (i_out >= n_out) { + return; + } + + // Kernel layout: [KW, KH, IC, ..] + // Input layout: [IW, IH, .., ..] + // Output layout: [OW, OH, OC, N] + + var i = i_out; + let n = i / (params.OC_out * params.OH * params.OW); + i = i % (params.OC_out * params.OH * params.OW); + let oc = i / (params.OH * params.OW); + i = i % (params.OH * params.OW); + let oh = i / params.OW; + let ow = i % params.OW; + + let ow_base = i32(ow * params.s0) - i32(params.p0); + let oh_base = i32(oh * params.s1) - i32(params.p1); + + // clip the valid kernel window once + let kw_begin = first_valid_k(ow_base, params.d0); + let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW); + let kh_begin = first_valid_k(oh_base, params.d1); + let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH); + + // entire receptive field is out of bounds + if (kw_begin >= kw_end || kh_begin >= kh_end) { + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, 0.0); + return; + } + + let weight_oc_base = params.offset_w + oc * params.sw3; + let input_n_base = params.offset_i + n * params.si3; + + for (var ic: u32 = 0; ic < params.IC; ic += 1) { + let w_base_ic = ic * params.sw2 + weight_oc_base; + let in_base = ic * params.si2 + input_n_base; + + for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) { + let ih = u32(oh_base + i32(kh * params.d1)); + let w_row_base = w_base_ic + kh * params.sw1; + let in_row_base = in_base + ih * params.si1; + for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) { + let iw = u32(ow_base + i32(kw * params.d0)); + let w_idx = w_row_base + kw * params.sw0; + let in_idx = in_row_base + iw * params.si0; + sum += load_weight(w_idx) * load_input(in_idx); + } + } + } + + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, sum); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl similarity index 60% rename from ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index b5e93b812fd..fa3bdf4e393 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -1,66 +1,41 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "i32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f32" - } - } -] +enable f16; -#end(VARIANTS) +#ifdef SRC_F32 +#define SRC_TYPE f32 +#elif defined(SRC_F16) +#define SRC_TYPE f16 +#endif -#define(SHADER) -enable f16; +#ifdef DST_F32 +#define DST_TYPE f32 +#elif defined(DST_F16) +#define DST_TYPE f16 +#elif defined(DST_I32) +#define DST_TYPE i32 +#endif @group(0) @binding(0) -var src: array<{{SRC_TYPE}}>; +var src: array; @group(0) @binding(1) -var dst: array<{{DST_TYPE}}>; +var dst: array; -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements +struct Params{ + ne: u32, + offset_src: u32, + offset_dst: u32, - // Strides (in elements) — may be permuted stride_src0: u32, stride_src1: u32, stride_src2: u32, stride_src3: u32, + stride_dst0: u32, stride_dst1: u32, stride_dst2: u32, stride_dst3: u32, - // Logical shapes src_ne0: u32, src_ne1: u32, src_ne2: u32, @@ -73,8 +48,7 @@ struct Params { @group(0) @binding(2) var params: Params; -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { return; @@ -102,6 +76,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + j2 * params.stride_dst2 + j3 * params.stride_dst3; - dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); + dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); } -#end(SHADER) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 8b5cfe715e7..79a3a9597ab 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -1,41 +1,8 @@ import os import re -import ast import argparse -def extract_block(text, name): - pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' - match = re.search(pattern, text, re.DOTALL) - if not match: - raise ValueError(f"Missing block: {name}") - return match.group(1).strip() - - -def parse_decls(decls_text): - decls = {} - for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): - decls[name.strip()] = code.strip() - return decls - - -def replace_repl_placeholders(variant, template_map): - for repl, code in variant["REPLS"].items(): - for key, val in template_map.items(): - # Match "key" and avoid matching subsequences using by using \b - code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) - variant["REPLS"][repl] = code - return variant - - -def replace_placeholders(shader_text, replacements): - for key, val in replacements.items(): - # Match {{KEY}} literally, where KEY is escaped - pattern = r'{{\s*' + re.escape(key) + r'\s*}}' - shader_text = re.sub(pattern, str(val), shader_text) - return shader_text - - def expand_includes(shader, input_dir): """ Replace #include "file" lines in the text with the contents of that file. @@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') -def generate_variants(fname, input_dir, output_dir, outfile): - shader_path = os.path.join(input_dir, fname) - shader_base_name = fname.split(".")[0] - - with open(shader_path, "r", encoding="utf-8") as f: - text = f.read() - - try: - variants = ast.literal_eval(extract_block(text, "VARIANTS")) - except ValueError: - write_shader(shader_base_name, text, output_dir, outfile, input_dir) - else: - try: - decls_map = parse_decls(extract_block(text, "DECLS")) - except ValueError: - decls_map = {} - try: - templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) - except ValueError: - templates_map = {} - - for fname in sorted(os.listdir(input_dir)): - if fname.endswith(".tmpl"): - tmpl_path = os.path.join(input_dir, fname) - with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: - decls = f_tmpl.read() - decls_map.update(parse_decls(decls)) - - shader_template = extract_block(text, "SHADER") - for variant in variants: - if "DECLS" in variant: - decls = variant["DECLS"] - else: - decls = [] - decls_code = "" - for key in decls: - if key not in decls_map: - raise ValueError(f"DECLS key '{key}' not found.") - decls_code += decls_map[key] + "\n\n" - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) - if "REPLS" in variant: - variant = replace_repl_placeholders(variant, templates_map) - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - # second run to expand placeholders in repl_template - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - final_shader = expand_includes(final_shader, input_dir) - - if "SHADER_NAME" in variant: - output_name = variant["SHADER_NAME"] - elif "SHADER_SUFFIX" in variant: - output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] - elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) - elif "REPLS" in variant and "TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] - else: - output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile, input_dir) - - def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", required=True) parser.add_argument("--output_file", required=True) - parser.add_argument("--output_dir") args = parser.parse_args() - if args.output_dir: - os.makedirs(args.output_dir, exist_ok=True) - with open(args.output_file, "w", encoding="utf-8") as out: out.write("// Auto-generated shader embedding\n") out.write("#include \n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): - generate_variants(fname, args.input_dir, args.output_dir, out) + shader_path = os.path.join(args.input_dir, fname) + shader_name = fname.replace(".wgsl", "") + + with open(shader_path, "r", encoding="utf-8") as f: + shader_code = f.read() + + write_shader(shader_name, shader_code, None, out, args.input_dir) if __name__ == "__main__": diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index b6822161464..6d5d69fb8de 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix; #ifdef KV_F32 #define KV_TYPE f32 +#elif defined(KV_Q4_0) || defined(KV_Q8_0) +#define KV_TYPE u32 #else #define KV_TYPE f16 #endif @@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix; #define NQ 16 // Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights #define F16_PER_BLOCK 9 +#define BLOCK_SIZE_BYTES 18u #define WEIGHTS_PER_F16 4 #elif defined(KV_Q8_0) #define NQ 8 // Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights #define F16_PER_BLOCK 17 +#define BLOCK_SIZE_BYTES 34u #define WEIGHTS_PER_F16 2 #endif #define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) @@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#if defined(KV_Q4_0) || defined(KV_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} +#endif + struct Params { offset_q: u32, offset_k: u32, @@ -93,26 +138,55 @@ struct Params { }; @group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +@group(0) @binding(1) var K: array; +#define V K +#else @group(0) @binding(1) var K: array; @group(0) @binding(2) var V: array; +#endif #if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; #define DST_BINDING 5 #define PARAMS_BINDING 6 +#endif #elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else @group(0) @binding(3) var mask: array; #define DST_BINDING 4 #define PARAMS_BINDING 5 +#endif #elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else @group(0) @binding(3) var sinks: array; #define DST_BINDING 4 #define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 #else #define DST_BINDING 3 #define PARAMS_BINDING 4 #endif +#endif @group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; @@ -254,12 +328,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -282,12 +355,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; @@ -326,35 +398,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #endif for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { let inter_offset = kv_block * SG_MAT_N; - var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); + var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); - var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); + var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); + var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); #else - var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); + var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); #endif var t: u32 = 1u; for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) { let h0 = t * SG_MAT_K; - var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); + var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); + var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); #else - var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); + var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = q0; k_cur = k0; let h1 = (t + 1u) * SG_MAT_K; - var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); + var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); + var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); #else - var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); + var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = q1g; @@ -364,11 +436,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // handle odd tail if (t < HEAD_DIM_QK / SG_MAT_K) { let h = t * SG_MAT_K; - var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); + var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); + var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); #else - var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); + var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = qn; @@ -459,12 +531,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -487,12 +558,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; @@ -525,7 +595,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, head_dim_block < HEAD_DIM_V; head_dim_block += num_subgroups * SG_MAT_N) { // load O submatrix from shared memory - var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( &o_shmem, head_dim_block, false, @@ -533,7 +603,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { let p_offset = kv_block * SG_MAT_N; - var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( &inter_shmem, p_offset, false, @@ -544,7 +614,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #ifdef KV_DIRECT let v_block_row = kv_tile + kv_block * SG_MAT_N; let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &V, v_global_offset, false, @@ -552,7 +622,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); #else let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &kv_shmem, v_block_offset + head_dim_block, false, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl new file mode 100644 index 00000000000..37ea23b80c8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -0,0 +1,330 @@ +enable f16; +enable subgroups; + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 +#define KV_STAGE_STRIDE 64 +#define Q_TILE 4 +#define KV_TILE 64 +#define WG_SIZE 128 + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + q_per_kv: u32, + + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +@group(0) @binding(1) var K: array>; +#define V K +#else +@group(0) @binding(1) var K: array>; +@group(0) @binding(2) var V: array>; +#endif + +#if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif +#endif + +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; +const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; +const V_CHUNKS: u32 = HEAD_DIM_V / 4u; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; + +var q_shmem: array; +var kv_shmem: array; +var p_shmem: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + if (subgroup_size == 0u || num_subgroups < Q_TILE) { + return; + } + + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + let global_q_row = q_row_start + subgroup_id; + let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q; + +#ifdef MASK + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); + + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_tile_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_tile_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col] * params.scale, + head_q_row < params.seq_len_q)); + } + + workgroupBarrier(); + + var row_max = FLOAT_MIN; + var exp_sum = 0.0; + var out_regs: array, OUT_REGS_PER_LANE>; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] = vec4(0.0); + } + + let q_base = subgroup_id * HEAD_DIM_QK; + let subgroup_p_offset = subgroup_id * KV_TILE; + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); + let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size); + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + var local_scores: array; + for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) { + local_scores[slot] = FLOAT_MIN; + } + + for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + kv_shmem[kv_off + 0u] = k4.x; + kv_shmem[kv_off + 1u] = k4.y; + kv_shmem[kv_off + 2u] = k4.z; + kv_shmem[kv_off + 3u] = k4.w; + } + + workgroupBarrier(); + + var local_max = FLOAT_MIN; + if (row_active) { + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (kv_local >= kv_count) { + continue; + } + + let global_k_row = kv_tile + kv_local; + var dot_val = 0.0; + for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { + let q_off = q_base + chunk * 4u; + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let kv = vec4( + f32(kv_shmem[kv_off + 0u]), + f32(kv_shmem[kv_off + 1u]), + f32(kv_shmem[kv_off + 2u]), + f32(kv_shmem[kv_off + 3u])); + dot_val += dot(qv, kv); + } +#ifdef LOGIT_SOFTCAP + dot_val = params.logit_softcap * tanh(dot_val); +#endif +#ifdef MASK + let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row; + dot_val += slope * f32(mask[mask_idx]); +#endif + local_scores[slot] = dot_val; + local_max = max(local_max, dot_val); + } + } + + let tile_max = subgroupMax(local_max); + let new_max = max(row_max, tile_max); + let cur_exp = exp(row_max - new_max); + exp_sum *= cur_exp; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= cur_exp; + } + + var local_sum = 0.0; + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (row_active && kv_local < kv_count) { + let p = exp(local_scores[slot] - new_max); + p_shmem[subgroup_p_offset + kv_local] = p; + local_sum += p; + } + } + + workgroupBarrier(); + + for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + kv_shmem[kv_off + 0u] = v4.x; + kv_shmem[kv_off + 1u] = v4.y; + kv_shmem[kv_off + 2u] = v4.z; + kv_shmem[kv_off + 3u] = v4.w; + } + + workgroupBarrier(); + + let tile_sum = subgroupAdd(local_sum); + exp_sum += tile_sum; + row_max = new_max; + + if (row_active) { + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + + var acc = out_regs[reg_idx]; + for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { + let p = p_shmem[subgroup_p_offset + kv_local]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let v4 = vec4( + f32(kv_shmem[kv_off + 0u]), + f32(kv_shmem[kv_off + 1u]), + f32(kv_shmem[kv_off + 2u]), + f32(kv_shmem[kv_off + 3u])); + acc += p * v4; + } + out_regs[reg_idx] = acc; + } + } + + workgroupBarrier(); + } + +#ifdef SINKS + if (row_active) { + let sink_score = sinks[params.offset_sinks + head_idx]; + let sink_max = max(row_max, sink_score); + let sink_scale = exp(row_max - sink_max); + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= sink_scale; + } + exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max); + row_max = sink_max; + } +#endif + + if (row_active) { + let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base = dst_global_offset + subgroup_id * dst2_stride; + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + let dst_vec_index = (row_base + chunk * 4u) >> 2u; + dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl new file mode 100644 index 00000000000..b4f7c16c35d --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -0,0 +1,101 @@ +diagnostic(off, subgroup_uniformity); +enable f16; + +#define KV_TILE 32 +#define WG_SIZE 32 + +struct Params { + offset_mask: u32, + seq_len_q: u32, + seq_len_kv: u32, + stride_mask3: u32, + // Number of KV blocks and Q blocks per batch. + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q. + nblk0: u32, + nblk1: u32, +}; + +@group(0) @binding(0) var mask: array; +@group(0) @binding(1) var blk: array; +@group(0) @binding(2) var params: Params; + +const MASK_MIN: f32 = -65504.0; +const MASK_MAX: f32 = 65504.0; +var wg_min: array; +var wg_max: array; +var wg_any: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + // Dispatch mapping: + // - x indexes KV blocks + // - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk + let kv_blk = wg_id.x; + let y = wg_id.y; + let q_blk = y % params.nblk1; + let batch_idx = y / params.nblk1; + if (kv_blk >= params.nblk0) { + return; + } + + let q_start = q_blk; + let k_start = kv_blk * KV_TILE; + + let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3; + + // We keep min/max to classify: + // - fully masked (max <= MASK_MIN) + // - all-zero mask (min == 0 && max == 0) + // - mixed/general mask + var local_min = MASK_MAX; + var local_max = -MASK_MAX; + var local_any = 0u; + + let q_row = q_start; + if (q_row < params.seq_len_q) { + let row_base = mask_batch_base + q_row * params.seq_len_kv; + for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { + let k_col = k_start + k_rel; + if (k_col >= params.seq_len_kv) { + continue; + } + let mv = f32(mask[row_base + k_col]); + local_min = min(local_min, mv); + local_max = max(local_max, mv); + local_any = 1u; + } + } + + wg_min[local_id.x] = local_min; + wg_max[local_id.x] = local_max; + wg_any[local_id.x] = local_any; + workgroupBarrier(); + + // Thread 0 writes one state per block. + if (local_id.x == 0u) { + var mmin = wg_min[0]; + var mmax = wg_max[0]; + var many = wg_any[0]; + for (var i = 1u; i < WG_SIZE; i += 1u) { + mmin = min(mmin, wg_min[i]); + mmax = max(mmax, wg_max[i]); + many = max(many, wg_any[i]); + } + + var state = 0u; + if (many != 0u) { + if (mmax <= MASK_MIN) { + state = 0u; + } else if (mmin == 0.0 && mmax == 0.0) { + state = 2u; + } else { + state = 1u; + } + } + + let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk; + blk[blk_idx] = state; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl new file mode 100644 index 00000000000..9a0de82a56a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -0,0 +1,78 @@ +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; + +// Default values +#define HEAD_DIM_V 64 +#define WG_SIZE 128 + +struct Params { + nrows: u32, + seq_len_q: u32, + n_heads: u32, + offset_dst: u32, + nwg: u32, + tmp_data_base: u32, + tmp_stats_base: u32, +}; + +@group(0) @binding(0) var tmp: array; +@group(0) @binding(1) var dst: array>; +@group(0) @binding(2) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + let rid = wg_id.x; + if (rid >= params.nrows) { + return; + } + + let rows_per_batch = params.n_heads * params.seq_len_q; + let batch_idx = rid / rows_per_batch; + let rem = rid % rows_per_batch; + let head_idx = rem / params.seq_len_q; + let q_row = rem % params.seq_len_q; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; + + let thread = sg_inv_id; + if (params.nwg > subgroup_size) { + return; + } + + let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); + let active_thread = thread < params.nwg; + let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread); + let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread); + let m = subgroupMax(mi); + let ms = select(0.0, exp(mi - m), active_thread); + let s = subgroupAdd(si * ms); + let inv_s = select(0.0, 1.0 / s, s != 0.0); + + let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); + for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { + var weighted = vec4(0.0, 0.0, 0.0, 0.0); + if (active_thread) { + let src = row_tmp_base + thread * HEAD_DIM_V + elem_base; + weighted = vec4(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms; + } + + let sum_x = subgroupAdd(weighted.x); + let sum_y = subgroupAdd(weighted.y); + let sum_z = subgroupAdd(weighted.z); + let sum_w = subgroupAdd(weighted.w); + + if (thread == 0u) { + let dst_vec_index = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl new file mode 100644 index 00000000000..b1e234784a8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -0,0 +1,708 @@ +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + +#define KV_GRANULARITY 8 +#define KV_TILE 16 +#define WG_SIZE 64 +#ifndef VEC_NE +#define VEC_NE 4u +#endif + +#define KV_BLOCKS (KV_TILE / KV_GRANULARITY) + +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +#if defined(KV_Q4_0) +#define NQ 16 +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, + +#ifdef BLK + blk_base: u32, + blk_nblk0: u32, + blk_nblk1: u32, +#endif + + tmp_data_base: u32, + tmp_stats_base: u32, + nwg: u32, +}; + +@group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif +#define V K +#else +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(2) var V: array; +#else +@group(0) @binding(2) var V: array>; +#endif +#endif +#if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#else +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#ifdef BLK +#define BLK_BINDING 5 +#define TMP_BINDING 6 +#define DST_BINDING 7 +#define PARAMS_BINDING 8 +#else +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#endif +#endif +#elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#ifdef BLK +#define BLK_BINDING 3 +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else +@group(0) @binding(3) var mask: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#endif +#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +@group(0) @binding(3) var sinks: array; +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#else +#ifdef KV_OVERLAP +#define TMP_BINDING 2 +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#endif + +#ifdef BLK +@group(0) @binding(BLK_BINDING) var blk: array; +#endif +@group(0) @binding(TMP_BINDING) var tmp: array; +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + if (apply_mask) { + var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE); + v += select(mask_val, slope * mask_val, has_bias); + } +#endif + return v; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + // Vec path processes exactly one query row per workgroup, so subgroup 0 can + // keep the running softmax state in private storage. + var row_max = FLOAT_MIN; + var exp_sum = 0.0; + + for (var i = local_id.x; i < HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = params.seq_len_q; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let iwg = wg_id.x % params.nwg; + let base_wg_id = wg_id.x / params.nwg; + + // batch index + let batch_idx = base_wg_id / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let wg_in_batch = base_wg_id % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // Vec path handles one Q row per workgroup. + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let head = f32(head_idx); + let has_bias = params.max_bias > 0.0; + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); + + // load the single Q row into shared memory + for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) { + let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + elem_idx], + q_row_start < params.seq_len_q)); + } + + for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { +#ifdef BLK + let q_blk = q_row_start; + let kv_blk = kv_tile / KV_TILE; + let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; + let blk_state_local = blk[blk_idx]; +#else + let blk_state_local = 1u; +#endif + let blk_state = blk_state_local; + let skip_tile = blk_state == 0u; + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = f16(0.0); + } + + // load k tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(k4.x); + kv_shmem[elem_idx + 1u] = f16(k4.y); + kv_shmem[elem_idx + 2u] = f16(k4.z); + kv_shmem[elem_idx + 3u] = f16(k4.w); + } +#endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + if (!skip_tile) { + let num_of_threads = subgroup_size / VEC_NE; + let tx = sg_inv_id % num_of_threads; + let ty = sg_inv_id / num_of_threads; + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { + let kv_idx = kv_base + ty; + var partial_sum: f32 = 0.0; + let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; + if (kv_valid) { + for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { + let q_off = i * 4u; + + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); +#ifdef KV_DIRECT + let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + let kv = vec4(K[idx >> 2u]); +#else + let idx = kv_idx * HEAD_DIM_QK + (i * 4u); + let kv = vec4( + f32(kv_shmem[idx + 0u]), + f32(kv_shmem[idx + 1u]), + f32(kv_shmem[idx + 2u]), + f32(kv_shmem[idx + 3u])); +#endif + partial_sum += dot(qv, kv); + } + } + var sum = partial_sum; + // Reduce over tx threads (NL) for this ty stripe. + var tx_delta = num_of_threads >> 1u; + loop { + if (tx_delta == 0u) { + break; + } + let sh = subgroupShuffleDown(sum, tx_delta); + if (tx < tx_delta) { + sum += sh; + } + tx_delta >>= 1u; + } + + let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); + if (tx == 0u && kv_valid) { + inter_shmem[kv_idx] = f16(sum_bcast); + } + } + } + } + + +#ifdef MASK + let apply_mask = !skip_tile && (blk_state != 2u); + if (apply_mask) { + // load mask tile into shared memory for this KV block + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { + let global_k_col = kv_tile + elem_idx; + let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } + } +#else + let apply_mask = false; +#endif + + workgroupBarrier(); + + // online softmax + if (!skip_tile && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, slope, has_bias, apply_mask), + kv_valid); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + row_max = final_max; + exp_sum = exp_sum * cur_exp + total_exp_term; + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp); + } + } + + // load v tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(v4.x); + kv_shmem[elem_idx + 1u] = f16(v4.y); + kv_shmem[elem_idx + 2u] = f16(v4.z); + kv_shmem[elem_idx + 3u] = f16(v4.w); + } +#endif + + workgroupBarrier(); + + if (!skip_tile) { + // we have P (KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + let ne_threads : u32 = VEC_NE; + let nl_threads = max(1u, subgroup_size / ne_threads); + let tx_pv = sg_inv_id % nl_threads; + let ty_pv = sg_inv_id / nl_threads; + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { + for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { + var lo = vec4(0.0, 0.0, 0.0, 0.0); + for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { + let kv_idx = cc * ne_threads + ty_pv; + let v_row = kv_tile + kv_idx; + if (v_row >= params.seq_len_kv) { + continue; + } + + let p = f32(inter_shmem[kv_idx]); +#ifdef KV_DIRECT + let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; + let v4 = vec4(V[v_idx >> 2u]); +#else + let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; + let v4 = vec4( + f32(kv_shmem[v_idx + 0u]), + f32(kv_shmem[v_idx + 1u]), + f32(kv_shmem[v_idx + 2u]), + f32(kv_shmem[v_idx + 3u])); +#endif + lo += p * v4; + } + + var lo_x = lo.x; + var lo_y = lo.y; + var lo_z = lo.z; + var lo_w = lo.w; + // Reduce over ty threads (NE) for this tx thread. + var ty_delta = ne_threads >> 1u; + loop { + if (ty_delta == 0u) { + break; + } + let thread_delta = ty_delta * nl_threads; + let shx = subgroupShuffleDown(lo_x, thread_delta); + let shy = subgroupShuffleDown(lo_y, thread_delta); + let shz = subgroupShuffleDown(lo_z, thread_delta); + let shw = subgroupShuffleDown(lo_w, thread_delta); + if (ty_pv < ty_delta) { + lo_x += shx; + lo_y += shy; + lo_z += shz; + lo_w += shw; + } + ty_delta >>= 1u; + } + + if (ty_pv == 0u) { + let elem_base = vec_col * 4u; + o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x); + o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y); + o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z); + o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w); + } + } + } + } + + workgroupBarrier(); + } + + +#ifdef SINKS + // Sinks are global terms and must be applied exactly once across split workgroups. + if (iwg == 0u && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0u); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + row_max = new_max; + exp_sum = exp_sum * max_exp + sink_exp_sum; + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp); + } + } + workgroupBarrier(); +#endif + let rows_per_batch = params.n_heads * params.seq_len_q; + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { + if (params.nwg == 1u) { + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base: u32 = params.offset_dst + batch_idx * dst3_stride + q_row_start * dst2_stride + + head_idx * HEAD_DIM_V; + + for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { + let v = vec4( + f32(o_shmem[elem_base + 0u]) * scale, + f32(o_shmem[elem_base + 1u]) * scale, + f32(o_shmem[elem_base + 2u]) * scale, + f32(o_shmem[elem_base + 3u]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } + } else { + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start; + let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; + let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let tbase = tmp_row_data_base + elem_base; + tmp[tbase + 0u] = f32(o_shmem[elem_base + 0u]); + tmp[tbase + 1u] = f32(o_shmem[elem_base + 1u]); + tmp[tbase + 2u] = f32(o_shmem[elem_base + 2u]); + tmp[tbase + 3u] = f32(o_shmem[elem_base + 3u]); + } + + if (sg_inv_id == 0u) { + tmp[tmp_row_stats_base + 0u] = exp_sum; + tmp[tmp_row_stats_base + 1u] = row_max; + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl new file mode 100644 index 00000000000..f9d98fda40b --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -0,0 +1,132 @@ +@group(0) @binding(0) +var src_q: array; + +@group(0) @binding(1) +var src_k: array; + +@group(0) @binding(2) +var src_v: array; + +@group(0) @binding(3) +var src_g: array; + +@group(0) @binding(4) +var src_beta: array; + +@group(0) @binding(5) +var src_state: array; + +@group(0) @binding(6) +var dst: array; + +struct Params { + h: u32, + n_tokens: u32, + n_seqs: u32, + s_off: u32, + + sq1: u32, + sq2: u32, + sq3: u32, + + sv1: u32, + sv2: u32, + sv3: u32, + + sb1: u32, + sb2: u32, + sb3: u32, + + neq1: u32, + rq3: u32, + scale: f32, +}; + +@group(0) @binding(7) +var params: Params; + +var sh_k: array; +var sh_q: array; +var sh_g: array; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let head_id = workgroup_id.x; + let seq_id = workgroup_id.y; + let col = local_id.x; + + let iq1 = head_id % params.neq1; + let iq3 = seq_id / params.rq3; + + let state_size = S_V * S_V; + let state_base = (seq_id * params.h + head_id) * state_size; + + var state: array; + for (var i = 0u; i < S_V; i++) { + state[i] = src_state[state_base + col * S_V + i]; + } + + var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; + + for (var t = 0u; t < params.n_tokens; t++) { + let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1; + let k_off = q_off; + let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1; + let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1; + + sh_q[col] = src_q[q_off + col]; + sh_k[col] = src_k[k_off + col]; + +#ifdef KDA + let g_base = gb_off * S_V; + sh_g[col] = exp(src_g[g_base + col]); +#endif + + workgroupBarrier(); + + let v_val = src_v[v_off + col]; + let beta_val = src_beta[gb_off]; + + var kv_col = 0.0; + var delta_col = 0.0; + var attn_col = 0.0; + +#ifdef KDA + for (var i = 0u; i < S_V; i++) { + kv_col += (sh_g[i] * state[i]) * sh_k[i]; + } + + delta_col = (v_val - kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#else + let g_val = exp(src_g[gb_off]); + + for (var i = 0u; i < S_V; i++) { + kv_col += state[i] * sh_k[i]; + } + + delta_col = (v_val - g_val * kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = g_val * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#endif + + dst[attn_off + col] = attn_col * params.scale; + attn_off += S_V * params.h; + + workgroupBarrier(); + } + + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_base + col * S_V + i] = state[i]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index b10800e36d2..5710cd35469 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC #include "common_decls.tmpl" + #ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; @@ -25,19 +27,38 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif +#ifdef Q1_0 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_byte_base = (src_base + offset) * 18; + let d = load_f16_as_f32_at_src(block_byte_base); + for (var j: u32 = 0u; j < 4u; j++) { + let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u); + let dst_base128 = dst_base + offset * 128u + j * 32u; + for (var k: u32 = 0; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + for (var bit: u32 = 0; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + dst[dst_base128 + k * 8u + bit] = w; + } + } + } +} +#endif + #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q4_0 = src[src_base + offset]; - let d = f32(block_q4_0.d); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at_src(block_byte_base); + for (var j: u32 = 0u; j < 4; j++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; + dst[dst_offset + 16u] = q_hi; } } } @@ -64,17 +85,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q5_0 = src[src_base + offset]; - let d = f32(block_q5_0.d); - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at_src(block_byte_base); + let qh_packed = load_u32_at_src(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at_src(q_byte_offset); + for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; dst[dst_offset + 16] = q_hi; @@ -106,14 +132,15 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q8_0 = src[src_base + offset]; - let d = f32(block_q8_0.d); - for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at_src(block_byte_base); + for (var j: u32 = 0u; j < 8u; j++) { + let q_byte_offset = block_byte_base + 2u + j * 4u; + let q_packed = load_u32_at_src(q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; - let dst_offset = dst_base + offset * 32 + j * 4 + k; + let dst_offset = dst_base + offset * 32u + j * 4u + k; dst[dst_offset] = q_val; } } @@ -152,36 +179,42 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q3_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes + + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at_src(block_byte_base + 108); - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at_src(block_byte_base + 96); + scale_vals[1] = load_u32_at_src(block_byte_base + 100); + scale_vals[2] = load_u32_at_src(block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; var is: u32 = 0; var m: u32 = 1; + // 2 halves of the block (128 elements each) for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { // 4 groups (each group has 2 blocks of 16 elements) @@ -191,11 +224,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sc = get_byte(scale_vals[is / 4], is % 4); is++; let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { + + for (var l: u32 = 0; l < 16; l++) { let q_idx = q_b_idx + k + l; let hm_idx = k + l; let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); let qs_val = (q_byte >> shift) & 3; dst[dst_i] = (f32(qs_val) - hm) * dl; @@ -268,21 +303,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q6_K // 16 blocks of 16 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes - // convert arrays of f16 -> u32 + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at_src(block_byte_base + 208); + + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at_src(block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + for (var i: u32 = 0; i < 16u; i++) { + qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -323,12 +364,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at_src(aux0_offset); + let aux1 = load_u32_at_src(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -345,15 +388,19 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif + + #ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; + var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); + for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); let db = array( @@ -361,7 +408,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at_src(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -379,21 +427,23 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; + var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array; + qh_vals[0] = load_u32_at_src(block_byte_base + 66); + qh_vals[1] = load_u32_at_src(block_byte_base + 70); + + var scale_vals: array; + scale_vals[0] = load_u32_at_src(block_byte_base + 74); + scale_vals[1] = load_u32_at_src(block_byte_base + 78); + for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); let db = array( @@ -419,16 +469,17 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at_src(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -448,18 +499,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; + var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); + var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4); } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at_src(block_byte_base + 106); + for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); let db = array( @@ -472,7 +527,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -493,14 +548,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -560,12 +615,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 32; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -579,8 +634,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); @@ -640,6 +695,35 @@ var params: Params; @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { +#ifdef FLOAT_PARALLEL + let blocks_per_row = params.ne0 / BLOCK_SIZE; + let row_count = params.n_rows * params.ne2 * params.ne3; + + if (gid.x >= blocks_per_row * row_count) { + return; + } + + let block_idx = gid.x % blocks_per_row; + var row_idx = gid.x / blocks_per_row; + let i_dst3 = row_idx / (params.ne2 * params.n_rows); + + row_idx = row_idx % (params.ne2 * params.n_rows); + let i_dst2 = row_idx / params.n_rows; + let i_dst1 = row_idx % params.n_rows; + + let i_idx2 = i_dst3 % params.idx2; + let i_idx1 = i_dst2 % params.idx1; + let i_idx0 = i_dst1; + + let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + + let idx_val = u32(idx[i_idx]); + + let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; + let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; + + copy_elements(i_src_row, i_dst_row, block_idx); +#else if (gid.x >= params.n_rows * params.ne2 * params.ne3) { return; } @@ -664,5 +748,5 @@ fn main(@builtin(global_invocation_id) gid: vec3) { for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { copy_elements(i_src_row, i_dst_row, i); } +#endif } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl deleted file mode 100644 index 03fcd548689..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +++ /dev/null @@ -1,323 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "reglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "geglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "swiglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_oai_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "swiglu_oai_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "geglu_erf_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_quick_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(REGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return max(a, 0) * b; -} -#enddecl(REGLU) - -#decl(GEGLU) -const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; -const GELU_COEF_A: {{TYPE}} = 0.044715; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); - return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; -} -#enddecl(GEGLU) - -#decl(SWIGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a / (1.0 + exp(-a)) * b; -} -#enddecl(SWIGLU) - -#decl(SWIGLU_OAI) -fn op(a: f32, b: f32) -> f32 { - let xi = min(a, params.limit); - let gi = max(min(b, params.limit), -params.limit); - var out_glu = xi / (1.0 + exp(-xi * params.alpha)); - out_glu = out_glu * (1.0 + gi); - return out_glu; -} -#enddecl(SWIGLU_OAI) - -#decl(GEGLU_ERF) -const p_erf: {{TYPE}} = 0.3275911; -const a1_erf: {{TYPE}} = 0.254829592; -const a2_erf: {{TYPE}} = -0.284496736; -const a3_erf: {{TYPE}} = 1.421413741; -const a4_erf: {{TYPE}} = -1.453152027; -const a5_erf: {{TYPE}} = 1.061405429; -const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let a_div_sqr2 = a * SQRT_2_INV; - let sign_x = sign(a_div_sqr2); - let x = abs(a_div_sqr2); - let t = 1.0 / (1.0 + p_erf * x); - let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); - let erf_approx = sign_x * y; - return 0.5 * a * (1.0 + erf_approx) * b; -} -#enddecl(GEGLU_ERF) - -#decl(GEGLU_QUICK) -const GELU_QUICK_COEF: {{TYPE}} = -1.702; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; -} -#enddecl(GEGLU_QUICK) - -#decl(NO_SPLIT) -@group(0) @binding(1) -var dst: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(0, params.ne0, params.swapped != 0); - return src0[base + offset]; -} - -fn b_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(params.ne0, 0, params.swapped != 0); - return src0[base + offset]; -} -#enddecl(NO_SPLIT) - -#decl(SPLIT) -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - return src0[base]; -} - -fn b_value(base: u32) -> {{TYPE}} { - return src1[base]; -} -#enddecl(SPLIT) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - swapped: u32, - alpha: f32, - limit: f32, -} - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; - let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; - let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; - - dst[i_dst] = op(a_value(i_a), b_value(i_b)); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl new file mode 100644 index 00000000000..e6d7608cec5 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl @@ -0,0 +1,155 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +#ifdef OP_REGLU +fn op(a: DataType, b: DataType) -> DataType { + return max(a, 0) * b; +} +#endif + +#ifdef OP_GEGLU +const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876; +const GELU_COEF_A: DataType = 0.044715; + +fn op(a: DataType, b: DataType) -> DataType { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b; +} +#endif + +#ifdef OP_SWIGLU +fn op(a: DataType, b: DataType) -> DataType { + return a / (1.0 + exp(-a)) * b; +} +#endif +#ifdef OP_SWIGLU_OAI +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#endif +#ifdef OP_GEGLU_ERF +const p_erf: DataType = 0.3275911; +const a1_erf: DataType = 0.254829592; +const a2_erf: DataType = -0.284496736; +const a3_erf: DataType = 1.421413741; +const a4_erf: DataType = -1.453152027; +const a5_erf: DataType = 1.061405429; +const SQRT_2_INV: DataType = 0.7071067811865476; + +fn op(a: DataType, b: DataType) -> DataType { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#endif +#ifdef OP_GEGLU_QUICK +const GELU_QUICK_COEF: DataType = -1.702; + +fn op(a: DataType, b: DataType) -> DataType { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array; + +#ifdef NO_SPLIT +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> DataType { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> DataType { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} + +#else +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> DataType { + return src0[base]; +} + +fn b_value(base: u32) -> DataType { + return src1[base]; +} + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl new file mode 100644 index 00000000000..386ebab879f --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl @@ -0,0 +1,101 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(1) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + KW: u32, KH: u32, IC: u32, + IW: u32, IH: u32, N: u32, + OW: u32, OH: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +} + +@group(0) @binding(2) +var params: Params; + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let K = params.KW * params.KH * params.IC; + let M = params.OW * params.OH; + let total = K * M * params.N; + + if (i_out >= total) { + return; + } + + // decode (k, m, n) + var i = i_out; + let n = i / (K * M); + i = i % (K * M); + let m = i / K; + let k = i % K; + + // decode (oh, ow) + let oh = m / params.OW; + let ow = m % params.OW; + + // decode (kw, kh, ic) + let kw = k % params.KW; + let tmp = k / params.KW; + let kh = tmp % params.KH; + let ic = tmp / params.KH; + + let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0); + let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1); + + if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) { + let iw = u32(iw_i32); + let ih = u32(ih_i32); + let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3; + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx)); + } else { + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 5b9f5b36224..fcbefdeb802 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef FLOAT const BLOCK_SIZE = 1u; @@ -20,11 +22,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_0 = src0[src0_idx_base + offset]; - let d = f32(block_q4_0.d); + let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; @@ -61,12 +64,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_0 = src0[src0_idx_base + offset]; - let d = f32(block_q5_0.d); + let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let qh_packed = load_u32_at_src0(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; @@ -107,12 +111,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_0 = src0[src0_idx_base + offset]; - let d = f32(block_q8_0.d); + let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at_src0(q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; @@ -178,31 +183,37 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q3_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes + + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at_src0(block_byte_base + 108); // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, // and 2-bits from the last 4 bytes + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at_src0(block_byte_base + 96); + scale_vals[1] = load_u32_at_src0(block_byte_base + 100); + scale_vals[2] = load_u32_at_src0(block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); } var sum = 0.0; @@ -301,21 +312,27 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q6_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes + + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at_src0(block_byte_base + 208); - // convert arrays of f16 -> u32 + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); } var sum = 0.0; @@ -358,13 +375,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at_src0(aux0_offset); + let aux1 = load_u32_at_src0(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -384,13 +403,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; + var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); + var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); @@ -399,7 +420,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -418,21 +440,23 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; + var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array; + qh_vals[0] = load_u32_at_src0(block_byte_base + 66); + qh_vals[1] = load_u32_at_src0(block_byte_base + 70); + + var scale_vals: array; + scale_vals[0] = load_u32_at_src0(block_byte_base + 74); + scale_vals[1] = load_u32_at_src0(block_byte_base + 78); + var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); @@ -460,17 +484,18 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at_src0(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -491,18 +516,22 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; + var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); + var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at_src0(block_byte_base + 106); + var sum = 0.0; for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); @@ -516,7 +545,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -538,15 +567,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -610,13 +639,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 32; var sum = 0.0; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -631,8 +660,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index de60ebbcf2b..51cf08f196f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -42,6 +42,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #endif // INIT_SRC0_SHMEM_FLOAT +#ifndef MUL_MAT_ID #ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { @@ -58,13 +59,47 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 } } #endif // INIT_SRC1_SHMEM_FLOAT +#endif + +#ifdef INIT_SRC0_SHMEM_Q1_0 +const BLOCK_SIZE = 128u; +const BLOCK_SIZE_BYTES = 18u; +const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let tile_m = i / TILE_K; + let tile_k_start = i % TILE_K; + let global_m = offset_m + tile_m; + let global_k_start = k_outer + tile_k_start; + + if (global_m >= params.m) { + break; + } + + let block_k = global_k_start / BLOCK_SIZE; + let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u; + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu; + + for (var bit = 0u; bit < NQ; bit++) { + let global_k = global_k_start + bit; + if (global_k < params.k) { + shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q1_0 #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -81,14 +116,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -104,10 +137,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_1 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 20u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -124,15 +157,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -149,11 +180,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_0 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 22u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights @@ -171,18 +202,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let qh0 = src0[scale_idx + 1u]; - let qh1 = src0[scale_idx + 2u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = load_f16_at_src0(block_byte_base); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -207,11 +234,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_1 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 24u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights @@ -229,20 +256,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; - let qh0 = src0[scale_idx + 2u]; - let qh1 = src0[scale_idx + 3u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -266,10 +289,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q8_0 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 34u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread @@ -286,14 +309,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -308,10 +329,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q8_1 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 36u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block @@ -328,15 +349,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -351,7 +370,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q2_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 42u; +const BLOCK_SIZE_BYTES = 84u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { // Use standard thread layout instead of lane/row_group @@ -371,10 +390,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx + 40u]; - let dmin = src0[scale_idx + 41u]; + let d = load_f16_at_src0(block_byte_base + 80u); + let dmin = load_f16_at_src0(block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -387,18 +406,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_0 = src0[scale_idx + 2u * (is / 4u)]; - let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u]; - let sc_packed = bitcast(vec2(sc_0, sc_1)); + let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -410,7 +425,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q3_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 55u; +const BLOCK_SIZE_BYTES = 110u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -429,9 +444,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx + 54u]; + let d = load_f16_at_src0(block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -439,9 +454,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - let scale_0 = src0[scale_idx + 48u + (2u*i)]; - let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -453,16 +466,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - let hmask_0 = src0[scale_idx + (2u*i)]; - let hmask_1 = src0[scale_idx + (2u*i) + 1u]; - hmask_vals[i] = bitcast(vec2(hmask_0, hmask_1)); + hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - let qs_0 = src0[scale_idx + 16u + (2u*i)]; - let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u]; - qs_vals[i] = bitcast(vec2(qs_0, qs_1)); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -502,7 +511,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 72u; +const BLOCK_SIZE_BYTES = 144u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -521,18 +530,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let dmin = src0[scale_idx + 1u]; - - // Load packed scales - var scale_vals: array; - for (var i: u32 = 0u; i < 3u; i++) { - let scale_0 = src0[scale_idx + 2u + (2u*i)]; - let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); - } + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // Map k_in_block to loop structure: // Outer loop over 64-element groups (alternating q_b_idx) @@ -549,15 +550,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var sc: u32; var mn: u32; + let scale_base = block_byte_base + 4u; + if (is < 4u) { - let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); - let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); - let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); - let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -567,9 +570,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -582,7 +583,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 88u; +const BLOCK_SIZE_BYTES = 176u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -601,18 +602,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let dmin = src0[scale_idx + 1u]; + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); - // Load packed scales - var scale_vals: array; - for (var i: u32 = 0u; i < 3u; i++) { - let scale_0 = src0[scale_idx + 2u + (2u*i)]; - let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); - } // The original loop processes elements in groups of 64 // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] @@ -633,15 +627,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var sc: u32; var mn: u32; + let scale_base = block_byte_base + 4u; + if (is < 4u) { - let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); - let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); - let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); - let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -651,15 +647,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)]; - let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u]; - let qh_packed = bitcast(vec2(qh_0, qh_1)); + let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -675,7 +667,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q6_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 105u; +const BLOCK_SIZE_BYTES = 210u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -694,7 +686,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let half = k_in_block / 128u; let pos_in_half = k_in_block % 128u; @@ -707,30 +699,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13_word = ql13_flat / 4u; - let ql13 = bitcast(vec2( - src0[scale_idx + 2u * ql13_word], - src0[scale_idx + 2u * ql13_word + 1u] - )); - let ql13_b = get_byte(ql13, ql13_flat % 4u); + let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); + let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24_word = ql24_flat / 4u; - let ql24 = bitcast(vec2( - src0[scale_idx + 2u * ql24_word], - src0[scale_idx + 2u * ql24_word + 1u] - )); - let ql24_b = get_byte(ql24, ql24_flat % 4u); + let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); + let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh_word = qh_flat / 4u; - let qh = bitcast(vec2( - src0[scale_idx + 64u + 2u * qh_word], - src0[scale_idx + 64u + 2u * qh_word + 1u] - )); - let qh_b = get_byte(qh, qh_flat % 4u); + let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); + let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); @@ -740,14 +720,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc_word = sc_idx / 4u; - let sc = bitcast(vec2( - src0[scale_idx + 96u + 2u * sc_word], - src0[scale_idx + 96u + 2u * sc_word + 1u] - )); - let sc_val = get_byte_i32(sc, sc_idx % 4u); + let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); + let sc_val = get_byte_i32(sc, 0u); - let d = src0[scale_idx + 104u]; + let d = load_f16_at_src0(block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { @@ -764,3 +740,426 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_Q6_K + +#ifdef INIT_SRC0_SHMEM_IQ4_NL +const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + + let pos = k_in_block % 16u; + let nib_shift = (k_in_block / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u); + let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_NL + +#ifdef INIT_SRC0_SHMEM_IQ4_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 136u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let d_scales_h = load_u32_at_src0(block_byte_base); + let d = bitcast>(d_scales_h).x; + let scales_h = d_scales_h >> 16u; + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu; + let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u; + let dl = d * f16(i32(ls_lo | ls_hi) - 32); + + let iqs = ib * 16u + (pos % 16u); + let nib_shift = (pos / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u); + let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_XS + +#ifdef INIT_SRC0_SHMEM_IQ1_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 50u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu; + let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + + let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_S + +#ifdef INIT_SRC0_SHMEM_IQ1_M +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 56u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let scales0 = load_u32_at_src0(block_byte_base + 48u); + let scales1 = load_u32_at_src0(block_byte_base + 52u); + let scale_packed = ((scales0 >> 12u) & 0xFu) | + ((scales0 >> 24u) & 0x00F0u) | + ((scales1 >> 4u) & 0x0F00u) | + ((scales1 >> 16u) & 0xF000u); + let d = f32(bitcast>(scale_packed).x); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let scales = select(scales0, scales1, ib >= 4u); + let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu; + let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u; + let dl = d * f32(2u * s_pair + 1u); + + let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u); + let qh = qh_word >> (16u * (ib % 2u)); + let qh_nib = (qh >> (4u * l)) & 0xFu; + + let qs_w = load_u32_at_src0(block_byte_base + ib * 4u); + let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u); + + let ig = idx * 8u; + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_M + +#ifdef INIT_SRC0_SHMEM_IQ2_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 66u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u); + let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u); + let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25; + + let ig = get_byte(aux0, l) * 8u; + let is = (aux1 >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XXS + +#ifdef INIT_SRC0_SHMEM_IQ2_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 74u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u); + let s = get_byte(scales_word, (ib % 16u) / 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u); + let qs_val = qs_word & 0xFFFFu; + let ig = (qs_val & 511u) * 8u; + let is = qs_val >> 9u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XS + +#ifdef INIT_SRC0_SHMEM_IQ2_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 82u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let l = (k_in_block % 32u) / 8u; + let j = k_in_block % 8u; + + let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u); + let s = get_byte(scales_word, ib % 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u); + let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u; + let ig = (get_byte(qs_word, l) | qh_b) * 8u; + + let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_S + +#ifdef INIT_SRC0_SHMEM_IQ3_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 98u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib_pair = k_in_block / 32u; + let in_pair = k_in_block % 32u; + let l = in_pair / 8u; + let in_l = in_pair % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let ib = ib_pair * 2u; + let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u; + let sc_sign = load_u32_at_src0(sc_sign_off); + let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5; + let is = (sc_sign >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu; + let ig_byte = get_byte(ig_word, k2); + let g = get_byte(iq3xxs_grid[ig_byte], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_XXS + +#ifdef INIT_SRC0_SHMEM_IQ3_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 110u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 64u; + let rest = k_in_block % 64u; + let k = rest / 32u; + let in_k = rest % 32u; + let l = in_k / 8u; + let in_l = in_k % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let scales_word = load_u32_at_src0(block_byte_base + 106u); + let s = get_byte(scales_word, ib); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u); + let dl = d * (1.0 + 2.0 * f32(s_nib)); + + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u); + let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu; + let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u); + let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u); + let ig = select(ig_lo, ig_hi, k2 != 0u); + + let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq3s_grid[ig], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_S diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl new file mode 100644 index 00000000000..91039ff2546 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl @@ -0,0 +1,195 @@ +enable f16; + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#include "mul_mat_decls.tmpl" + +#ifdef VEC +fn store_val(acc: array, TILE_N>, tn: u32, tm: u32) -> vec4 { + return vec4(f32(acc[tn][tm]), f32(acc[tn][tm + 1]), f32(acc[tn][tm + 2]), f32(acc[tn][tm + 3])); +} +#endif + +#ifdef SCALAR +fn store_val(acc: array, TILE_N>, tn: u32, tm: u32) -> f32 { + return f32(acc[tn][tm]); +} +#endif + +struct MulMatIdParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + k: u32, + m: u32, + n_expert: u32, + n_expert_used: u32, + n_tokens: u32, + b_ne1: u32, + + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, +}; + +@group(0) @binding(0) var src0: array; // [cols, rows, n_expert] +@group(0) @binding(1) var src1: array; // [cols, b_ne1, n_tokens] +@group(0) @binding(2) var dst: array; // [rows, n_expert_used, n_tokens] +@group(0) @binding(3) var global_gathered_expert_used: array; // [n_expert][n_tokens] +@group(0) @binding(4) var global_gathered_tokens: array; // [n_expert][n_tokens] +@group(0) @binding(5) var gathered_count_ids: array; // [n_expert] + +@group(0) @binding(6) var params: MulMatIdParams; + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var shmem: array; +var gathered_expert_used: array; +var gathered_tokens: array; + +#ifdef INIT_SRC1_SHMEM_FLOAT +fn init_shmem_id_src1(thread_id: u32, offset_src1: u32, rest_token_n: u32, k_outer: u32) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + if (tile_n < rest_token_n) { + let global_src10 = k_outer + tile_k; + let expert_used_idx = gathered_expert_used[tile_n] % params.b_ne1; + let token_idx = gathered_tokens[tile_n]; + let src1_idx = offset_src1 + token_idx * params.stride_12 + expert_used_idx * params.stride_11 + global_src10; + let src1_val = select( + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], + global_src10 < params.k); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); + } else { + store_shmem(SHMEM_TYPE(0.0), TILE_SRC0_SHMEM + elem_idx); + } + } +} +#endif // INIT_SRC1_SHMEM_FLOAT + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + var expert_idx:u32 = 0xFFFFFFFFu; + var wg_in_batch:u32 = 0; + var wg_sum:u32 = 0; + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + for (var i = 0u;i < params.n_expert;i += 1) { + let wg_n_count = (gathered_count_ids[i] + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_per_matrix = wg_m_count * wg_n_count; + if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { + expert_idx = i; + wg_in_batch = wg_linear - wg_sum; + break; + } + wg_sum += wg_per_matrix; + } + + let is_valid = expert_idx != 0xFFFFFFFFu; + + var wg_m: u32 = 0; + var wg_n: u32 = 0; + var offset_wg_m: u32 = 0; + var offset_wg_n: u32 = 0; + var rest_token_n: u32 = 0; + var src0_batch_offset: u32 = 0; + + wg_m = wg_in_batch % wg_m_count; + wg_n = wg_in_batch / wg_m_count; + + offset_wg_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + offset_wg_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + if (is_valid) { + rest_token_n = gathered_count_ids[expert_idx] - offset_wg_n; + let global_gathered_base = expert_idx * params.n_tokens + offset_wg_n; + for (var i = thread_id; i < TILE_N * WORKGROUP_SIZE_N && offset_wg_n + i < gathered_count_ids[expert_idx]; i += TOTAL_WORKGROUP_SIZE) { + gathered_expert_used[i] = global_gathered_expert_used[global_gathered_base + i]; + gathered_tokens[i] = global_gathered_tokens[global_gathered_base + i]; + } + src0_batch_offset = params.offset_src0 + expert_idx * params.stride_02; + } + + workgroupBarrier(); + + let output_row_base = offset_wg_m + local_m * TILE_M; + let output_col_base = offset_wg_n + local_n * TILE_N; + + let dst2_stride = params.m * params.n_expert_used; + let dst1_stride = params.m; + + var acc: array, TILE_N>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + if (is_valid) { + init_shmem_src0(thread_id, src0_batch_offset, offset_wg_m, k_outer); + init_shmem_id_src1(thread_id, params.offset_src1, rest_token_n, k_outer); + } + + workgroupBarrier(); + + if (is_valid) { + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tn][tm] += src0_tile[tm] * src1_val; + } + } + } + } + + workgroupBarrier(); + } + + if (is_valid) { + for (var tn = 0u; tn < TILE_N; tn++) { + let n_idx = output_col_base + tn; + if (n_idx < gathered_count_ids[expert_idx]) { + let dst1_idx = gathered_expert_used[n_idx - offset_wg_n]; + let dst2_idx = gathered_tokens[n_idx - offset_wg_n]; + let dst12_offset = params.offset_dst + dst2_idx * dst2_stride + dst1_idx * dst1_stride; + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst12_offset + global_row; + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); + } + } + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl new file mode 100644 index 00000000000..d79d5f3f282 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl @@ -0,0 +1,55 @@ +enable f16; + +struct MulMatIdGatherParams { + offset_ids: u32, + + n_expert: u32, + n_expert_used: u32, + n_tokens: u32, + + stride_ids_1: u32, +}; + +@group(0) @binding(0) var ids: array; // [n_expert_used, n_tokens] +@group(0) @binding(1) var global_gathered_expert_used: array; // [n_expert][n_tokens] +@group(0) @binding(2) var global_gathered_tokens: array; // [n_expert][n_tokens] +@group(0) @binding(3) var gathered_count_ids: array; // [n_expert] + +@group(0) @binding(4) var params: MulMatIdGatherParams; + +var count:atomic; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + + let thread_id = local_id.x; + let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup + + if (own_expert < params.n_expert) { + if (thread_id == 0u) { + atomicStore(&count, 0); + } + + workgroupBarrier(); + + for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { + let row = i / params.n_expert_used; + let col = i % params.n_expert_used; + let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); + if (own_expert == expert) { + let pos = atomicAdd(&count, 1u); + let gathered_id = own_expert * params.n_tokens + pos; + global_gathered_expert_used[gathered_id] = col; + global_gathered_tokens[gathered_id] = row; + } + } + + workgroupBarrier(); + + if (thread_id == 0u) { + gathered_count_ids[own_expert] = atomicLoad(&count); + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl new file mode 100644 index 00000000000..6ff9bcf2df0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -0,0 +1,154 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#include "mul_mat_vec_acc.tmpl" + +struct MulMatIdVecParams { + offset_src0: u32, + offset_src1: u32, + offset_ids: u32, + offset_dst: u32, + + k: u32, + m: u32, + n_expert: u32, + n_expert_used: u32, + b_ne1: u32, + + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, +}; + +@group(0) @binding(0) var src0: array; // [cols, rows, n_expert] +@group(0) @binding(1) var src1: array; // [cols, b_ne1, n_tokens(1)] +@group(0) @binding(2) var ids: array; // [n_experd_used, n_tokens(1)] +@group(0) @binding(3) var dst: array; // [rows, n_expert_used, n_tokens(1)] + +// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 +@group(0) @binding(4) var params: MulMatIdVecParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; +} + +var gathered_count_ids: array; +var gathered_expert_used: array; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 +#endif +) { + + let thread_id = local_id.x; + + for (var i = thread_id;i < params.n_expert;i += WG_SIZE) { + gathered_count_ids[i] = 0; + } + + workgroupBarrier(); + + // gather the selected experts for the target token. + for (var col = thread_id;col < params.n_expert_used;col += WG_SIZE) { + let expert = ids[params.offset_ids + col]; + gathered_count_ids[expert] = 1; + gathered_expert_used[expert] = col; + } + + workgroupBarrier(); + + let output_groups:u32 = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + var own_expert:u32 = 0; + var wg_in_batch:u32 = 0; + var wg_sum:u32 = 0; + + for (var i = 0u;i < params.n_expert;i += 1) { + let wg_vec_count = gathered_count_ids[i]; // 1 or 0 + let wg_per_matrix = output_groups * wg_vec_count; + if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { + own_expert = i; + wg_in_batch = wg_linear - wg_sum; + break; + } + wg_sum += wg_per_matrix; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + let dst1_stride = params.m; + + let src0_batch_offset = params.offset_src0 + own_expert * params.stride_02; + let src1_idx_base = params.offset_src1 + (gathered_expert_used[own_expert] % params.b_ne1) * params.stride_11; + let dst_idx_base = params.offset_dst + gathered_expert_used[own_expert] * dst1_stride + row_base; + + let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); + +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } + + workgroupBarrier(); + + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } + } +#endif + +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; + } + + workgroupBarrier(); + + var stride:u32 = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index b1da421a691..98bbdeb83ba 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -1,17 +1,19 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); } #endif #ifdef SCALAR -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return f32(acc[tm][tn]); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; } #endif @@ -98,7 +100,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; - var acc: array, TILE_M>; + var acc: array, TILE_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { @@ -122,7 +124,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let src1_idx = src1_n * TILE_K + k_inner; let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; for (var tm = 0u; tm < TILE_M; tm++) { - acc[tm][tn] += src0_tile[tm] * src1_val; + acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 9f9ef279f29..d86a72ce6e0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -3,9 +3,14 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" +// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. +// See https://github.com/ggml-org/llama.cpp/issues/21602 + #ifdef VEC fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4( @@ -193,4 +198,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 94f4bae11f4..a194cf40468 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,375 +1,12 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" -#ifdef VEC - -#define VEC_SIZE 4 -#define DST_TYPE vec4 -#define SRC0_TYPE vec4 -#define SRC1_TYPE vec4 - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(dot(SRC1_TYPE(src0_val), src1_val)); -} - -fn store_val(group_base: u32) -> vec4 { - return vec4(partial_sums[group_base], - partial_sums[group_base + THREADS_PER_OUTPUT], - partial_sums[group_base + THREADS_PER_OUTPUT * 2], - partial_sums[group_base + THREADS_PER_OUTPUT * 3]); -} -#endif - -#ifdef SCALAR - -#define VEC_SIZE 1 -#define DST_TYPE f32 -#define SRC0_TYPE SRC0_INNER_TYPE -#define SRC1_TYPE SRC1_INNER_TYPE - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(src0_val) * f32(src1_val); -} - -fn store_val(group_base: u32) -> f32 { - return partial_sums[group_base]; -} -#endif - -#ifdef MUL_ACC_FLOAT -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { - let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; - let b = shared_vector[i / VEC_SIZE]; - local_sum += inner_dot(a, b); - } - return local_sum; -} -#endif - -#ifdef MUL_ACC_Q4_0 - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; - } - } - } - return local_sum; -} -#endif - -#ifdef MUL_ACC_Q4_1 - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 10u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = f32(src0[scale_idx + 1u]); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; - } - } - } - return local_sum; -} -#endif - -#ifdef MUL_ACC_Q5_0 - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 11u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let qh0 = src0[scale_idx + 1u]; - let qh1 = src0[scale_idx + 2u]; - let qh_packed = bitcast(vec2(qh0, qh1)); - - for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); - - let j_adjusted = j + (block_offset / 2u); - - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; - } - - } - } - return local_sum; -} -#endif - - -#ifdef MUL_ACC_Q5_1 - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 12u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = src0[scale_idx + 1u]; - let qh0 = src0[scale_idx + 2u]; - let qh1 = src0[scale_idx + 3u]; - let qh_packed = bitcast(vec2(qh0, qh1)); - - for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); - - let j_adjusted = j + (block_offset / 2u); - - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); - - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; - } - - } - } - return local_sum; -} -#endif - - -#ifdef MUL_ACC_Q8_0 - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 17u; -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; - } - } - } - return local_sum; -} -#endif - - -#ifdef MUL_ACC_Q8_1 - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 18u; -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = src0[scale_idx + 1u]; - - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + f32(m); - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; - } - } - } - return local_sum; -} -#endif - -#ifdef MUL_ACC_Q6_K - -const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 105u; - -fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 { - let aligned = byte_offset & ~3u; - let idx = bbase + aligned / 2u; - return bitcast(vec2(src0[idx], src0[idx + 1u])); -} - -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; -} - -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); -} - -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let tid = tig / 2u; - let ix = tig % 2u; - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; - - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; - - let nb = tile_size / BLOCK_SIZE; - let k_block_start = k_outer / BLOCK_SIZE; - - // Aligned scale byte position (is can be odd) - let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; - - var local_sum = 0.0; - - for (var i = ix; i < nb; i += 2u) { - let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK; - - let d_raw = load_u32_at(bbase, 208u); - let d = f32(bitcast>(d_raw)[0]); - - let ql1_u32 = load_u32_at(bbase, q_offset_l); - let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u); - let qh_u32 = load_u32_at(bbase, 128u + q_offset_h); - let sc_u32_0 = load_u32_at(bbase, sc_base_byte); - let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u); - - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - - var sums = vec4(0.0, 0.0, 0.0, 0.0); - - for (var l = 0u; l < 4u; l++) { - let y_base = i * BLOCK_SIZE + y_offset + l; - let yl0 = f32(shared_vector[y_base]); - let yl1 = f32(shared_vector[y_base + 32u]); - let yl2 = f32(shared_vector[y_base + 64u]); - let yl3 = f32(shared_vector[y_base + 96u]); - - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += yl0 * dq0; - sums[1] += yl1 * dq1; - sums[2] += yl2 * dq2; - sums[3] += yl3 * dq3; - } - - local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); - } - - return local_sum; -} -#endif +#include "mul_mat_vec_acc.tmpl" struct MulMatParams { offset_src0: u32, @@ -390,27 +27,34 @@ struct MulMatParams { broadcast3: u32 }; -// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included -@group(0) @binding(0) var src0: array; // M rows, K columns -@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) +@group(0) @binding(0) var src0: array; +@group(0) @binding(1) var src1: array; +@group(0) @binding(2) var dst: array; +// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 @group(0) @binding(3) var params: MulMatParams; -const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; -// Shared memory for collaborative loading and reduction -var shared_vector: array; // Cache vector tile -var partial_sums: array; // For reduction +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; +} @compute @workgroup_size(WG_SIZE) fn main( @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 +#endif +) { let thread_id = local_id.x; - // Handle batch dimensions let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; let wg_linear = wg_id.y * num_wg.x + wg_id.x; let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; @@ -419,12 +63,7 @@ fn main( return; } - // Which of the outputs does this thread belong to? - let thread_group = thread_id / THREADS_PER_OUTPUT; - let thread_in_group = thread_id % THREADS_PER_OUTPUT; - - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs - let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; let dst2_stride = params.m * params.n; let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); @@ -435,46 +74,60 @@ fn main( let src02_idx = dst2_idx / params.broadcast2; let src12_idx = dst2_idx; - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; - var local_sum = 0.0; + let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { - let tile_size = min(TILE_K, params.k - k_tile); - - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { - shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } + } - workgroupBarrier(); + workgroupBarrier(); - if (output_row < params.m) { - local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; } + } +#endif - workgroupBarrier(); +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; } - // Store partial sums and reduce within each partition - partial_sums[thread_id] = local_sum; workgroupBarrier(); - let group_base = thread_group * THREADS_PER_OUTPUT; - let thread_base = group_base + thread_in_group; - var offset: u32 = THREADS_PER_OUTPUT / 2; - while (offset > 0) { - if (thread_in_group < offset) { - partial_sums[thread_base] += partial_sums[thread_base + offset]; + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } } - offset = offset / 2; + workgroupBarrier(); + stride = stride / 2; } - // Store back to global memory - if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { - dst[dst_idx / VEC_SIZE] = store_val(group_base); + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } } +#endif } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl new file mode 100644 index 00000000000..1f59bd14863 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -0,0 +1,1391 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#ifdef VEC +#define VEC_SIZE 4u +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); +} +#endif + +#ifdef SCALAR +#define VEC_SIZE 1u +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(src0_val) * f32(src1_val); +} +#endif + +#ifdef MUL_ACC_FLOAT +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q1_0 +#define BLOCK_SIZE 128 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 16 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[bit]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); + let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + + let y_offset = 128u * iq + 8u * ir + 4u * phase; + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_f16_at_src0(block_byte_base + 80u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } + } + + return acc; +} +#endif + + +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; + + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; + + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; + + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; + + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } + + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); + let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); + let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); + let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); + let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); + let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); + + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; + + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; + + let num_blocks = params.k / BLOCK_SIZE; + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var l = 0u; l < 4u; l++) { + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 208u)); + let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); + let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ1_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 50 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base)); + let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; + let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ1_M +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 56 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let sc_lo = load_u32_at_src0(block_byte_base + 48u); + let sc_hi = load_u32_at_src0(block_byte_base + 52u); + let sc0 = sc_lo & 0xFFFFu; + let sc1 = (sc_lo >> 16u) & 0xFFFFu; + let sc2 = sc_hi & 0xFFFFu; + let sc3 = (sc_hi >> 16u) & 0xFFFFu; + let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); + let d = f32(bitcast>(d_bits)[0]); + + let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), + select(sc0, sc1, sub_blk >= 2u), + sub_blk < 4u); + + let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); + let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; + let qh_lo = qh & 0xFFu; + let qh_hi = (qh >> 8u) & 0xFFu; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 66 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let ls = aux_hi >> 28u; + let db = d * (0.5 + f32(ls)) * 0.25; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 74 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(scales_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 82 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(sc_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ3_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 98 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); + let ls = aux >> 28u; + let db = d * (0.5 + f32(ls)) * 0.5; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ3_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); + let sc_word = load_u32_at_src0(block_byte_base + 106u); + let scales_byte = get_byte(sc_word, sub_blk / 2u); + let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let db = d * (1.0 + 2.0 * f32(sub_scale)); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ4_NL +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + i + 16u]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ4_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 136 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let y_offset = sub_blk * 32u + half * 16u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let scales_h = load_u16_at_src0(block_byte_base + 2u); + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let sl_byte = get_byte(scales_l_word, sub_blk / 2u); + let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; + let ls = i32(sl | (sh_bits << 4u)); + let dl = d * f32(ls - 32); + + let qs_byte_off = 8u + sub_blk * 16u; + let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); + let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); + let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); + let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); + + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl new file mode 100644 index 00000000000..fd20a4e54c9 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -0,0 +1,152 @@ +#ifdef OVERLAP + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif INPLACE + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif SRC_OVERLAP + +@group(0) @binding(0) +var merged_src: array; + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * merged_src[rn_src_offset] * merged_src[mul_src_offset]; +} + +#else + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#endif + +struct Params { + offset_rn_src: u32, + offset_mul_src: u32, + offset_dst: u32, + + stride_rn_src1: u32, + stride_rn_src2: u32, + stride_rn_src3: u32, + + stride_mul_src1: u32, + stride_mul_src2: u32, + stride_mul_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + mul_src_ne0: u32, + mul_src_ne1: u32, + mul_src_ne2: u32, + mul_src_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } +#ifdef SRC_OVERLAP + sum += pow(merged_src[i_rn_src_row + col], 2.0); +#else + sum += pow(rn_src[i_rn_src_row + col], 2.0); +#endif + col += WG_SIZE; + } + + scratch[lid.x] = sum; + + workgroupBarrier(); + + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); + + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_rn_src_row + col, i_dst_row + col, scale, i_mul_src_row + col % params.mul_src_ne0); + col += WG_SIZE; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl similarity index 73% rename from ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl index 84dc8dbff61..1c874e14240 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl @@ -1,138 +1,12 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f32_ff", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_ff_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f16_ff", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_ff_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(ROTATE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - dst[i_dst0] = {{TYPE}}(out0); - dst[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE) - -#decl(ROTATE_INPLACE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - src0[i_dst0] = {{TYPE}}(out0); - src0[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE_INPLACE) - -#decl(NO_FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return 1.0f; -} -#enddecl(NO_FF_FUNC) - -#decl(FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return src2[params.offset_src2 + i/2]; -} -#enddecl(FF_FUNC) - -#decl(NO_FF_BINDINGS) - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -#enddecl(NO_FF_BINDINGS) - -#decl(NO_FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var params: Params; - -#enddecl(NO_FF_BINDINGS_INPLACE) - -#decl(FF_BINDINGS) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var dst: array<{{TYPE}}>; - -@group(0) @binding(4) -var params: Params; - -#enddecl(FF_BINDINGS) - -#decl(FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var params: Params; - -#enddecl(FF_BINDINGS_INPLACE) - -#end(DECLS) - -#define(SHADER) - enable f16; +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + struct Params { offset_src0: u32, offset_src1: u32, @@ -168,12 +42,69 @@ struct Params { }; @group(0) @binding(0) -var src0: array<{{TYPE}}>; - +var src0: array; @group(0) @binding(1) var src1: array; -DECLS +#ifdef INPLACE + +#ifdef FF_FUNC + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#else + +@group(0) @binding(2) +var params: Params; + +#endif + +#else + +#ifdef FF_FUNC +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#ifdef FF_FUNC +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} + +#else +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#endif +#ifdef INPLACE +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = DataType(out0); + src0[i_dst1] = DataType(out1); +} +#else +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = DataType(out0); + dst[i_dst1] = DataType(out1); +} +#endif fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { let y = (f32(i / 2) - low) / max(0.001f, high - low); @@ -184,7 +115,7 @@ fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { // TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { var mscale = params.attn_factor; - var theta = params.freq_scale * theta_extrap; + var theta = params.freq_scale * theta_extrap; if (params.ext_factor != 0.0f) { let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; @@ -211,10 +142,9 @@ fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { } } -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { - // two elements per thread + // two elements per n_threads if (gid.x >= params.n_threads) { return; } @@ -290,6 +220,5 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let x0 = f32(src0[i_src]); let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); -} -#end(SHADER) +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl similarity index 79% rename from ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl index 712b921f1ab..bd8d32bded7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -1,47 +1,21 @@ -#define(VARIANTS) - -[ - { - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_SUFFIX": "inplace", - "DECLS": ["INPLACE"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) - +#ifdef INPLACE fn update(src_offset: u32, dst_offset: u32, scale: f32) { - dst[dst_offset] = scale * src[src_offset]; + src[dst_offset] = scale * src[src_offset]; } @group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) var params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - +#else fn update(src_offset: u32, dst_offset: u32, scale: f32) { - src[dst_offset] = scale * src[src_offset]; + dst[dst_offset] = scale * src[src_offset]; } @group(0) @binding(1) -var params: Params; - -#enddecl(INPLACE) - -#end(DECLS) +var dst: array; -#define(SHADER) +@group(0) @binding(2) +var params: Params; +#endif struct Params { offset_src: u32, // in elements @@ -68,12 +42,9 @@ struct Params { @group(0) @binding(0) var src: array; -DECLS +var scratch: array; -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { @@ -86,7 +57,7 @@ fn main(@builtin(workgroup_id) wid: vec3, let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - let elems = (params.ne0 + wg_size - 1) / wg_size; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; var sum = 0.0f; var col = lid.x; @@ -95,12 +66,12 @@ fn main(@builtin(workgroup_id) wid: vec3, break; } sum += pow(src[i_src_row + col], 2.0); - col += wg_size; + col += WG_SIZE; } scratch[lid.x] = sum; workgroupBarrier(); - var offset = wg_size / 2; + var offset: u32 = WG_SIZE / 2; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] += scratch[lid.x + offset]; @@ -110,14 +81,18 @@ fn main(@builtin(workgroup_id) wid: vec3, } sum = scratch[0]; +#ifdef RMS_NORM let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); +#elif defined(L2_NORM) + let scale = 1.0/max(sqrt(sum), params.eps); +#endif + col = lid.x; for (var j: u32 = 0; j < elems; j++) { if (col >= params.ne0) { break; } update(i_src_row + col, i_dst_row + col, scale); - col += wg_size; + col += WG_SIZE; } } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl new file mode 100644 index 00000000000..0a7ae9bdb2c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl @@ -0,0 +1,109 @@ +#ifdef TYPE_I32 +#define TYPE i32 +#else +#define TYPE f32 +#endif + +#ifndef INPLACE +@group(0) @binding(0) +var src0: array; +#define SRC1_BINDING 1 +#else +#define SRC1_BINDING 0 +#endif + +#define DST_BINDING SRC1_BINDING + 1 +#define PARAMS_BINDING SRC1_BINDING + 2 + +@group(0) @binding(SRC1_BINDING) +var src1: array; + +@group(0) @binding(DST_BINDING) +var dst: array; + +struct Params { + ne: u32, + offset_src0: u32, + offset_src1: u32, + offset_view: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst10: u32, + stride_dst11: u32, + stride_dst12: u32, + stride_dst13: u32, + + src1_ne0: u32, + src1_ne1: u32, + src1_ne2: u32, + src1_ne3: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var params: Params; + +fn decode_src1_coords(idx: u32) -> vec4 { + var i = idx; + let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0; + let i3 = i / plane; + i = i % plane; + let row = params.src1_ne1 * params.src1_ne0; + let i2 = i / row; + i = i % row; + let i1 = i / params.src1_ne0; + let i0 = i % params.src1_ne0; + return vec4(i0, i1, i2, i3); +} + +fn decode_view_coords(rel: u32) -> vec4 { + let i3 = rel / params.stride_dst13; + let rem3 = rel % params.stride_dst13; + let i2 = rem3 / params.stride_dst12; + let rem2 = rem3 % params.stride_dst12; + let i1 = rem2 / params.stride_dst11; + let i0 = rem2 % params.stride_dst11; + return vec4(i0, i1, i2, i3); +} + +fn view_rel_from_coords(coords: vec4) -> u32 { + return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 + + coords.z * params.stride_dst12 + coords.w * params.stride_dst13; +} + +fn src1_idx_from_coords(coords: vec4) -> u32 { + return coords.x * params.stride_src10 + coords.y * params.stride_src11 + + coords.z * params.stride_src12 + coords.w * params.stride_src13; +} + +fn in_set_view(rel: u32, coords: vec4) -> bool { + return view_rel_from_coords(coords) == rel; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + +#ifdef INPLACE + let coords = decode_src1_coords(gid.x); + + let src1_idx = params.offset_src1 + src1_idx_from_coords(coords); + let dst_idx = params.offset_view + view_rel_from_coords(coords); + + dst[dst_idx] = src1[src1_idx]; +#else + let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view); + let coords = decode_view_coords(rel); + + if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) { + dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)]; + } else { + dst[gid.x] = src0[params.offset_src0 + gid.x]; + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl similarity index 59% rename from ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl index c74dc4cc923..10edf136048 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl @@ -1,262 +1,162 @@ -#define(VARIANTS) -[ - { - "SHADER_NAME": "soft_max_f32", - "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_inplace", - "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink", - "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink_inplace", - "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - } -] -#end(VARIANTS) - -#define(DECLS) - -#decl(BASE_BINDINGS) -@group(0) @binding(1) -var dst: array; +enable f16; -@group(0) @binding(2) -var params: Params; -#enddecl(BASE_BINDINGS) +#ifdef MASK_F32 +#define MaskType f32 +#endif +#ifdef MASK_F16 +#define MaskType f16 +#endif -#decl(BASE_BINDINGS_INPLACE) -@group(0) @binding(1) -var params: Params; -#enddecl(BASE_BINDINGS_INPLACE) +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_sinks: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of src0/dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, -#decl(SINK_BINDINGS) + // shape of src1 + ne12: u32, + ne13: u32, + + scale: f32, + max_bias: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) +var src: array; + +#ifdef HAS_MASK +#ifdef HAS_SINK @group(0) @binding(1) +var mask: array; +@group(0) @binding(2) var sinks: array; -@group(0) @binding(2) -var dst: array; +#ifdef INPLACE +@group(0) @binding(3) +var params: Params; +#else @group(0) @binding(3) +var dst: array; +@group(0) @binding(4) var params: Params; -#enddecl(SINK_BINDINGS) +#endif -#decl(SINK_BINDINGS_INPLACE) +#else @group(0) @binding(1) -var sinks: array; +var mask: array; +#ifdef INPLACE @group(0) @binding(2) var params: Params; -#enddecl(SINK_BINDINGS_INPLACE) - -#decl(MASK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; +#else @group(0) @binding(2) var dst: array; - @group(0) @binding(3) var params: Params; -#enddecl(MASK_BINDINGS) +#endif +#endif -#decl(MASK_BINDINGS_INPLACE) +#else +#ifdef HAS_SINK @group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; +var sinks: array; +#ifdef INPLACE @group(0) @binding(2) var params: Params; -#enddecl(MASK_BINDINGS_INPLACE) - -#decl(MASK_SINK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; +#else @group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) var dst: array; - -@group(0) @binding(4) +@group(0) @binding(3) var params: Params; -#enddecl(MASK_SINK_BINDINGS) +#endif -#decl(MASK_SINK_BINDINGS_INPLACE) +#else +#ifdef INPLACE @group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - +var params: Params; +#else +@group(0) @binding(1) +var dst: array; @group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) var params: Params; -#enddecl(MASK_SINK_BINDINGS_INPLACE) +#endif +#endif +#endif -#decl(NOT_INPLACE) +#ifdef INPLACE fn inter_value(i: u32) -> f32 { - return dst[i]; + return src[i]; } - fn update(i: u32, val: f32) { - dst[i] = val; + src[i] = val; } -#enddecl(NOT_INPLACE) -#decl(INPLACE) +#else fn inter_value(i: u32) -> f32 { - return src[i]; + return dst[i]; } - fn update(i: u32, val: f32) { - src[i] = val; + dst[i] = val; } -#enddecl(INPLACE) +#endif -#decl(NO_MASK) +#ifdef HAS_MASK fn mask_val(i: u32) -> f32 { - return 0.0; + return f32(mask[i]); } -#enddecl(NO_MASK) -#decl(MASK) +#else fn mask_val(i: u32) -> f32 { - return f32(mask[i]); + return 0.0; } -#enddecl(MASK) +#endif -#decl(NO_SINK) +#ifdef HAS_SINK fn lower_max_bound(i2: u32) -> f32 { - return -1e30; + return sinks[params.offset_sinks + i2]; } - fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val; + return val + exp(sinks[params.offset_sinks + i2] - max_val); } -#enddecl(NO_SINK) - -#decl(SINK) +#else fn lower_max_bound(i2: u32) -> f32 { - return sinks[params.offset_sinks + i2]; + return -1e30; } - fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val + exp(sinks[params.offset_sinks + i2] - max_val); + return val; } -#enddecl(SINK) - -#end(DECLS) - -#define(SHADER) -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_sinks: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of src0/dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - // shape of src1 - ne12: u32, - ne13: u32, - - scale: f32, - max_bias: f32, - n_head_log2: f32, - m0: f32, - m1: f32, -}; - -@group(0) @binding(0) -var src: array; - -DECLS +#endif const CACHE_SIZE: u32 = 16; +var scratch: array; -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { @@ -268,7 +168,7 @@ fn main(@builtin(workgroup_id) wid: vec3, let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - let elems = (params.ne0 + wg_size - 1) / wg_size; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; let head = f32(i2); let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); @@ -286,12 +186,12 @@ fn main(@builtin(workgroup_id) wid: vec3, if (col < CACHE_SIZE) { cache[col] = val; } - col += wg_size; + col += WG_SIZE; } scratch[lid.x] = max_val; workgroupBarrier(); - var offset = wg_size / 2; + var offset: u32 = WG_SIZE / 2; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); @@ -317,12 +217,12 @@ fn main(@builtin(workgroup_id) wid: vec3, } else { update(i_dst_row + col, ex); } - col += wg_size; + col += WG_SIZE; } scratch[lid.x] = sum; workgroupBarrier(); - offset = wg_size / 2; + offset = WG_SIZE / 2; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] += scratch[lid.x + offset]; @@ -339,7 +239,7 @@ fn main(@builtin(workgroup_id) wid: vec3, break; } update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); - col += wg_size; + col += WG_SIZE; } } -#end(SHADER) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl new file mode 100644 index 00000000000..9d5d902cb1e --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl @@ -0,0 +1,121 @@ +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src00: u32, + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + k: u32, + ne2: u32, + ne3: u32, +}; + +@group(0) @binding(3) +var params: Params; + +var shA: array; +var shB: array; + +fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src0 + + col * params.stride_src00 + + row * params.stride_src01 + + i2 * params.stride_src02 + + i3 * params.stride_src03; +} + +fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src1 + + col * params.stride_src10 + + row * params.stride_src11 + + i2 * params.stride_src12 + + i3 * params.stride_src13; +} + +fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_dst + + col * params.stride_dst0 + + row * params.stride_dst1 + + i2 * params.stride_dst2 + + i3 * params.stride_dst3; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let batch = workgroup_id.y; + let col = workgroup_id.x * WG_SIZE + local_id.x; + let i3 = batch / params.ne2; + let i2 = batch % params.ne2; + let active_lane = local_id.x < K_TILE; + let active_col = active_lane && col < params.k; + + var X: array; + + for (var row_base = 0u; row_base < N; row_base += BATCH_N) { + let cur_n = min(BATCH_N, N - row_base); + + for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) { + let tile_row = i / N; + let tile_col = i % N; + shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)]; + } + + for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) { + let tile_row = i / K_TILE; + let tile_col = i % K_TILE; + let global_col = workgroup_id.x * WG_SIZE + tile_col; + let sh_idx = tile_row * K_TILE + tile_col; + + if (global_col < params.k) { + shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)]; + } else { + shB[sh_idx] = 0.0; + } + } + + workgroupBarrier(); + + if (active_col) { + for (var row_offset = 0u; row_offset < cur_n; row_offset++) { + let r = row_base + row_offset; + var b = shB[row_offset * K_TILE + local_id.x]; + let a_row = row_offset * N; + + for (var t = 0u; t < r; t++) { + b -= shA[a_row + t] * X[t]; + } + + let x = b / shA[a_row + r]; + X[r] = x; + dst[dst_idx(r, col, i2, i3)] = x; + } + } + + workgroupBarrier(); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl new file mode 100644 index 00000000000..11511305ed8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl @@ -0,0 +1,65 @@ +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src01: u32, + stride_src02: u32, + stride_src11: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + + nc: u32, + nr: u32, + n_t: u32, + n_s: u32, + token_tiles: u32, +}; + +@group(0) @binding(3) +var params: Params; + +@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i1 = gid.x; + let tile_y = gid.y / TOKENS_PER_WG; + let local_token = gid.y % TOKENS_PER_WG; + let i3 = tile_y / params.token_tiles; + let token_tile = tile_y % params.token_tiles; + let i2 = token_tile * TOKENS_PER_WG + local_token; + + if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) { + return; + } + + let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01; + let src1_base = params.offset_src1 + i1 * params.stride_src11; + + var sum = 0.0; + +#ifdef VECTORIZED + sum = + src0[src0_base + 0u] * src1[src1_base + 0u] + + src0[src0_base + 1u] * src1[src1_base + 1u] + + src0[src0_base + 2u] * src1[src1_base + 2u] + + src0[src0_base + 3u] * src1[src1_base + 3u]; +#else + for (var i0 = 0u; i0 < params.nc; i0++) { + sum += src0[src0_base + i0] * src1[src1_base + i0]; + } +#endif + + let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0; + dst[dst_idx] = sum; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl new file mode 100644 index 00000000000..05761dec353 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl @@ -0,0 +1,193 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif + +struct Params { + offset_s: u32, + offset_x: u32, + offset_dt: u32, + offset_A: u32, + offset_B: u32, + offset_C: u32, + offset_ids: u32, + offset_dst: u32, + + stride_s1: u32, + stride_s2: u32, + stride_s3: u32, + + stride_x1: u32, + stride_x2: u32, + stride_x3: u32, + + stride_dt1: u32, + stride_dt2: u32, + + a_ne0: u32, + stride_A1: u32, + + stride_B1: u32, + stride_B2: u32, + stride_B3: u32, + + stride_C1: u32, + stride_C2: u32, + stride_C3: u32, + + d_state: u32, + d_inner: u32, + n_head: u32, + n_group: u32, + n_seq_tokens: u32, + n_seqs: u32, + + y_elems: u32, +}; + +@group(0) @binding(0) var s_in: array; +#ifdef XBC_OVERLAP +@group(0) @binding(1) var x_B_C_merged: array; +@group(0) @binding(2) var dt: array; +@group(0) @binding(3) var A: array; +@group(0) @binding(4) var ids: array; +@group(0) @binding(5) var dst: array; +@group(0) @binding(6) var params: Params; +#else +@group(0) @binding(1) var x: array; +@group(0) @binding(2) var dt: array; +@group(0) @binding(3) var A: array; +@group(0) @binding(4) var B: array; +@group(0) @binding(5) var C: array; +@group(0) @binding(6) var ids: array; +@group(0) @binding(7) var dst: array; +@group(0) @binding(8) var params: Params; +#endif + +var shared_x_dt: array; +var shared_dtsp: array; +var shared_reduce: array; + +fn reduce_base(token_in_tile: u32) -> u32 { + return token_in_tile * WG_SIZE; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32 +#endif +) { + let tid = local_id.x; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + let i1 = wg_linear % params.d_inner; + let head_seq = wg_linear / params.d_inner; + let ir = head_seq % params.n_head; + let i3 = head_seq / params.n_head; + + let state_slot = u32(ids[params.offset_ids + i3]); + let g = ir / (params.n_head / params.n_group); + + let s_idx = params.offset_s + tid + i1 * params.stride_s1 + ir * params.stride_s2 + state_slot * params.stride_s3; + var s_prev = s_in[s_idx]; + + let A0 = A[params.offset_A + (tid % params.a_ne0) + ir * params.stride_A1]; + + for (var token_base = 0u; token_base < params.n_seq_tokens; token_base += TOKENS_PER_TILE) { + if (tid < TOKENS_PER_TILE) { + let token = token_base + tid; + if (token < params.n_seq_tokens) { + let x_idx = params.offset_x + i1 + ir * params.stride_x1 + token * params.stride_x2 + i3 * params.stride_x3; + let dt_idx = params.offset_dt + ir + token * params.stride_dt1 + i3 * params.stride_dt2; + let dt0 = dt[dt_idx]; + let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); + shared_dtsp[tid] = dtsp; +#ifdef XBC_OVERLAP + shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp; +#else + shared_x_dt[tid] = x[x_idx] * dtsp; +#endif + } + } + + workgroupBarrier(); + + for (var token_in_tile = 0u; token_in_tile < TOKENS_PER_TILE; token_in_tile++) { + let token = token_base + token_in_tile; + if (token >= params.n_seq_tokens) { + break; + } + + let x_dt = shared_x_dt[token_in_tile]; + let dA = exp(shared_dtsp[token_in_tile] * A0); + let reduce_idx = reduce_base(token_in_tile) + tid; + + let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; + let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; +#ifdef XBC_OVERLAP + let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt; +#else + let s = s_prev * dA + B[b_idx] * x_dt; +#endif + s_prev = s; + +#ifdef USE_SUBGROUP_REDUCTION +#ifdef XBC_OVERLAP + let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]); +#else + let subgroup_partial = subgroupAdd(s * C[c_idx]); +#endif + if (subgroup_invocation_id == 0u) { + shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; + } +#else +#ifdef XBC_OVERLAP + shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx]; +#else + shared_reduce[reduce_idx] = s * C[c_idx]; +#endif +#endif + + workgroupBarrier(); + +#ifdef USE_SUBGROUP_REDUCTION + if (tid == 0u) { + var sum = 0.0; + for (var sg = 0u; sg < num_subgroups; sg++) { + sum += shared_reduce[reduce_base(token_in_tile) + sg]; + } + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = sum; + } +#else + for (var stride = WG_SIZE / 2u; stride > 0u; stride >>= 1u) { + if (tid < stride) { + shared_reduce[reduce_idx] += shared_reduce[reduce_idx + stride]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = shared_reduce[reduce_base(token_in_tile)]; + } +#endif + + workgroupBarrier(); + } + } + + let state_idx = + params.offset_dst + params.y_elems + tid + i1 * params.d_state + ir * (params.d_state * params.d_inner) + + i3 * (params.d_state * params.d_inner * params.n_head); + dst[state_idx] = s_prev; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index feaf6d0ac29..b8f1bca1284 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -5,7 +5,6 @@ enable f16; #define TYPE f32 #endif - @group(0) @binding(0) var src: array; @@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { return; } var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; + let ne2 = params.ne2; +#ifdef DIAG + let ne1 = params.ne0; +#else + let ne1 = params.ne1; +#endif + let ne0 = params.ne0; + + let i3 = i / (ne2 * ne1 * ne0); + i = i % (ne2 * ne1 * ne0); + let i2 = i / (ne1 * ne0); + i = i % (ne1 * ne0); + let i1 = i / ne0; + let i0 = i % ne0; let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3; @@ -100,7 +107,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx])); #endif #ifdef EXP - let res = exp(src[params.offset_src + src_idx]); + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32)); #endif #ifdef LOG let res = TYPE(log(f32(src[params.offset_src + src_idx]))); @@ -139,22 +147,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { -9.010913, 9.010913))); #endif #ifdef XIELU + let val = f32(src[params.offset_src + src_idx]); let res = - select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) - - src[params.offset_src + src_idx]) * - TYPE(params.alpha_n) + - TYPE(params.beta) * src[params.offset_src + src_idx], - TYPE(params.alpha_p) * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] + - TYPE(params.beta) * src[params.offset_src + src_idx], - src[params.offset_src + src_idx] > 0.0); + TYPE(select( + ((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val, + params.alpha_p * val * val + params.beta * val, + val > 0.0)); #endif #ifdef SOFTPLUS let src_f32 = f32(src[params.offset_src + src_idx]); let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0)); #endif #ifdef EXPM1 - let res = exp(src[params.offset_src + src_idx]) - 1.0; + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32) - 1.0); #endif #ifdef FLOOR let res = floor(src[params.offset_src + src_idx]); @@ -174,7 +180,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx]; #endif #ifdef SQRT - let res = sqrt(src[params.offset_src + src_idx]); + let res = TYPE(sqrt(f32(src[params.offset_src + src_idx]))); #endif #ifdef SIN let res_f32 = sin(f32(src[params.offset_src + src_idx])); @@ -184,6 +190,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res_f32 = cos(f32(src[params.offset_src + src_idx])); let res = TYPE(res_f32); #endif +#ifdef DIAG + let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1); +#endif +#ifdef TRI +#ifdef TRI_TYPE_LOWER + let res = select(0.0, src[params.offset_src + src_idx], i0 < i1); +#elif TRI_TYPE_LOWER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1); +#elif TRI_TYPE_UPPER + let res = select(0.0, src[params.offset_src + src_idx], i0 > i1); +#elif TRI_TYPE_UPPER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1); +#endif +#endif #ifdef INPLACE src[params.offset_src + src_idx] = res; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl new file mode 100644 index 00000000000..e9ef8822644 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl @@ -0,0 +1,240 @@ +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif + +#ifdef SRC_F16 +#define SRC_TYPE f16 +#else +#define SRC_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + src_w: u32, + src_h: u32, + src_z: u32, + src_n: u32, + + dst_w: u32, + dst_h: u32, + dst_z: u32, + dst_n: u32, + + mode_flags: u32, +}; + +@group(0) @binding(2) +var params: Params; + +const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u; + +fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 { + let cx = u32(clamp(x, 0, i32(params.src_w) - 1)); + let cy = u32(clamp(y, 0, i32(params.src_h) - 1)); + let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3; + return f32(input[i]); +} + +fn cubic_weight(t: f32, a: f32) -> f32 { + let at = abs(t); + if (at <= 1.0) { + return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0; + } else if (at <= 2.0) { + return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a; + } else { + return 0.0; + } +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y; + let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n; + + if (i_out >= total) { + return; + } + + // decode (x, y, z, n) + var i = i_out; + let x_dst = i % params.dst_w; + i = i / params.dst_w; + let y_dst = i % params.dst_h; + i = i / params.dst_h; + let z_dst = i % params.dst_z; + let n_dst = i / params.dst_z; + + // scale factors + var sf0 = f32(params.dst_w) / f32(params.src_w); + var sf1 = f32(params.dst_h) / f32(params.src_h); + var sf2 = f32(params.dst_z) / f32(params.src_z); + var sf3 = f32(params.dst_n) / f32(params.src_n); + + let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0; + + // pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners + var pixel_offset = 0.5; + if (align_corners) { + pixel_offset = 0.0; + if (params.dst_w > 1 && params.src_w > 1) { + sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1); + } + if (params.dst_h > 1 && params.src_h > 1) { + sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1); + } + } + + let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2))); + let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3))); + + var result = 0.0; + +#if defined(NEAREST) + + let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0))); + let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1))); + + result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src); + +#elif defined(BILINEAR) + +#if defined(ANTIALIAS) + + // Antialiased bilinear: triangle filter over a variable support region. + let support0 = max(1.0f / sf0, 1.0f); + let support1 = max(1.0f / sf1, 1.0f); + let invscale0 = 1.0 / support0; + let invscale1 = 1.0 / support1; + + let fx = (f32(x_dst) + pixel_offset) / sf0; + let fy = (f32(y_dst) + pixel_offset) / sf1; + + let x_min = max(i32(fx - support0 + pixel_offset), 0); + let y_min = max(i32(fy - support1 + pixel_offset), 0); + let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w)); + let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h)); + + var weighted_sum = 0.0; + var total_weight = 0.0; + + for (var x = x_min; x < x_max; x += 1) { + let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0); + for (var y = y_min; y < y_max; y += 1) { + let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0); + let w = wx * wy; + if (w > 0.0) { + weighted_sum += get_clamped_input(x, y, z_src, n_src) * w; + total_weight += w; + } + } + } + + if (total_weight > 0.0) { + result = weighted_sum / total_weight; + } + +#else + + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = clamp(fx - f32(x0), 0.0, 1.0); + let dy = clamp(fy - f32(y0), 0.0, 1.0); + let a = get_clamped_input(x0, y0, z_src, n_src); + let b = get_clamped_input(x0 + 1, y0, z_src, n_src); + let c = get_clamped_input(x0, y0 + 1, z_src, n_src); + let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + + let wa = (1.0 - dx) * (1.0 - dy); + let wb = dx * (1.0 - dy); + let wc = (1.0 - dx) * dy; + let wd = dx * dy; + + result = a * wa + b * wb + c * wc + d * wd; + +#endif + +#elif defined(BICUBIC) + + // bicubic convolution with alpha = -0.75 (PyTorch default) + let alpha = -0.75; + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = fx - f32(x0); + let dy = fy - f32(y0); + + // horizontal weights for offsets -1, 0, 1, 2 + let wx0 = cubic_weight(dx + 1.0, alpha); + let wx1 = cubic_weight(dx, alpha); + let wx2 = cubic_weight(1.0 - dx, alpha); + let wx3 = cubic_weight(2.0 - dx, alpha); + + // vertical weights for offsets -1, 0, 1, 2 + let wy0 = cubic_weight(dy + 1.0, alpha); + let wy1 = cubic_weight(dy, alpha); + let wy2 = cubic_weight(1.0 - dy, alpha); + let wy3 = cubic_weight(2.0 - dy, alpha); + + // intermediate horizontal interpolation for 4x4 grid of pixels + // x0-1, x0, x0+1, x0+2, y0-1 + let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src); + let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src); + let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src); + let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src); + let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0 + let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src); + let q1 = get_clamped_input(x0, y0, z_src, n_src); + let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src); + let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src); + let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+1 + let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src); + let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src); + let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src); + let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+2 + let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src); + let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src); + let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src); + let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src); + let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3; + + // final vertical interpolation + result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3; + +#endif + + let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3; + output[dst_idx] = DST_TYPE(result); +} diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 9b6938abf7e..639b818d128 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -313,6 +313,8 @@ static ggml_backend_buffer_i ggml_backend_zdnn_buffer_i = { /* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_zdnn_buffer_set_tensor, /* .get_tensor = */ ggml_backend_zdnn_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_zdnn_buffer_clear, /* .reset = */ NULL, @@ -417,20 +419,22 @@ static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, } static ggml_backend_i ggml_backend_zdnn_i = { - /* .get_name = */ ggml_backend_zdnn_name, - /* .free = */ ggml_backend_zdnn_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_zdnn_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .graph_optimize = */ NULL, + /* .get_name = */ ggml_backend_zdnn_name, + /* .free = */ ggml_backend_zdnn_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_zdnn_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) { diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index 9bdb4e836d3..4f321a25257 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08 + GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index c8760304008..2b82c7c1dbb 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -190,6 +190,170 @@ static void ggml_zendnn_compute_forward_mul_mat( } } +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static void ggml_zendnn_compute_forward_mul_mat_id( + ggml_backend_zendnn_context * ctx, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; // expert weights + const ggml_tensor * src1 = dst->src[1]; // inputs + const ggml_tensor * ids = dst->src[2]; // expert ids + + GGML_TENSOR_BINARY_OP_LOCALS + + // exit for no tokens to process + if (ne2 == 0 || ne11 == 0) { + return; + } + + ggml_type const vec_dot_type = src0->type; + ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne3 == 1); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_experts + + std::vector matrix_row_counts(n_as, 0); + std::vector> matrix_rows(n_as); + + int64_t max_rows = 0; + // group rows by expert (preprocessing step) + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + matrix_rows[i02].push_back({id, iid1}); + matrix_row_counts[i02]++; + if (matrix_row_counts[i02] > max_rows) { + max_rows = matrix_row_counts[i02]; + } + } + } + + if (max_rows == 0) { + return; // no rows to process + } + + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + // size for converting src1 rows to vec_dot_type if needed + const size_t nbw1 = row_size; + const size_t nbw2 = nbw1 * ne11; + const size_t nbw3 = nbw2 * ne12; + const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0; + + // size for MoE gather/scatter buffers + const size_t wdata_cur_size = max_rows * row_size; + const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); + + // allocate single buffer for all needs + const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size; + if (ctx->work_size < total_size) { + ctx->work_data.reset(new char[total_size]); + ctx->work_size = total_size; + } + + // partition the buffer + char * work_data = ctx->work_data.get(); + char * wdata_cur = work_data + src1_conv_size; + char * dst_cur = wdata_cur + wdata_cur_size; + + if (src1->type != vec_dot_type) { + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13); + void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3; + from_float(src1_f32, src1_conv, ne10); + } + } + } + } + + const void * wdata = src1->type == vec_dot_type ? src1->data : work_data; + + // process each expert with gather -> gemm -> scatter pattern + for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_cur = (const char *) src0->data + cur_a*nb02; + + // gather input rows for this expert + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + + std::memcpy( + wdata_cur + ir1 * row_size, + (const char *) wdata + (i11 + i12*ne11) * row_size, + row_size + ); + } + + // batched gemm for all tokens in this expert + if (!ggml_zendnn_sgemm(ctx, + ne01, // m + cne1, // n + ne10, // k + src0_cur, + ne00, // lda + wdata_cur, + ne10, // ldb + dst_cur, + ne01, // ldc + src0->type, + vec_dot_type, + dst->type)) { + GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + } + + // scatter output rows to destination + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i1 = id; + const int64_t i2 = row_mapping.i2; + + std::memcpy( + (char *) dst->data + i1*nb1 + i2*nb2, + dst_cur + ir1 * ggml_row_size(dst->type, ne01), + ggml_row_size(dst->type, ne01) + ); + } + } +} + // backend interface static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) { @@ -218,6 +382,9 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); break; + case GGML_OP_MUL_MAT_ID: + ggml_zendnn_compute_forward_mul_mat_id(ctx, node); + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -240,6 +407,8 @@ static struct ggml_backend_i ggml_backend_zendnn_i = { /* .free = */ ggml_backend_zendnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, @@ -361,6 +530,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const return true; case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { const ggml_tensor * weights = op->src[0]; const ggml_tensor * inputs = op->src[1]; @@ -374,6 +544,17 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { return false; } + // MUL_MAT_ID performs best with a moderate number of experts due to its + // gather + batched matmul + scatter approach. Future versions will leverage + // ZenDNN's grouped_gemm for better scalability with larger expert counts: + // https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md + if (op->op == GGML_OP_MUL_MAT_ID) { + const int64_t n_experts = weights->ne[2]; + const int64_t max_experts = 32; + if (n_experts > max_experts) { + return false; + } + } switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4c0764a0ac5..81343eeb14c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -53,6 +53,21 @@ #define UNUSED GGML_UNUSED +uint64_t ggml_graph_next_uid(void) { +#ifdef _MSC_VER +#if defined(_WIN32) + static volatile LONG counter = 1; + return (uint64_t) InterlockedIncrement(&counter) - 1; +#else + static volatile long long counter = 1; + return (uint64_t) _InterlockedIncrement64(&counter) - 1; +#endif +#else + static uint64_t counter = 1; + return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED); +#endif +} + // Needed for ggml_fp32_to_bf16_row() #if defined(__AVX512BF16__) #if defined(_MSC_VER) @@ -651,6 +666,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, }, + [GGML_TYPE_Q1_0] = { + .type_name = "q1_0", + .blck_size = QK1_0, + .type_size = sizeof(block_q1_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q1_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q1_0_ref, + }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -1384,6 +1407,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q1_0: wtype = GGML_TYPE_Q1_0; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; @@ -4962,6 +4986,7 @@ static struct ggml_tensor * ggml_interpolate_impl( GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); // TODO: implement antialias for modes other than bilinear GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); + GGML_ASSERT(a->type == GGML_TYPE_F32); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); @@ -5307,6 +5332,7 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(q->ne[3] == v->ne[3]); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); @@ -7087,6 +7113,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.use_counts =*/ use_counts_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + /*.uid =*/ 0, }; ggml_hash_set_reset(&cgraph->visited_hash_set); @@ -7114,6 +7141,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.use_counts =*/ cgraph0->use_counts, /*.visited_hash_set =*/ cgraph0->visited_hash_set, /*.order =*/ cgraph0->order, + /*.uid =*/ 0 }; return cgraph; @@ -7633,7 +7661,7 @@ size_t ggml_quantize_chunk( int64_t nrows, int64_t n_per_row, const float * imatrix) { - const int64_t n = (int64_t) nrows * n_per_row; + const int64_t n = nrows * n_per_row; if (ggml_quantize_requires_imatrix(type)) { GGML_ASSERT(imatrix != NULL); @@ -7650,20 +7678,21 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { - case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0: result = quantize_q4_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_1: result = quantize_q4_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_0: result = quantize_q5_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_1: result = quantize_q5_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0: result = quantize_q8_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q2_K: result = quantize_q2_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q3_K: result = quantize_q3_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_K: result = quantize_q4_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_K: result = quantize_q5_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_K: result = quantize_q6_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ1_0: result = quantize_tq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ2_0: result = quantize_tq2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -7728,9 +7757,9 @@ struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { } bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { - if (p0->n_threads != p1->n_threads ) return false; - if (p0->prio != p1->prio ) return false; - if (p0->poll != p1->poll ) return false; - if (p0->strict_cpu != p1->strict_cpu ) return false; + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; } diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index cbeedf6c4b6..ab3cc974867 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -394,7 +394,11 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & bu gguf_write_out(ctx, gw, only_meta); } +bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta) { + GGML_ASSERT(file); + + try { + gguf_writer_file gw(file); + gguf_write_out(ctx, gw, only_meta); + } catch (const std::runtime_error& ex) { + GGML_LOG_ERROR("%s: failed to write GGUF data: %s\n", __func__, ex.what()); + return false; + } + return true; +} + bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = ggml_fopen(fname, "wb"); @@ -1516,17 +1533,13 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo return false; } - try { - gguf_writer_file gw(file); - gguf_write_out(ctx, gw, only_meta); - } catch (const std::runtime_error& ex) { - GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what()); - fclose(file); - return false; + const bool success = gguf_write_to_file_ptr(ctx, file, only_meta); + if (!success) { + GGML_LOG_ERROR("%s: failed to write GGUF data into '%s'\n", __func__, fname); } fclose(file); - return true; + return success; } size_t gguf_get_meta_size(const struct gguf_context * ctx) { diff --git a/include/whisper.h b/include/whisper.h index 85091e02fea..2670313593b 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -699,6 +699,16 @@ extern "C" { const float * samples, int n_samples); + // Like whisper_vad_detect_speech, but does not reset LSTM state. + // Use for streaming: call whisper_vad_reset_state() between utterances. + WHISPER_API bool whisper_vad_detect_speech_no_reset( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples); + + // Reset LSTM hidden/cell states to zero. + WHISPER_API void whisper_vad_reset_state(struct whisper_vad_context * vctx); + WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context * vctx); WHISPER_API float * whisper_vad_probs (struct whisper_vad_context * vctx); diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6557fb46cbe..812e721a8c5 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c044a8eeae2591faa0950c8b5e514cbc4bbfc4ca +19eac6f0edaf285506eb6228d31bb9caeda9aba1 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f360411d704..3a09c7b9157 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,8 +67,12 @@ if (WHISPER_COREML) include(DefaultTargetOptions) - target_include_directories(${TARGET} PUBLIC - . + # PRIVATE (not PUBLIC) so the raw source-tree include path doesn't + # propagate through the install/export interface. whisper.coreml is an + # internal helper that only whisper.cpp (in the same directory) consumes; + # external users of `whisper` never need to see these headers. + target_include_directories(${TARGET} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} ) target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK}) @@ -79,7 +83,12 @@ if (WHISPER_COREML) ) set_target_properties(${TARGET} PROPERTIES FOLDER "libs") - install(TARGETS ${TARGET} LIBRARY) + # Join the whisper-targets export set so that consumers' install(EXPORT) + # works when whisper PRIVATE-links to whisper.coreml. Without this the + # iOS / xcframework build fails at CMake generate time with + # install(EXPORT "whisper-targets" ...) includes target "whisper" which + # requires target "whisper.coreml" that is not in any export set. + install(TARGETS ${TARGET} EXPORT whisper-targets LIBRARY) endif() if (WHISPER_OPENVINO) diff --git a/src/whisper.cpp b/src/whisper.cpp index 0a9619dcac7..25eee380665 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -5162,7 +5162,11 @@ struct whisper_vad_context * whisper_vad_init_with_params( return vctx; } -bool whisper_vad_detect_speech( +void whisper_vad_reset_state(whisper_vad_context * vctx) { + ggml_backend_buffer_clear(vctx->buffer, 0); +} + +bool whisper_vad_detect_speech_no_reset( struct whisper_vad_context * vctx, const float * samples, int n_samples) { @@ -5174,9 +5178,6 @@ bool whisper_vad_detect_speech( WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples); WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks); - // Reset LSTM hidden/cell states - ggml_backend_buffer_clear(vctx->buffer, 0); - vctx->probs.resize(n_chunks); WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks); @@ -5244,6 +5245,14 @@ bool whisper_vad_detect_speech( return true; } +bool whisper_vad_detect_speech( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples) { + whisper_vad_reset_state(vctx); + return whisper_vad_detect_speech_no_reset(vctx, samples, n_samples); +} + int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) { return segments->data.size(); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09e77ea89c2..6b80b023ffb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -110,3 +110,18 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") + +# Streaming VAD test pins the upstream PR #3677 contract: +# whisper_vad_detect_speech_no_reset + whisper_vad_reset_state must +# combine to yield exactly the same per-chunk probabilities as the +# original whisper_vad_detect_speech call. Uses only the silero VAD +# model (no whisper transcribe model needed). +set(VAD_TEST test-vad-streaming) +add_executable(${VAD_TEST} ${VAD_TEST}.cpp) +target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${VAD_TEST} PRIVATE common) +target_compile_definitions(${VAD_TEST} PRIVATE + VAD_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-silero-v6.2.0-ggml.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") +add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) +set_tests_properties(${VAD_TEST} PROPERTIES LABELS "unit;vad") diff --git a/tests/test-vad-streaming.cpp b/tests/test-vad-streaming.cpp new file mode 100644 index 00000000000..4c4d092eefa --- /dev/null +++ b/tests/test-vad-streaming.cpp @@ -0,0 +1,133 @@ +// Regression coverage for the streaming VAD API added by upstream +// ggml-org/whisper.cpp PR #3677 (commit 166c20b4) and pulled into the +// tetherto fork via QVAC-18991. The upstream change did not ship a test, +// so we add one here to lock the contract: +// +// 1. whisper_vad_detect_speech is idempotent (reset is implicit). +// 2. whisper_vad_reset_state restores the LSTM state, so a no-reset +// run after a reset produces the same probs as the very first run. +// 3. Two contiguous whisper_vad_detect_speech_no_reset calls on +// adjacent halves of the input produce the same per-chunk probs +// as a single whisper_vad_detect_speech(full_input) — i.e. the +// LSTM state actually carries across the boundary (within the +// tolerance of running the same graph through two scheduler +// activations). +// +// Split is performed at a multiple of vctx->n_window so that the second +// half starts cleanly on a chunk boundary and no zero-padding is +// introduced mid-stream that would diverge from the single-shot run. + +#include "whisper.h" +#include "common-whisper.h" + +#include +#include +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +static std::vector snapshot_probs(struct whisper_vad_context * vctx) { + const int n = whisper_vad_n_probs(vctx); + const float * p = whisper_vad_probs(vctx); + return std::vector(p, p + n); +} + +static void assert_probs_near(const std::vector & a, + const std::vector & b, + float tol, + const char * label) { + assert(a.size() == b.size()); + float worst = 0.0f; + for (size_t i = 0; i < a.size(); ++i) { + const float d = std::fabs(a[i] - b[i]); + if (d > worst) worst = d; + } + printf("%s: max |diff| = %.6f over %zu probs (tol = %.6f)\n", label, worst, a.size(), tol); + assert(worst <= tol); +} + +int main() { + const std::string vad_model_path = VAD_MODEL_PATH; + const std::string sample_path = SAMPLE_PATH; + + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + + struct whisper_vad_context_params ctx_params = whisper_vad_default_context_params(); + struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params( + vad_model_path.c_str(), ctx_params); + assert(vctx != nullptr); + + // --- Test 1: detect_speech is idempotent (implicit reset). ----------- + assert(whisper_vad_detect_speech(vctx, pcmf32.data(), (int)pcmf32.size())); + const auto probs_first = snapshot_probs(vctx); + + assert(whisper_vad_detect_speech(vctx, pcmf32.data(), (int)pcmf32.size())); + const auto probs_second = snapshot_probs(vctx); + + assert_probs_near(probs_first, probs_second, 0.0f, + "Test1 detect_speech idempotent"); + + // --- Test 2: reset_state restores initial LSTM state. ---------------- + whisper_vad_reset_state(vctx); + assert(whisper_vad_detect_speech_no_reset(vctx, pcmf32.data(), (int)pcmf32.size())); + const auto probs_no_reset_a = snapshot_probs(vctx); + + whisper_vad_reset_state(vctx); + assert(whisper_vad_detect_speech_no_reset(vctx, pcmf32.data(), (int)pcmf32.size())); + const auto probs_no_reset_b = snapshot_probs(vctx); + + assert_probs_near(probs_no_reset_a, probs_no_reset_b, 0.0f, + "Test2 reset_state restores LSTM"); + + // detect_speech (which resets internally) must also match the + // reset+no_reset sequence — proves they share identical semantics + // when starting from a clean state. + assert_probs_near(probs_first, probs_no_reset_a, 0.0f, + "Test2b detect_speech == reset+no_reset"); + + // --- Test 3: streaming carries LSTM state across calls. -------------- + // Split exactly on a chunk boundary so we don't introduce mid-stream + // zero padding. The Silero v6.2.0 VAD model fixture uses a fixed + // 512-sample window at 16 kHz (32 ms); the chunk count is + // ceil(n_samples / 512), with the last chunk zero-padded if short. + constexpr int kSileroWindow = 512; + const int total_chunks = (int)probs_first.size(); + assert(total_chunks == ((int)pcmf32.size() + kSileroWindow - 1) / kSileroWindow); + const int half_chunks = total_chunks / 2; + assert(half_chunks >= 1 && total_chunks - half_chunks >= 1); + const int split_idx = half_chunks * kSileroWindow; + assert(split_idx > 0 && split_idx < (int)pcmf32.size()); + + whisper_vad_reset_state(vctx); + assert(whisper_vad_detect_speech_no_reset(vctx, pcmf32.data(), split_idx)); + const auto probs_part1 = snapshot_probs(vctx); + assert((int)probs_part1.size() == half_chunks); + + assert(whisper_vad_detect_speech_no_reset( + vctx, pcmf32.data() + split_idx, (int)pcmf32.size() - split_idx)); + const auto probs_part2 = snapshot_probs(vctx); + assert((int)probs_part2.size() == total_chunks - half_chunks); + + // Concatenate part1 + part2 and compare against single-shot probs. + std::vector probs_stitched; + probs_stitched.reserve(total_chunks); + probs_stitched.insert(probs_stitched.end(), probs_part1.begin(), probs_part1.end()); + probs_stitched.insert(probs_stitched.end(), probs_part2.begin(), probs_part2.end()); + + // Float-equality is the contract: the per-chunk graph is the same + // and the LSTM state is preserved exactly across the no_reset call + // boundary. If a future refactor introduces tiny numerical drift + // across scheduler resets, bump the tolerance — but never silently. + assert_probs_near(probs_first, probs_stitched, 0.0f, + "Test3 streaming == single-shot"); + + whisper_vad_free(vctx); + return 0; +} diff --git a/tts-cpp/src/chatterbox_tts.cpp b/tts-cpp/src/chatterbox_tts.cpp index edd762a7285..603cbc74f3f 100644 --- a/tts-cpp/src/chatterbox_tts.cpp +++ b/tts-cpp/src/chatterbox_tts.cpp @@ -43,6 +43,7 @@ #endif #include +#include #include #include #include