diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c17f291..0523f2cb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -187,4 +187,6 @@ jobs: - name: Build example for Android run: | sed -i 's/rnllamaBuildFromSource=true/rnllamaBuildFromSource=false/g' example/android/gradle.properties - npm run build:android + sed -i 's/reactNativeArchitectures=.*/reactNativeArchitectures=arm64-v8a,x86_64/g' example/android/gradle.properties + cd example/android + ./gradlew assembleDebug --stacktrace diff --git a/README.md b/README.md index a2174a73..5e736938 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,53 @@ Please visit the [Documentation](docs/API) for more details. You can also visit the [example](example) to see how to use it. +## MTP Speculative Decoding + +MTP speculative decoding can be enabled for GGUF models that contain MTP/NextN layers: + +```js +const context = await initLlama({ + model: modelPath, + n_ctx: 4096, + n_batch: 1024, + n_ubatch: 512, + n_gpu_layers: 99, + flash_attn_type: 'auto', + cache_type_k: 'q8_0', + cache_type_v: 'q8_0', + speculative: { + type: 'draft-mtp', + n_max: 3, + }, +}) + +const result = await context.completion({ + messages: [ + { + role: 'user', + content: + 'Write a concise TypeScript function that groups an array of objects by a key.', + }, + ], + chat_template_kwargs: { + preserve_thinking: true, + }, + n_predict: 128, + temperature: 0.6, + top_k: 20, + top_p: 0.95, + speculative: { + type: 'draft-mtp', + n_max: 3, + }, +}) + +console.log(result.text) +console.log(result.draft_tokens, result.draft_tokens_accepted) +``` + +Use `speculative: false` on a completion call to disable MTP for that request. For recurrent or hybrid models, enable MTP at `initLlama` time with a positive `spec_draft_n_max` or `speculative.draft.n_max` so llama.cpp can allocate rollback state. Current MTP support is text-only and is not used by queued parallel completions. + ## Multimodal (Vision & Audio) `llama.rn` supports multimodal capabilities including vision (images) and audio processing. This allows you to interact with models that can understand both text and media content. diff --git a/cpp/ggml-metal/ggml-metal.metal b/cpp/ggml-metal/ggml-metal.metal index 4219d888..d68b391e 100644 --- a/cpp/ggml-metal/ggml-metal.metal +++ b/cpp/ggml-metal/ggml-metal.metal @@ -13354,7 +13354,7 @@ void mmv_fn( device char * dst, threadgroup char * shmem, uint3 tgpig, - ushort tiitg, + uint tiitg, ushort tiisg, ushort sgitg) { disp_fn(args, src0, src1, dst, tgpig, tiisg); @@ -13368,7 +13368,7 @@ void mmv_fn( device char * dst, threadgroup char * shmem, uint3 tgpig, - ushort tiitg, + uint tiitg, ushort tiisg, ushort sgitg) { disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); @@ -13385,7 +13385,7 @@ kernel void kernel_mul_mv_id( device const char * ids, threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + uint tiitg[[thread_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int iid1 = tgpig.z/args.nei0; diff --git a/cpp/jsi/JSICompletion.h b/cpp/jsi/JSICompletion.h index 6e44aa52..ba457c73 100644 --- a/cpp/jsi/JSICompletion.h +++ b/cpp/jsi/JSICompletion.h @@ -163,6 +163,8 @@ namespace rnllama_jsi { ); res.setProperty(runtime, "tokens_predicted", (double)ctx->completion->num_tokens_predicted); res.setProperty(runtime, "tokens_evaluated", (double)ctx->completion->num_prompt_tokens); + res.setProperty(runtime, "draft_tokens", (double)ctx->completion->num_draft_tokens); + res.setProperty(runtime, "draft_tokens_accepted", (double)ctx->completion->num_draft_tokens_accepted); res.setProperty(runtime, "truncated", ctx->completion->truncated); res.setProperty(runtime, "context_full", ctx->completion->context_full); res.setProperty(runtime, "interrupted", ctx->completion->is_interrupted); @@ -230,6 +232,8 @@ namespace rnllama_jsi { res.setProperty(runtime, "tokens_predicted", (double)slot->num_tokens_predicted); res.setProperty(runtime, "tokens_evaluated", (double)slot->num_prompt_tokens); + res.setProperty(runtime, "draft_tokens", 0.0); + res.setProperty(runtime, "draft_tokens_accepted", 0.0); res.setProperty(runtime, "truncated", slot->truncated); res.setProperty(runtime, "context_full", slot->context_full); res.setProperty(runtime, "interrupted", slot->is_interrupted); diff --git a/cpp/jsi/JSIParams.cpp b/cpp/jsi/JSIParams.cpp index df977a02..21bfe1b8 100644 --- a/cpp/jsi/JSIParams.cpp +++ b/cpp/jsi/JSIParams.cpp @@ -1,4 +1,9 @@ #include "JSIParams.h" +#if defined(RNLLAMA_USE_FRAMEWORK_HEADERS) +#include +#else +#include "speculative.h" +#endif #include #include #include @@ -57,6 +62,12 @@ namespace rnllama_jsi { } #endif +#if defined(__APPLE__) + static int default_apple_n_threads() { + return std::max(1, common_cpu_get_num_math() / 2); + } +#endif + std::string getPropertyAsString(jsi::Runtime& runtime, const jsi::Object& obj, const char* name, const std::string& defaultValue) { if (obj.hasProperty(runtime, name)) { auto val = obj.getProperty(runtime, name); @@ -107,6 +118,163 @@ namespace rnllama_jsi { return defaultValue; } + static bool isNil(const jsi::Value& value) { + return value.isNull() || value.isUndefined(); + } + + static std::string normalizeSpeculativeTypeName(std::string name) { + if (name == "mtp") { + return "draft-mtp"; + } + return name; + } + + static void addSpeculativeTypeName(std::vector& typeNames, std::string name) { + name = normalizeSpeculativeTypeName(std::move(name)); + if (std::find(typeNames.begin(), typeNames.end(), name) == typeNames.end()) { + typeNames.push_back(std::move(name)); + } + } + + static void addSpeculativeTypeNamesFromValue( + jsi::Runtime& runtime, + const jsi::Value& value, + std::vector& typeNames + ) { + if (isNil(value)) { + return; + } + + if (value.isString()) { + addSpeculativeTypeName(typeNames, value.asString(runtime).utf8(runtime)); + return; + } + + if (value.isObject()) { + auto obj = value.asObject(runtime); + if (!obj.isArray(runtime)) { + return; + } + + auto arr = obj.asArray(runtime); + for (size_t i = 0; i < arr.size(runtime); i++) { + auto item = arr.getValueAtIndex(runtime, i); + if (item.isString()) { + addSpeculativeTypeName(typeNames, item.asString(runtime).utf8(runtime)); + } + } + } + } + + static void applySpeculativeDraftOptions( + jsi::Runtime& runtime, + const jsi::Object& obj, + common_params_speculative_draft& draft + ) { + draft.n_max = getPropertyAsInt(runtime, obj, "n_max", draft.n_max); + draft.n_min = getPropertyAsInt(runtime, obj, "n_min", draft.n_min); + draft.p_min = getPropertyAsFloat(runtime, obj, "p_min", draft.p_min); + draft.p_split = getPropertyAsFloat(runtime, obj, "p_split", draft.p_split); + } + + bool hasSpeculativeType(const common_params_speculative& speculative, common_speculative_type type) { + return std::find(speculative.types.begin(), speculative.types.end(), type) != speculative.types.end(); + } + + static void applySpeculativeTypeNames( + common_params_speculative& speculative, + const std::vector& typeNames + ) { + if (typeNames.empty()) { + return; + } + speculative.types = common_speculative_types_from_names(typeNames); + } + + static void applySpeculativeOptions(jsi::Runtime& runtime, const jsi::Object& params, common_params& cparams) { + std::vector typeNames; + + if (params.hasProperty(runtime, "spec_type")) { + addSpeculativeTypeNamesFromValue(runtime, params.getProperty(runtime, "spec_type"), typeNames); + } + + if (params.hasProperty(runtime, "speculative")) { + auto value = params.getProperty(runtime, "speculative"); + if (!isNil(value)) { + if (value.isBool()) { + addSpeculativeTypeName(typeNames, value.getBool() ? "draft-mtp" : "none"); + } else if (value.isString()) { + addSpeculativeTypeName(typeNames, value.asString(runtime).utf8(runtime)); + } else if (value.isObject()) { + auto speculative = value.asObject(runtime); + bool enabled = false; + bool hasEnabled = false; + bool hasExplicitType = false; + + if (speculative.hasProperty(runtime, "enabled")) { + auto enabledValue = speculative.getProperty(runtime, "enabled"); + if (enabledValue.isBool()) { + enabled = enabledValue.getBool(); + hasEnabled = true; + } + } + + if (speculative.hasProperty(runtime, "type")) { + const size_t oldSize = typeNames.size(); + addSpeculativeTypeNamesFromValue(runtime, speculative.getProperty(runtime, "type"), typeNames); + hasExplicitType = hasExplicitType || typeNames.size() != oldSize; + } + + if (speculative.hasProperty(runtime, "types")) { + const size_t oldSize = typeNames.size(); + addSpeculativeTypeNamesFromValue(runtime, speculative.getProperty(runtime, "types"), typeNames); + hasExplicitType = hasExplicitType || typeNames.size() != oldSize; + } + + if (hasEnabled) { + if (!enabled) { + addSpeculativeTypeName(typeNames, "none"); + } else if (!hasExplicitType) { + addSpeculativeTypeName(typeNames, "draft-mtp"); + } + } + + applySpeculativeDraftOptions(runtime, speculative, cparams.speculative.draft); + if (speculative.hasProperty(runtime, "draft")) { + auto draftValue = speculative.getProperty(runtime, "draft"); + if (draftValue.isObject()) { + applySpeculativeDraftOptions(runtime, draftValue.asObject(runtime), cparams.speculative.draft); + } + } + } + } + } + + cparams.speculative.draft.n_max = getPropertyAsInt( + runtime, params, "spec_draft_n_max", cparams.speculative.draft.n_max); + cparams.speculative.draft.n_max = getPropertyAsInt( + runtime, params, "speculative.n_max", cparams.speculative.draft.n_max); + cparams.speculative.draft.n_min = getPropertyAsInt( + runtime, params, "spec_draft_n_min", cparams.speculative.draft.n_min); + cparams.speculative.draft.n_min = getPropertyAsInt( + runtime, params, "speculative.n_min", cparams.speculative.draft.n_min); + cparams.speculative.draft.p_min = getPropertyAsFloat( + runtime, params, "spec_draft_p_min", cparams.speculative.draft.p_min); + cparams.speculative.draft.p_min = getPropertyAsFloat( + runtime, params, "speculative.p_min", cparams.speculative.draft.p_min); + cparams.speculative.draft.p_split = getPropertyAsFloat( + runtime, params, "spec_draft_p_split", cparams.speculative.draft.p_split); + cparams.speculative.draft.p_split = getPropertyAsFloat( + runtime, params, "speculative.p_split", cparams.speculative.draft.p_split); + + applySpeculativeTypeNames(cparams.speculative, typeNames); + + if (hasSpeculativeType(cparams.speculative, COMMON_SPECULATIVE_TYPE_DRAFT_MTP) && + cparams.speculative.draft.n_max <= 0) { + throw std::invalid_argument("MTP requires spec_draft_n_max > 0"); + } + } + void parseCommonParams(jsi::Runtime& runtime, const jsi::Object& params, common_params& cparams) { cparams.fit_params = false; @@ -134,6 +302,10 @@ namespace rnllama_jsi { std::string cpuMask = getPropertyAsString(runtime, params, "cpu_mask"); #if defined(__ANDROID__) set_best_cores(cparams.cpuparams, cparams.cpuparams.n_threads); +#elif defined(__APPLE__) + if (cparams.cpuparams.n_threads < 0) { + cparams.cpuparams.n_threads = default_apple_n_threads(); + } #endif cparams.n_gpu_layers = getPropertyAsInt(runtime, params, "n_gpu_layers", cparams.n_gpu_layers); @@ -231,6 +403,8 @@ namespace rnllama_jsi { } } } + + applySpeculativeOptions(runtime, params, cparams); } void parseCompletionParams(jsi::Runtime& runtime, const jsi::Object& params, rnllama::llama_rn_context* ctx) { @@ -242,6 +416,7 @@ namespace rnllama_jsi { sparams.seed = getPropertyAsInt(runtime, params, "seed", -1); ctx->params.n_predict = getPropertyAsInt(runtime, params, "n_predict", ctx->params.n_predict); ctx->params.sampling.ignore_eos = getPropertyAsBool(runtime, params, "ignore_eos", ctx->params.sampling.ignore_eos); + applySpeculativeOptions(runtime, params, ctx->params); sparams.temp = getPropertyAsDouble(runtime, params, "temperature", sparams.temp); sparams.n_probs = getPropertyAsInt(runtime, params, "n_probs", sparams.n_probs); diff --git a/cpp/jsi/JSIParams.h b/cpp/jsi/JSIParams.h index 696e487a..236106df 100644 --- a/cpp/jsi/JSIParams.h +++ b/cpp/jsi/JSIParams.h @@ -12,6 +12,7 @@ namespace rnllama_jsi { bool getPropertyAsBool(jsi::Runtime& runtime, const jsi::Object& obj, const char* name, bool defaultValue = false); float getPropertyAsFloat(jsi::Runtime& runtime, const jsi::Object& obj, const char* name, float defaultValue = 0.0f); + bool hasSpeculativeType(const common_params_speculative& speculative, common_speculative_type type); void parseCommonParams(jsi::Runtime& runtime, const jsi::Object& params, common_params& cparams); void parseCompletionParams(jsi::Runtime& runtime, const jsi::Object& params, rnllama::llama_rn_context* ctx); } diff --git a/cpp/jsi/RNLlamaJSI.cpp b/cpp/jsi/RNLlamaJSI.cpp index aa6f723c..5110ccbc 100644 --- a/cpp/jsi/RNLlamaJSI.cpp +++ b/cpp/jsi/RNLlamaJSI.cpp @@ -1064,6 +1064,10 @@ namespace rnllama_jsi { ctx->tts_wrapper->setGuideTokens(guide_tokens); } + if (!mediaPaths.empty() && ctx->completion->shouldUseMTP()) { + throw std::runtime_error("MTP speculative decoding currently supports text-only completion"); + } + if (!ctx->completion->initSampling()) { throw std::runtime_error("Failed to initialize sampling"); } @@ -1273,6 +1277,10 @@ namespace rnllama_jsi { } } + if (hasSpeculativeType(cparams.speculative, COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) { + throw std::runtime_error("MTP speculative decoding is not supported for queued parallel completions"); + } + int chat_format = getPropertyAsInt(runtime, params, "chat_format", 0); std::string reasoningFormatStr = getPropertyAsString(runtime, params, "reasoning_format", "none"); common_reasoning_format reasoning_format = common_reasoning_format_from_name(reasoningFormatStr); diff --git a/cpp/rn-completion.cpp b/cpp/rn-completion.cpp index 70ec2e6f..90d7e93c 100644 --- a/cpp/rn-completion.cpp +++ b/cpp/rn-completion.cpp @@ -5,7 +5,9 @@ #include "rn-common.hpp" #include +#include #include +#include // Include multimodal support #include "tools/mtmd/mtmd.h" @@ -21,6 +23,7 @@ llama_rn_context_completion::llama_rn_context_completion(llama_rn_context* paren // Destructor llama_rn_context_completion::~llama_rn_context_completion() { + resetSpeculative(); if (ctx_sampling != nullptr) { common_sampler_free(ctx_sampling); ctx_sampling = nullptr; @@ -28,6 +31,7 @@ llama_rn_context_completion::~llama_rn_context_completion() { } void llama_rn_context_completion::rewind() { + resetSpeculative(); is_interrupted = false; parent_ctx->params.antiprompt.clear(); parent_ctx->params.sampling.grammar = {}; @@ -37,6 +41,8 @@ void llama_rn_context_completion::rewind() { parent_ctx->params.sampling.generation_prompt.clear(); num_prompt_tokens = 0; num_tokens_predicted = 0; + num_draft_tokens = 0; + num_draft_tokens_accepted = 0; prefill_text = ""; generated_text = ""; generated_text.reserve(parent_ctx->params.n_ctx); @@ -235,8 +241,274 @@ void llama_rn_context_completion::endCompletion() { is_predicting = false; } +bool llama_rn_context_completion::shouldUseMTP() const { + const auto & types = parent_ctx->params.speculative.types; + return std::find(types.begin(), types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != types.end() && + parent_ctx->params.speculative.draft.n_max > 0; +} + +void llama_rn_context_completion::resetSpeculative() { + if (spec != nullptr) { + common_speculative_free(spec); + spec = nullptr; + } + spec_ctx.reset(); + if (spec_batch_initialized) { + llama_batch_free(spec_batch); + spec_batch = {}; + spec_batch_initialized = false; + } + spec_prompt.clear(); + spec_id_last = LLAMA_TOKEN_NULL; + spec_n_past = 0; + spec_draft.clear(); + spec_pending_tokens.clear(); +} + +void llama_rn_context_completion::initMTP() { + if (!shouldUseMTP()) { + return; + } + if (llama_model_has_encoder(parent_ctx->model)) { + throw std::runtime_error("MTP speculative decoding is only supported for decoder-only models"); + } + if (embd.empty()) { + throw std::runtime_error("MTP speculative decoding requires a non-empty prompt"); + } + + const auto n_mtp = parent_ctx->params.speculative.draft.n_max; + if ((llama_model_is_recurrent(parent_ctx->model) || llama_model_is_hybrid(parent_ctx->model)) && + llama_n_rs_seq(parent_ctx->ctx) < (uint32_t) n_mtp) { + throw std::runtime_error( + "MTP for recurrent or hybrid models must be enabled when loading the model " + "with speculative.type='draft-mtp' and speculative.n_max/spec_draft_n_max set"); + } + + resetSpeculative(); + + auto cparams = common_context_params_to_llama(parent_ctx->params); + cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + cparams.n_rs_seq = 0; + + spec_ctx.reset(llama_init_from_model(parent_ctx->model, cparams)); + if (spec_ctx == nullptr) { + throw std::runtime_error("failed to create MTP draft context"); + } + + parent_ctx->params.speculative.draft.ctx_tgt = parent_ctx->ctx; + parent_ctx->params.speculative.draft.ctx_dft = spec_ctx.get(); + + spec = common_speculative_init(parent_ctx->params.speculative, 1); + if (spec == nullptr) { + throw std::runtime_error("failed to initialize MTP speculative decoding"); + } + + spec_batch = llama_batch_init(llama_n_batch(parent_ctx->ctx), 0, 1); + spec_batch_initialized = true; + + llama_memory_clear(llama_get_memory(parent_ctx->ctx), false); + llama_memory_clear(llama_get_memory(spec_ctx.get()), false); + n_past = 0; + + evalMTPPrompt(); +} + +void llama_rn_context_completion::evalMTPPrompt() { + const llama_seq_id seq_id = 0; + const size_t n_prompt = embd.size(); + + spec_prompt.clear(); + spec_pending_tokens.clear(); + spec_draft.clear(); + spec_id_last = embd.back(); + + if (n_prompt > 1) { + spec_prompt.assign(embd.begin(), embd.end() - 1); + } + + const int32_t n_batch = std::max(1, llama_n_batch(parent_ctx->ctx)); + size_t offset = 0; + + while (offset < spec_prompt.size()) { + common_batch_clear(spec_batch); + + const size_t n_eval = std::min(n_batch, spec_prompt.size() - offset); + for (size_t i = 0; i < n_eval; ++i) { + // MTP consumes pre-norm embeddings from every target row, but prompt logits are unused. + // Keep one output row per decode batch to preserve the usual llama.cpp graph shape. + const bool needs_logits = i + 1 == n_eval; + common_batch_add(spec_batch, spec_prompt[offset + i], + (llama_pos) (offset + i), { seq_id }, needs_logits); + } + + const int ret = llama_decode(parent_ctx->ctx, spec_batch); + if (ret != 0) { + throw std::runtime_error("failed to evaluate MTP prompt batch, ret=" + std::to_string(ret)); + } + if (!common_speculative_process(spec, spec_batch)) { + throw std::runtime_error("failed to process MTP prompt batch"); + } + + offset += n_eval; + } + + spec_n_past = (llama_pos) spec_prompt.size(); + n_past = spec_n_past; + + common_speculative_begin(spec, seq_id, spec_prompt); +} + +bool llama_rn_context_completion::refillMTPTokens() { + const llama_seq_id seq_id = 0; + + if (spec_id_last == LLAMA_TOKEN_NULL || stopped_eos || stopped_limit || context_full) { + return false; + } + if (parent_ctx->params.n_predict >= 0 && n_remain == 0) { + stopped_limit = true; + has_next_token = false; + return false; + } + + const int32_t n_ctx = parent_ctx->params.n_ctx; + if (spec_n_past + 1 >= n_ctx) { + context_full = true; + has_next_token = false; + return false; + } + + spec_draft.clear(); + + const int32_t remaining = + parent_ctx->params.n_predict < 0 ? std::numeric_limits::max() : (int32_t) n_remain; + const int32_t n_draft_remaining = remaining == std::numeric_limits::max() + ? parent_ctx->params.speculative.draft.n_max + : std::max(0, remaining - 1); + const int32_t n_draft_ctx = std::max(0, n_ctx - (int32_t) spec_n_past - 1); + const int32_t n_draft_batch = std::max(0, llama_n_batch(parent_ctx->ctx) - 1); + const int32_t n_draft_limit = std::min( + parent_ctx->params.speculative.draft.n_max, + std::min(n_draft_remaining, std::min(n_draft_ctx, n_draft_batch))); + + if (n_draft_limit > 0) { + common_speculative_get_draft_params(spec, seq_id) = { + /* .drafting = */ true, + /* .n_max = */ n_draft_limit, + /* .n_past = */ spec_n_past, + /* .id_last = */ spec_id_last, + /* .prompt = */ &spec_prompt, + /* .result = */ &spec_draft, + }; + common_speculative_draft(spec); + + if ((int32_t) spec_draft.size() > n_draft_limit) { + spec_draft.resize(n_draft_limit); + } + + common_context_seq_rm(spec_ctx.get(), seq_id, spec_n_past, -1); + } + + const size_t n_draft = spec_draft.size(); + num_draft_tokens += n_draft; + + common_batch_clear(spec_batch); + common_batch_add(spec_batch, spec_id_last, spec_n_past, { seq_id }, true); + for (size_t i = 0; i < n_draft; ++i) { + common_batch_add(spec_batch, spec_draft[i], + spec_n_past + (llama_pos) i + 1, { seq_id }, true); + } + + const int ret = llama_decode(parent_ctx->ctx, spec_batch); + if (ret != 0) { + throw std::runtime_error("failed to evaluate MTP target batch, ret=" + std::to_string(ret)); + } + if (!common_speculative_process(spec, spec_batch)) { + throw std::runtime_error("failed to process MTP target batch"); + } + + auto accepted = common_sampler_sample_and_accept_n(ctx_sampling, parent_ctx->ctx, spec_draft); + if (accepted.empty()) { + return false; + } + + size_t accepted_count = accepted.size(); + bool saw_eos = false; + const llama_vocab* vocab = llama_model_get_vocab(parent_ctx->model); + for (size_t i = 0; i < accepted.size(); ++i) { + if (llama_vocab_is_eog(vocab, accepted[i])) { + accepted_count = i + 1; + saw_eos = true; + break; + } + + completion_token_output output; + output.tok = accepted[i]; + output.text = common_token_to_piece(parent_ctx->ctx, accepted[i]); + spec_pending_tokens.push_back(std::move(output)); + } + + const size_t n_accepted_draft = saw_eos + ? accepted_count - 1 + : accepted.size() - 1; + if (n_draft > 0) { + const size_t n_accepted = std::min(n_accepted_draft, n_draft); + num_draft_tokens_accepted += n_accepted; + common_speculative_accept(spec, seq_id, (uint16_t) n_accepted); + } + + for (size_t i = 0; i < accepted_count; ++i) { + spec_prompt.push_back(spec_id_last); + spec_id_last = accepted[i]; + } + + spec_n_past += (llama_pos) accepted_count; + n_past = spec_n_past; + + common_context_seq_rm(parent_ctx->ctx, seq_id, spec_n_past, -1); + common_context_seq_rm(spec_ctx.get(), seq_id, spec_n_past, -1); + + if (saw_eos) { + stopped_eos = true; + has_next_token = false; + } + + if (parent_ctx->params.n_predict >= 0) { + const size_t emitted = spec_pending_tokens.size(); + n_remain = emitted >= n_remain ? 0 : n_remain - emitted; + if (n_remain == 0 && !saw_eos) { + stopped_limit = true; + has_next_token = false; + } + } + + return !spec_pending_tokens.empty(); +} + +completion_token_output llama_rn_context_completion::nextTokenMTP() { + completion_token_output result; + result.tok = -1; + + if (spec == nullptr) { + initMTP(); + } + + if (spec_pending_tokens.empty() && !refillMTPTokens()) { + return result; + } + + result = std::move(spec_pending_tokens.front()); + spec_pending_tokens.pop_front(); + num_tokens_predicted++; + has_next_token = !spec_pending_tokens.empty() || (!stopped_eos && !stopped_limit && !context_full); + return result; +} + completion_token_output llama_rn_context_completion::nextToken() { + if (shouldUseMTP()) { + return nextTokenMTP(); + } + completion_token_output result; result.tok = -1; diff --git a/cpp/rn-completion.h b/cpp/rn-completion.h index 86b447c2..bb5992ca 100644 --- a/cpp/rn-completion.h +++ b/cpp/rn-completion.h @@ -6,6 +6,8 @@ #include "sampling.h" #include "nlohmann/json.hpp" #include "chat.h" +#include "speculative.h" +#include using json = nlohmann::ordered_json; @@ -60,6 +62,8 @@ struct llama_rn_context_completion { std::string prefill_text; std::string generated_text; std::vector generated_token_probs; + size_t num_draft_tokens = 0; + size_t num_draft_tokens_accepted = 0; size_t num_prompt_tokens = 0; size_t num_tokens_predicted = 0; llama_pos n_past = 0; @@ -81,6 +85,17 @@ struct llama_rn_context_completion { // Sampling context common_sampler *ctx_sampling = nullptr; + // Speculative decoding context for MTP. + common_speculative *spec = nullptr; + llama_context_ptr spec_ctx; + llama_batch spec_batch = {}; + bool spec_batch_initialized = false; + llama_tokens spec_prompt; + llama_token spec_id_last = LLAMA_TOKEN_NULL; + llama_pos spec_n_past = 0; + llama_tokens spec_draft; + std::deque spec_pending_tokens; + // Constructor llama_rn_context_completion(llama_rn_context* parent); @@ -96,6 +111,12 @@ struct llama_rn_context_completion { void beginCompletion(int chat_format, common_reasoning_format reasoning_format, const std::string &generation_prompt = "", const std::string &chat_parser = ""); void endCompletion(); completion_token_output nextToken(); + bool shouldUseMTP() const; + void resetSpeculative(); + void initMTP(); + void evalMTPPrompt(); + bool refillMTPTokens(); + completion_token_output nextTokenMTP(); size_t findStoppingStrings(const std::string &text, const size_t last_token_size, const stop_type type); completion_token_output doCompletion(); completion_chat_output parseChatOutput(bool is_partial); diff --git a/example/ios/Podfile b/example/ios/Podfile index 24d5ca85..5b381274 100644 --- a/example/ios/Podfile +++ b/example/ios/Podfile @@ -17,6 +17,7 @@ end ENV['RCT_NEW_ARCH_ENABLED'] = '1' ENV['RNLLAMA_BUILD_FROM_SOURCE'] = ENV['RNLLAMA_BUILD_FROM_SOURCE'] || '1' +ENV['RNLLAMA_NATIVE_CPU'] = ENV['RNLLAMA_NATIVE_CPU'] || '1' target 'RNLlamaExample' do config = use_native_modules! diff --git a/example/src/__tests__/screenRegistry.test.ts b/example/src/__tests__/screenRegistry.test.ts index 6e18bf1a..1e32f469 100644 --- a/example/src/__tests__/screenRegistry.test.ts +++ b/example/src/__tests__/screenRegistry.test.ts @@ -8,6 +8,7 @@ describe('example screen registry', () => { 'SimpleChat', 'TextCompletion', 'StructuredOutput', + 'MTPSpeculative', 'ParallelDecoding', 'Multimodal', 'ToolCalling', diff --git a/example/src/components/ExampleModelSetup.tsx b/example/src/components/ExampleModelSetup.tsx index 385d316a..7b78f408 100644 --- a/example/src/components/ExampleModelSetup.tsx +++ b/example/src/components/ExampleModelSetup.tsx @@ -36,6 +36,7 @@ interface ExampleModelSetupProps { defaultModelSectionTitle?: string customModelSectionTitle?: string addCustomModelLabel?: string + defaultModelsFirst?: boolean isLoading?: boolean initProgress?: number progressText?: string @@ -59,6 +60,7 @@ export function ExampleModelSetup({ defaultModelSectionTitle = 'Default Models', customModelSectionTitle = 'Custom Models', addCustomModelLabel = '+ Add Custom Model', + defaultModelsFirst = false, isLoading = false, initProgress = 0, progressText = '', @@ -67,101 +69,121 @@ export function ExampleModelSetup({ const { theme } = useTheme() const themedStyles = createThemedStyles(theme.colors) - return ( - - - {description} - {children} - - {customModels.length > 0 && onInitializeCustomModel && ( - <> - - {customModelSectionTitle} - - {customModels.map((model) => ( - - onInitializeCustomModel(model, modelPath, mmprojPath) - } - onModelRemoved={async () => { - if (onReloadCustomModels) { - await onReloadCustomModels() - } - }} - initializeButtonText={ - defaultModels[0]?.initializeButtonText || 'Initialize' - } - /> - ))} - - )} - - {onOpenCustomModelModal && onCloseCustomModelModal && ( - - - {addCustomModelLabel} - - - )} - + const customModelSection = + customModels.length > 0 && onInitializeCustomModel ? ( + <> - {defaultModelSectionTitle} + {customModelSectionTitle} - {defaultModels.map((model) => { - if (model.kind === 'multimodal') { - return ( - - onInitializeModel(model, modelPath, mmprojPath) - } - /> - ) - } + {customModels.map((model) => ( + + onInitializeCustomModel(model, modelPath, mmprojPath) + } + onModelRemoved={async () => { + if (onReloadCustomModels) { + await onReloadCustomModels() + } + }} + initializeButtonText={ + defaultModels[0]?.initializeButtonText || 'Initialize' + } + /> + ))} + + ) : null - if (model.kind === 'tts') { - return ( - - onInitializeModel(model, ttsPath, vocoderPath) - } - /> - ) - } + const customModelButton = + onOpenCustomModelModal && onCloseCustomModelModal ? ( + + + {addCustomModelLabel} + + + ) : null + + const defaultModelSection = ( + <> + + {defaultModelSectionTitle} + + {defaultModels.map((model) => { + if (model.kind === 'multimodal') { + return ( + + onInitializeModel(model, modelPath, mmprojPath) + } + /> + ) + } + if (model.kind === 'tts') { return ( - onInitializeModel(model, modelPath)} + onInitialize={(ttsPath, vocoderPath) => + onInitializeModel(model, ttsPath, vocoderPath) + } /> ) - })} + } + + return ( + onInitializeModel(model, modelPath)} + /> + ) + })} + + ) + + return ( + + + {description} + {children} + + {defaultModelsFirst ? ( + <> + {defaultModelSection} + {customModelSection} + {customModelButton} + + ) : ( + <> + {customModelSection} + {customModelButton} + {defaultModelSection} + + )} {onOpenCustomModelModal && onCloseCustomModelModal && ( diff --git a/example/src/config/screenMetadata.ts b/example/src/config/screenMetadata.ts index 906a0119..09c9e14d 100644 --- a/example/src/config/screenMetadata.ts +++ b/example/src/config/screenMetadata.ts @@ -21,6 +21,12 @@ export const EXAMPLE_SCREEN_METADATA: Array< homeLabel: 'Structured Output', emoji: '๐Ÿงพ', }, + { + routeName: 'MTPSpeculative', + title: 'MTP Speculative Decoding', + homeLabel: 'MTP Speculative Decoding', + emoji: '๐Ÿš€', + }, { routeName: 'ParallelDecoding', title: 'Parallel Decoding', diff --git a/example/src/config/screens.ts b/example/src/config/screens.ts index f0eaa9f4..eae0c6d8 100644 --- a/example/src/config/screens.ts +++ b/example/src/config/screens.ts @@ -6,6 +6,7 @@ import ModelInfoScreen from '../screens/ModelInfoScreen' import BenchScreen from '../screens/BenchScreen' import TextCompletionScreen from '../screens/TextCompletionScreen' import StructuredOutputScreen from '../screens/StructuredOutputScreen' +import MTPSpeculativeScreen from '../screens/MTPSpeculativeScreen' import ParallelDecodingScreen from '../screens/ParallelDecodingScreen' import EmbeddingScreen from '../screens/EmbeddingScreen' import StressTestScreen from '../screens/StressTestScreen' @@ -22,6 +23,7 @@ const SCREEN_COMPONENTS: Record< SimpleChat: SimpleChatScreen, TextCompletion: TextCompletionScreen, StructuredOutput: StructuredOutputScreen, + MTPSpeculative: MTPSpeculativeScreen, ParallelDecoding: ParallelDecodingScreen, Multimodal: MultimodalScreen, ToolCalling: ToolCallsScreen, diff --git a/example/src/screens/MTPSpeculativeScreen.tsx b/example/src/screens/MTPSpeculativeScreen.tsx new file mode 100644 index 00000000..44c44efb --- /dev/null +++ b/example/src/screens/MTPSpeculativeScreen.tsx @@ -0,0 +1,673 @@ +import React, { useCallback, useMemo, useRef, useState } from 'react' +import { + Alert, + ScrollView, + StyleSheet, + Text, + TextInput, + TouchableOpacity, + View, +} from 'react-native' +import { useSafeAreaInsets } from 'react-native-safe-area-context' +import { initLlama } from '../../../src' +import type { NativeCompletionResult } from '../../../src' +import ContextParamsModal from '../components/ContextParamsModal' +import { ExampleModelSetup } from '../components/ExampleModelSetup' +import { + ParameterSwitch, + ParameterTextInput, +} from '../components/ParameterFormFields' +import { useTheme } from '../contexts/ThemeContext' +import { useExampleContext } from '../hooks/useExampleContext' +import { useExampleScreenHeader } from '../hooks/useExampleScreenHeader' +import { + useStoredContextParams, + useStoredCustomModels, +} from '../hooks/useStoredSetting' +import { createThemedStyles, Spacing } from '../styles/commonStyles' +import type { ContextParams } from '../utils/storage' +import { createExampleModelDefinitions } from '../utils/exampleModels' + +const DEFAULT_PROMPT = + 'Write a concise TypeScript function that groups an array of objects by a key.' + +const DEFAULT_DRAFT_TOKENS = 3 +const MAX_DRAFT_TOKENS = 32 +const DEFAULT_MAX_TOKENS = 128 +const MTP_CONTEXT = 4096 +const MTP_BATCH = 1024 +const MTP_UBATCH = 512 +const OUTPUT_FLUSH_INTERVAL_MS = 250 + +const MTP_MODELS = createExampleModelDefinitions( + ['QWEN_3_5_4B_MTP', 'QWEN_3_6_35B_A3B_MTP'], + 'Initialize MTP Model', +) + +type MTPRunMetrics = { + predicted: number + drafted: number + accepted: number + acceptRate: number + wallSeconds: number + tokensPerSecond: number +} + +function parseBoundedInteger( + value: string, + fallback: number, + min: number, + max: number, +) { + const parsed = Number.parseInt(value, 10) + if (Number.isNaN(parsed)) return fallback + return Math.max(min, Math.min(max, parsed)) +} + +function createMTPRunMetrics( + result: NativeCompletionResult, + wallSeconds: number, +): MTPRunMetrics { + const predicted = result.tokens_predicted || 0 + const drafted = result.draft_tokens || 0 + const accepted = result.draft_tokens_accepted || 0 + return { + predicted, + drafted, + accepted, + acceptRate: drafted > 0 ? accepted / drafted : 0, + wallSeconds, + tokensPerSecond: wallSeconds > 0 ? predicted / wallSeconds : 0, + } +} + +function logMTPMetrics(metrics: MTPRunMetrics) { + console.log( + [ + 'MTP metrics:', + ` predicted: ${metrics.predicted}`, + ` drafted: ${metrics.drafted}`, + ` accepted: ${metrics.accepted}`, + ` accept_rate: ${metrics.acceptRate.toFixed(3)}`, + ` wall_seconds: ${metrics.wallSeconds.toFixed(2)}`, + ` tokens_per_second: ${metrics.tokensPerSecond.toFixed(2)}`, + ].join('\n'), + ) +} + +export default function MTPSpeculativeScreen({ + navigation, +}: { + navigation: any +}) { + const { theme } = useTheme() + const themedStyles = createThemedStyles(theme.colors) + const styles = createStyles(theme) + const insets = useSafeAreaInsets() + const [prompt, setPrompt] = useState(DEFAULT_PROMPT) + const [output, setOutput] = useState('') + const [isLoading, setIsLoading] = useState(false) + const [isGenerating, setIsGenerating] = useState(false) + const [showContextParamsModal, setShowContextParamsModal] = useState(false) + const [showCustomModelModal, setShowCustomModelModal] = useState(false) + const [draftTokensText, setDraftTokensText] = useState( + DEFAULT_DRAFT_TOKENS.toString(), + ) + const [maxTokensText, setMaxTokensText] = useState( + DEFAULT_MAX_TOKENS.toString(), + ) + const [isMTPEnabled, setIsMTPEnabled] = useState(true) + const [draftCapacity, setDraftCapacity] = useState(DEFAULT_DRAFT_TOKENS) + const [lastResult, setLastResult] = useState( + null, + ) + const [lastRunMetrics, setLastRunMetrics] = useState( + null, + ) + const outputBufferRef = useRef('') + const lastOutputFlushAtRef = useRef(0) + const { + context, + initProgress, + isModelReady, + replaceContext, + setInitProgress, + } = useExampleContext() + const { value: contextParams, setValue: setContextParams } = + useStoredContextParams() + const { value: customModels, reload: reloadCustomModels } = + useStoredCustomModels() + + const draftTokens = useMemo( + () => + parseBoundedInteger( + draftTokensText, + DEFAULT_DRAFT_TOKENS, + 1, + MAX_DRAFT_TOKENS, + ), + [draftTokensText], + ) + const maxTokens = useMemo( + () => parseBoundedInteger(maxTokensText, DEFAULT_MAX_TOKENS, 1, 4096), + [maxTokensText], + ) + const displayedMetrics = useMemo(() => { + if (!lastResult) return null + return lastRunMetrics ?? createMTPRunMetrics(lastResult, 0) + }, [lastResult, lastRunMetrics]) + + const handleReset = useCallback(async () => { + outputBufferRef.current = '' + lastOutputFlushAtRef.current = 0 + setOutput('') + setLastResult(null) + setLastRunMetrics(null) + setPrompt(DEFAULT_PROMPT) + if (context) { + await context.clearCache(false) + } + }, [context]) + + useExampleScreenHeader({ + navigation, + isModelReady, + readyActions: [ + { + key: 'reset', + iconName: 'refresh', + onPress: handleReset, + }, + ], + setupActions: [ + { + key: 'context-settings', + iconName: 'cog-outline', + onPress: () => setShowContextParamsModal(true), + }, + ], + }) + + const handleSaveContextParams = (params: ContextParams) => { + setContextParams(params) + } + + const handleInitModel = async ( + modelUri: string, + params?: ContextParams, + ) => { + setIsLoading(true) + setInitProgress(0) + + try { + const baseParams = params || contextParams || {} + const initDraftTokens = draftTokens + const ctx = await initLlama( + { + ...baseParams, + model: modelUri, + use_mlock: false, + use_mmap: true, + n_ctx: MTP_CONTEXT, + n_batch: MTP_BATCH, + n_ubatch: MTP_UBATCH, + n_parallel: 1, + n_gpu_layers: baseParams.n_gpu_layers ?? 99, + flash_attn_type: 'auto', + cache_type_k: 'q8_0', + cache_type_v: 'q8_0', + ctx_shift: true, + kv_unified: false, + swa_full: false, + no_extra_bufts: false, + speculative: { + type: 'draft-mtp', + n_max: initDraftTokens, + }, + spec_draft_n_max: initDraftTokens, + }, + (progress) => { + setInitProgress(progress) + }, + ) + + await replaceContext(ctx) + console.log( + [ + 'MTP context:', + ` devices: ${ctx.devices?.join(', ') || 'N/A'}`, + ` system_info: ${ctx.systemInfo}`, + ].join('\n'), + ) + setOutput('') + setLastResult(null) + setLastRunMetrics(null) + setDraftCapacity(initDraftTokens) + setInitProgress(100) + } catch (error) { + Alert.alert('Error', `Failed to initialize MTP model: ${error}`) + } finally { + setIsLoading(false) + setInitProgress(0) + } + } + + const handleGenerate = async () => { + if (!context) { + Alert.alert('Error', 'Initialize a model before generating.') + return + } + + const trimmedPrompt = prompt.trim() + if (!trimmedPrompt) { + Alert.alert('Error', 'Please enter a prompt.') + return + } + if (isMTPEnabled && draftTokens > draftCapacity) { + Alert.alert( + 'Error', + `Draft Tokens cannot exceed the initialized MTP capacity (${draftCapacity}). Reinitialize the model to use a larger value.`, + ) + return + } + + setIsGenerating(true) + outputBufferRef.current = '' + lastOutputFlushAtRef.current = Date.now() + setOutput('') + setLastResult(null) + setLastRunMetrics(null) + + try { + const startedAt = Date.now() + const result = await context.completion( + { + messages: [ + { + role: 'user', + content: trimmedPrompt, + }, + ], + chat_template_kwargs: { + preserve_thinking: true, + }, + n_predict: maxTokens, + temperature: 0.6, + top_k: 20, + top_p: 0.95, + speculative: isMTPEnabled + ? { + type: 'draft-mtp', + n_max: draftTokens, + } + : false, + spec_draft_n_max: isMTPEnabled ? draftTokens : 0, + }, + (tokenData) => { + outputBufferRef.current += tokenData.token + const now = Date.now() + if (now - lastOutputFlushAtRef.current >= OUTPUT_FLUSH_INTERVAL_MS) { + lastOutputFlushAtRef.current = now + setOutput(outputBufferRef.current) + } + }, + ) + const elapsedSeconds = (Date.now() - startedAt) / 1000 + const metrics = createMTPRunMetrics(result, elapsedSeconds) + logMTPMetrics(metrics) + + setLastResult(result) + setLastRunMetrics(metrics) + const finalText = result.content || result.text + setOutput(finalText || outputBufferRef.current) + } catch (error) { + setOutput(outputBufferRef.current) + if (error !== 'aborted') { + Alert.alert('Error', `Failed to generate: ${error}`) + } + } finally { + setIsGenerating(false) + } + } + + const handleStop = async () => { + if (!context) return + try { + await context.stopCompletion() + } catch (error) { + console.warn('Failed to stop completion:', error) + } + } + + if (!isModelReady) { + return ( + <> + + handleInitModel(modelPath) + } + onInitializeModel={(_model, modelPath) => handleInitModel(modelPath)} + onReloadCustomModels={reloadCustomModels} + showCustomModelModal={showCustomModelModal} + onOpenCustomModelModal={() => setShowCustomModelModal(true)} + onCloseCustomModelModal={() => setShowCustomModelModal(false)} + isLoading={isLoading} + initProgress={initProgress} + progressText={`Initializing MTP model... ${initProgress}%`} + > + + Requirements + + MTP works only with text-only models that include draft prediction + layers. Multimodal prompts are intentionally disabled for this + demo. + + + + + setShowContextParamsModal(false)} + onSave={handleSaveContextParams} + /> + + ) + } + + return ( + + + + Prompt + + + + + + + + + + + + + + + {isGenerating ? ( + + Stop Generation + + ) : ( + + Generate + + )} + + + Output + + {output || 'Generated text will appear here.'} + + + + {lastResult && displayedMetrics && ( + + MTP Metrics + + + + + + 0 + ? `${displayedMetrics.wallSeconds.toFixed(2)} s` + : '--' + } + styles={styles} + /> + 0 + ? `${displayedMetrics.tokensPerSecond.toFixed(2)} t/s` + : '--' + } + styles={styles} + /> + + + + + )} + + + setShowContextParamsModal(false)} + onSave={handleSaveContextParams} + /> + + ) +} + +function MetricItem({ + label, + value, + styles, +}: { + label: string + value: string + styles: ReturnType +}) { + return ( + + {label} + {value} + + ) +} + +function createStyles(theme: any) { + return StyleSheet.create({ + container: { + flex: 1, + }, + content: { + padding: Spacing.lg, + }, + setupNote: { + backgroundColor: theme.colors.surface, + borderColor: theme.colors.border, + borderWidth: 1, + borderRadius: Spacing.sm, + padding: Spacing.lg, + marginBottom: Spacing.xl, + }, + setupNoteTitle: { + color: theme.colors.text, + fontSize: 16, + fontWeight: '700', + marginBottom: Spacing.xs, + }, + setupNoteText: { + color: theme.colors.textSecondary, + fontSize: 14, + lineHeight: 20, + }, + section: { + backgroundColor: theme.colors.surface, + borderRadius: Spacing.sm, + borderWidth: 1, + borderColor: theme.colors.border, + padding: Spacing.lg, + marginBottom: Spacing.md, + }, + label: { + color: theme.colors.text, + fontSize: 16, + fontWeight: '700', + marginBottom: Spacing.sm, + }, + textArea: { + backgroundColor: theme.colors.inputBackground, + borderWidth: 1, + borderColor: theme.colors.border, + borderRadius: Spacing.sm, + color: theme.colors.text, + fontSize: 15, + lineHeight: 21, + padding: Spacing.md, + textAlignVertical: 'top', + }, + promptInput: { + minHeight: 140, + }, + controlsGrid: { + flexDirection: 'row', + flexWrap: 'wrap', + gap: Spacing.md, + }, + controlItem: { + flexGrow: 1, + flexBasis: 220, + }, + actionButton: { + backgroundColor: theme.colors.primary, + borderRadius: Spacing.sm, + paddingVertical: Spacing.md, + alignItems: 'center', + marginBottom: Spacing.md, + }, + stopButton: { + backgroundColor: theme.colors.error, + }, + actionButtonText: { + color: theme.colors.white, + fontSize: 16, + fontWeight: '700', + }, + outputText: { + color: theme.colors.text, + fontSize: 15, + lineHeight: 22, + minHeight: 120, + }, + metricsPanel: { + backgroundColor: theme.colors.surface, + borderRadius: Spacing.sm, + borderWidth: 1, + borderColor: theme.colors.border, + padding: Spacing.lg, + }, + metricsTitle: { + color: theme.colors.text, + fontSize: 16, + fontWeight: '700', + marginBottom: Spacing.md, + }, + metricsGrid: { + flexDirection: 'row', + flexWrap: 'wrap', + gap: Spacing.sm, + }, + metricItem: { + backgroundColor: theme.colors.card, + borderRadius: Spacing.sm, + paddingHorizontal: Spacing.md, + paddingVertical: Spacing.sm, + minWidth: 120, + flexGrow: 1, + }, + metricLabel: { + color: theme.colors.textSecondary, + fontSize: 12, + marginBottom: 2, + }, + metricValue: { + color: theme.colors.text, + fontSize: 14, + fontWeight: '700', + }, + }) +} diff --git a/example/src/types/example.ts b/example/src/types/example.ts index 6c63bc33..c629e55f 100644 --- a/example/src/types/example.ts +++ b/example/src/types/example.ts @@ -5,6 +5,7 @@ export type ExampleRouteName = | 'SimpleChat' | 'TextCompletion' | 'StructuredOutput' + | 'MTPSpeculative' | 'ParallelDecoding' | 'Multimodal' | 'ToolCalling' diff --git a/example/src/utils/constants.ts b/example/src/utils/constants.ts index 84aab94e..36624a97 100644 --- a/example/src/utils/constants.ts +++ b/example/src/utils/constants.ts @@ -27,6 +27,20 @@ export const MODELS = { mmproj: undefined, size: '1.93GB', }, + QWEN_3_5_4B_MTP: { + name: 'Qwen3.5 4B MTP (Q8_0)', + repo: 'unsloth/Qwen3.5-4B-MTP-GGUF', + filename: 'Qwen3.5-4B-Q8_0.gguf', + mmproj: undefined, + size: '4.3GB', + }, + QWEN_3_6_35B_A3B_MTP: { + name: 'Qwen3.6 35B A3B MTP (Q8_0)', + repo: 'unsloth/Qwen3.6-35B-A3B-MTP-GGUF', + filename: 'Qwen3.6-35B-A3B-Q8_0.gguf', + mmproj: undefined, + size: '37.8GB', + }, SMOL_VLM_500M: { name: 'SmolVLM 500M Instruct (Q8_0)', repo: 'ggml-org/SmolVLM-500M-Instruct-GGUF', @@ -127,7 +141,7 @@ export const MODELS = { mmproj: undefined, ranking: true, size: '636MB', - } + }, } export const HUGGINGFACE_BASE_URL = 'https://huggingface.co' diff --git a/ios/CMakeLists.txt b/ios/CMakeLists.txt index 91be6f61..3e85f394 100644 --- a/ios/CMakeLists.txt +++ b/ios/CMakeLists.txt @@ -55,6 +55,7 @@ set(PUBLIC_HEADERS ${SOURCE_DIR}/rn-tts.h ${SOURCE_DIR}/llama.h ${SOURCE_DIR}/llama-impl.h + ${SOURCE_DIR}/common/speculative.h ${SOURCE_DIR}/ggml.h ) @@ -100,6 +101,7 @@ add_library(rnllama SHARED # Headers (needed for build) ${SOURCE_DIR}/common/chat.h ${SOURCE_DIR}/common/common.h + ${SOURCE_DIR}/common/speculative.h # Multimodal support (globbed) ${MTMD_MODEL_FILES} diff --git a/llama-rn.podspec b/llama-rn.podspec index e1d28fd0..53304d97 100644 --- a/llama-rn.podspec +++ b/llama-rn.podspec @@ -2,7 +2,7 @@ require "json" package = JSON.parse(File.read(File.join(__dir__, "package.json"))) base_ld_flags = "-framework Accelerate -framework Foundation -framework Metal -framework MetalKit" -base_compiler_flags = "-fno-objc-arc -DLM_GGML_USE_CPU -DLM_GGML_USE_ACCELERATE -DLM_GGML_USE_BLAS -DLM_GGML_BLAS_USE_ACCELERATE -Wno-shorten-64-to-32" +base_compiler_flags = "-fno-objc-arc -DLM_GGML_USE_CPU -DLM_GGML_USE_ACCELERATE -DLM_GGML_USE_BLAS -DLM_GGML_BLAS_USE_ACCELERATE -DLM_GGML_USE_CPU_REPACK -Wno-shorten-64-to-32" if ENV["RNLLAMA_DISABLE_METAL"] != "1" then base_compiler_flags += " -DLM_GGML_USE_METAL" # -DLM_GGML_METAL_NDEBUG @@ -10,7 +10,12 @@ end # Use base_optimizer_flags = "" for debug builds # base_optimizer_flags = "" -base_optimizer_flags = "-O3 -DNDEBUG" +base_optimizer_flags = "-O3 -DNDEBUG -funroll-loops" + +if ENV["RNLLAMA_NATIVE_CPU"] == "1" then + apple_cpu_flags = ENV["RNLLAMA_NATIVE_CPU_FLAGS"] || "-Xarch_arm64 -mcpu=apple-m2" + base_optimizer_flags += " -U__ARM_FEATURE_SVE -U__ARM_FEATURE_SME #{apple_cpu_flags}" +end Pod::Spec.new do |s| s.name = "llama-rn" diff --git a/scripts/build-ios.sh b/scripts/build-ios.sh index 3516ee46..9616aba5 100755 --- a/scripts/build-ios.sh +++ b/scripts/build-ios.sh @@ -34,6 +34,7 @@ copy_headers() { cp "$ROOT_DIR"/cpp/common/chat.h "$framework_path/Headers/" cp "$ROOT_DIR"/cpp/common/common.h "$framework_path/Headers/" cp "$ROOT_DIR"/cpp/common/sampling.h "$framework_path/Headers/" + cp "$ROOT_DIR"/cpp/common/speculative.h "$framework_path/Headers/" cp "$ROOT_DIR"/cpp/common/json-schema-to-grammar.h "$framework_path/Headers/" cp "$ROOT_DIR"/cpp/common/peg-parser.h "$framework_path/Headers/" } diff --git a/scripts/patches/ggml-metal-mul-mv-id-tiitg.patch b/scripts/patches/ggml-metal-mul-mv-id-tiitg.patch new file mode 100644 index 00000000..8219f3ca --- /dev/null +++ b/scripts/patches/ggml-metal-mul-mv-id-tiitg.patch @@ -0,0 +1,29 @@ +--- ggml-metal/ggml-metal.metal.orig ++++ ggml-metal/ggml-metal.metal +@@ -13355,7 +13355,7 @@ void mmv_fn( + device char * dst, + threadgroup char * shmem, + uint3 tgpig, +- ushort tiitg, ++ uint tiitg, + ushort tiisg, + ushort sgitg) { + disp_fn(args, src0, src1, dst, tgpig, tiisg); +@@ -13369,7 +13369,7 @@ void mmv_fn( + device char * dst, + threadgroup char * shmem, + uint3 tgpig, +- ushort tiitg, ++ uint tiitg, + ushort tiisg, + ushort sgitg) { + disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +@@ -13385,7 +13385,7 @@ kernel void kernel_mul_mv_id( + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], +- ushort tiitg[[thread_index_in_threadgroup]], ++ uint tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const int iid1 = tgpig.z/args.nei0; diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index a087014e..d3866dcb 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -6,13 +6,14 @@ jest.mock('..', () => require('../../jest/mock')) Math.random = () => 0.5 -test('LoRA inputs are normalized and deduplicated', async () => { +test('LoRA and speculative inputs are passed through', async () => { await NativeModules.RNLlama.install() const mocks = global as typeof globalThis & { llamaInitContext: jest.Mock llamaApplyLoraAdapters: jest.Mock + llamaCompletion: jest.Mock } - const { llamaInitContext, llamaApplyLoraAdapters } = mocks + const { llamaInitContext, llamaApplyLoraAdapters, llamaCompletion } = mocks const context = await initLlama({ model: 'test.gguf', @@ -22,6 +23,14 @@ test('LoRA inputs are normalized and deduplicated', async () => { { path: 'file:///adapter-a.gguf', scaled: 0.75 }, { path: 'file:///adapter-b.gguf', scaled: 0.5 }, ], + speculative: { + enabled: true, + draft: { + n_max: 4, + p_min: 0.6, + }, + }, + spec_draft_n_min: 1, }) const initParams = llamaInitContext.mock.calls.at(-1)?.[1] @@ -31,10 +40,29 @@ test('LoRA inputs are normalized and deduplicated', async () => { { path: '/adapter-a.gguf', scaled: 0.75 }, { path: '/adapter-b.gguf', scaled: 0.5 }, ], + speculative: { + enabled: true, + draft: { + n_max: 4, + p_min: 0.6, + }, + }, + spec_draft_n_min: 1, }), ) expect(initParams).not.toHaveProperty('lora') + await context.completion({ + prompt: 'Test', + speculative: false, + }) + + expect(llamaCompletion.mock.calls.at(-1)?.[1]).toEqual( + expect.objectContaining({ + speculative: false, + }), + ) + await context.applyLoraAdapters([ { path: 'file:///adapter-b.gguf', scaled: 0.5 }, { path: 'file:///adapter-b.gguf', scaled: 1.0 }, diff --git a/src/index.ts b/src/index.ts index 2bfd29b0..2f024aee 100644 --- a/src/index.ts +++ b/src/index.ts @@ -21,6 +21,9 @@ import type { NativeImageProcessingResult, NativeLlamaChatMessage, NativeBackendDeviceInfo, + NativeSpeculativeConfig, + NativeSpeculativeParams, + NativeSpeculativeType, ParallelStatus, ParallelRequestStatus, } from './types' @@ -64,6 +67,9 @@ export type { JinjaFormattedChatResult, NativeImageProcessingResult, NativeBackendDeviceInfo, + NativeSpeculativeConfig, + NativeSpeculativeParams, + NativeSpeculativeType, ParallelStatus, ParallelRequestStatus, } @@ -244,6 +250,8 @@ export type CompletionResponseFormat = { schema?: object // for json_object type } +export type ChatTemplateKwargs = Record + export type CompletionBaseParams = { prompt?: string messages?: RNLlamaOAICompatibleMessage[] @@ -260,7 +268,7 @@ export type CompletionBaseParams = { * Timestamp in seconds since epoch to apply to chat template's strftime_now */ now?: string | number - chat_template_kwargs?: Record + chat_template_kwargs?: ChatTemplateKwargs /** * When enabled, forces the chat parser to treat the entire model output as * plain content, skipping separate parsing of reasoning tokens and tool calls. @@ -274,6 +282,7 @@ export type CompletionBaseParams = { */ prefill_text?: string } + export type CompletionParams = Omit< NativeCompletionParams, 'emit_partial_completion' | 'prompt' @@ -667,7 +676,7 @@ export class LlamaContext { reasoning_format?: 'none' | 'auto' | 'deepseek' add_generation_prompt?: boolean now?: string | number - chat_template_kwargs?: Record + chat_template_kwargs?: ChatTemplateKwargs force_pure_content?: boolean }, ): Promise { diff --git a/src/types.ts b/src/types.ts index d56d3ceb..83ab25a9 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2,6 +2,35 @@ export type NativeEmbeddingParams = { embd_normalize?: number } +export type NativeSpeculativeType = + | 'none' + | 'draft-mtp' + /** + * Alias for draft-mtp. + */ + | 'mtp' + +export type NativeSpeculativeParams = { + enabled?: boolean + type?: NativeSpeculativeType + types?: Array + n_max?: number + n_min?: number + p_min?: number + p_split?: number + draft?: { + n_max?: number + n_min?: number + p_min?: number + p_split?: number + } +} + +export type NativeSpeculativeConfig = + | NativeSpeculativeParams + | NativeSpeculativeType + | boolean + export type NativeContextParams = { model: string /** @@ -99,6 +128,18 @@ export type NativeContextParams = { rope_freq_base?: number rope_freq_scale?: number + /** + * Enable speculative decoding support at context creation time. + * MTP on recurrent/hybrid models must be enabled here so llama.cpp can + * allocate recurrent-state rollback slots. + */ + speculative?: NativeSpeculativeConfig + spec_type?: NativeSpeculativeType | Array + spec_draft_n_max?: number + spec_draft_n_min?: number + spec_draft_p_min?: number + spec_draft_p_split?: number + pooling_type?: number /** @@ -209,6 +250,16 @@ export type NativeCompletionParams = { * Default: `0` */ n_probs?: number + /** + * Per-completion speculative decoding override. For MTP on recurrent/hybrid + * models, load the model with matching MTP options first. + */ + speculative?: NativeSpeculativeConfig + spec_type?: NativeSpeculativeType | Array + spec_draft_n_max?: number + spec_draft_n_min?: number + spec_draft_p_min?: number + spec_draft_p_split?: number /** * Limit the next token selection to the K most probable tokens. Default: `40` */ @@ -413,6 +464,8 @@ export type NativeCompletionResult = { tokens_predicted: number tokens_evaluated: number + draft_tokens: number + draft_tokens_accepted: number truncated: boolean stopped_eos: boolean stopped_word: string