Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,6 @@ jobs:
- name: Build example for Android
run: |
sed -i 's/rnllamaBuildFromSource=true/rnllamaBuildFromSource=false/g' example/android/gradle.properties
npm run build:android
sed -i 's/reactNativeArchitectures=.*/reactNativeArchitectures=arm64-v8a,x86_64/g' example/android/gradle.properties
cd example/android
./gradlew assembleDebug --stacktrace
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,53 @@ Please visit the [Documentation](docs/API) for more details.

You can also visit the [example](example) to see how to use it.

## MTP Speculative Decoding

MTP speculative decoding can be enabled for GGUF models that contain MTP/NextN layers:

```js
const context = await initLlama({
model: modelPath,
n_ctx: 4096,
n_batch: 1024,
n_ubatch: 512,
n_gpu_layers: 99,
flash_attn_type: 'auto',
cache_type_k: 'q8_0',
cache_type_v: 'q8_0',
speculative: {
type: 'draft-mtp',
n_max: 3,
},
})

const result = await context.completion({
messages: [
{
role: 'user',
content:
'Write a concise TypeScript function that groups an array of objects by a key.',
},
],
chat_template_kwargs: {
preserve_thinking: true,
},
n_predict: 128,
temperature: 0.6,
top_k: 20,
top_p: 0.95,
speculative: {
type: 'draft-mtp',
n_max: 3,
},
})

console.log(result.text)
console.log(result.draft_tokens, result.draft_tokens_accepted)
```

Use `speculative: false` on a completion call to disable MTP for that request. For recurrent or hybrid models, enable MTP at `initLlama` time with a positive `spec_draft_n_max` or `speculative.draft.n_max` so llama.cpp can allocate rollback state. Current MTP support is text-only and is not used by queued parallel completions.

## Multimodal (Vision & Audio)

`llama.rn` supports multimodal capabilities including vision (images) and audio processing. This allows you to interact with models that can understand both text and media content.
Expand Down
6 changes: 3 additions & 3 deletions cpp/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -13354,7 +13354,7 @@ void mmv_fn(
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiitg,
uint tiitg,
ushort tiisg,
ushort sgitg) {
disp_fn(args, src0, src1, dst, tgpig, tiisg);
Expand All @@ -13368,7 +13368,7 @@ void mmv_fn(
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiitg,
uint tiitg,
ushort tiisg,
ushort sgitg) {
disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
Expand All @@ -13385,7 +13385,7 @@ kernel void kernel_mul_mv_id(
device const char * ids,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
uint tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const int iid1 = tgpig.z/args.nei0;
Expand Down
4 changes: 4 additions & 0 deletions cpp/jsi/JSICompletion.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ namespace rnllama_jsi {
);
res.setProperty(runtime, "tokens_predicted", (double)ctx->completion->num_tokens_predicted);
res.setProperty(runtime, "tokens_evaluated", (double)ctx->completion->num_prompt_tokens);
res.setProperty(runtime, "draft_tokens", (double)ctx->completion->num_draft_tokens);
res.setProperty(runtime, "draft_tokens_accepted", (double)ctx->completion->num_draft_tokens_accepted);
res.setProperty(runtime, "truncated", ctx->completion->truncated);
res.setProperty(runtime, "context_full", ctx->completion->context_full);
res.setProperty(runtime, "interrupted", ctx->completion->is_interrupted);
Expand Down Expand Up @@ -230,6 +232,8 @@ namespace rnllama_jsi {

res.setProperty(runtime, "tokens_predicted", (double)slot->num_tokens_predicted);
res.setProperty(runtime, "tokens_evaluated", (double)slot->num_prompt_tokens);
res.setProperty(runtime, "draft_tokens", 0.0);
res.setProperty(runtime, "draft_tokens_accepted", 0.0);
res.setProperty(runtime, "truncated", slot->truncated);
res.setProperty(runtime, "context_full", slot->context_full);
res.setProperty(runtime, "interrupted", slot->is_interrupted);
Expand Down
175 changes: 175 additions & 0 deletions cpp/jsi/JSIParams.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "JSIParams.h"
#if defined(RNLLAMA_USE_FRAMEWORK_HEADERS)
#include <rnllama/speculative.h>
#else
#include "speculative.h"
#endif
#include <cmath>
#include <algorithm>
#include <list>
Expand Down Expand Up @@ -57,6 +62,12 @@ namespace rnllama_jsi {
}
#endif

#if defined(__APPLE__)
static int default_apple_n_threads() {
return std::max(1, common_cpu_get_num_math() / 2);
}
#endif

std::string getPropertyAsString(jsi::Runtime& runtime, const jsi::Object& obj, const char* name, const std::string& defaultValue) {
if (obj.hasProperty(runtime, name)) {
auto val = obj.getProperty(runtime, name);
Expand Down Expand Up @@ -107,6 +118,163 @@ namespace rnllama_jsi {
return defaultValue;
}

static bool isNil(const jsi::Value& value) {
return value.isNull() || value.isUndefined();
}

static std::string normalizeSpeculativeTypeName(std::string name) {
if (name == "mtp") {
return "draft-mtp";
}
return name;
}

static void addSpeculativeTypeName(std::vector<std::string>& typeNames, std::string name) {
name = normalizeSpeculativeTypeName(std::move(name));
if (std::find(typeNames.begin(), typeNames.end(), name) == typeNames.end()) {
typeNames.push_back(std::move(name));
}
}

static void addSpeculativeTypeNamesFromValue(
jsi::Runtime& runtime,
const jsi::Value& value,
std::vector<std::string>& typeNames
) {
if (isNil(value)) {
return;
}

if (value.isString()) {
addSpeculativeTypeName(typeNames, value.asString(runtime).utf8(runtime));
return;
}

if (value.isObject()) {
auto obj = value.asObject(runtime);
if (!obj.isArray(runtime)) {
return;
}

auto arr = obj.asArray(runtime);
for (size_t i = 0; i < arr.size(runtime); i++) {
auto item = arr.getValueAtIndex(runtime, i);
if (item.isString()) {
addSpeculativeTypeName(typeNames, item.asString(runtime).utf8(runtime));
}
}
}
}

static void applySpeculativeDraftOptions(
jsi::Runtime& runtime,
const jsi::Object& obj,
common_params_speculative_draft& draft
) {
draft.n_max = getPropertyAsInt(runtime, obj, "n_max", draft.n_max);
draft.n_min = getPropertyAsInt(runtime, obj, "n_min", draft.n_min);
draft.p_min = getPropertyAsFloat(runtime, obj, "p_min", draft.p_min);
draft.p_split = getPropertyAsFloat(runtime, obj, "p_split", draft.p_split);
}

bool hasSpeculativeType(const common_params_speculative& speculative, common_speculative_type type) {
return std::find(speculative.types.begin(), speculative.types.end(), type) != speculative.types.end();
}

static void applySpeculativeTypeNames(
common_params_speculative& speculative,
const std::vector<std::string>& typeNames
) {
if (typeNames.empty()) {
return;
}
speculative.types = common_speculative_types_from_names(typeNames);
}

static void applySpeculativeOptions(jsi::Runtime& runtime, const jsi::Object& params, common_params& cparams) {
std::vector<std::string> typeNames;

if (params.hasProperty(runtime, "spec_type")) {
addSpeculativeTypeNamesFromValue(runtime, params.getProperty(runtime, "spec_type"), typeNames);
}

if (params.hasProperty(runtime, "speculative")) {
auto value = params.getProperty(runtime, "speculative");
if (!isNil(value)) {
if (value.isBool()) {
addSpeculativeTypeName(typeNames, value.getBool() ? "draft-mtp" : "none");
} else if (value.isString()) {
addSpeculativeTypeName(typeNames, value.asString(runtime).utf8(runtime));
} else if (value.isObject()) {
auto speculative = value.asObject(runtime);
bool enabled = false;
bool hasEnabled = false;
bool hasExplicitType = false;

if (speculative.hasProperty(runtime, "enabled")) {
auto enabledValue = speculative.getProperty(runtime, "enabled");
if (enabledValue.isBool()) {
enabled = enabledValue.getBool();
hasEnabled = true;
}
}

if (speculative.hasProperty(runtime, "type")) {
const size_t oldSize = typeNames.size();
addSpeculativeTypeNamesFromValue(runtime, speculative.getProperty(runtime, "type"), typeNames);
hasExplicitType = hasExplicitType || typeNames.size() != oldSize;
}

if (speculative.hasProperty(runtime, "types")) {
const size_t oldSize = typeNames.size();
addSpeculativeTypeNamesFromValue(runtime, speculative.getProperty(runtime, "types"), typeNames);
hasExplicitType = hasExplicitType || typeNames.size() != oldSize;
}

if (hasEnabled) {
if (!enabled) {
addSpeculativeTypeName(typeNames, "none");
} else if (!hasExplicitType) {
addSpeculativeTypeName(typeNames, "draft-mtp");
}
}

applySpeculativeDraftOptions(runtime, speculative, cparams.speculative.draft);
if (speculative.hasProperty(runtime, "draft")) {
auto draftValue = speculative.getProperty(runtime, "draft");
if (draftValue.isObject()) {
applySpeculativeDraftOptions(runtime, draftValue.asObject(runtime), cparams.speculative.draft);
}
}
}
}
}

cparams.speculative.draft.n_max = getPropertyAsInt(
runtime, params, "spec_draft_n_max", cparams.speculative.draft.n_max);
cparams.speculative.draft.n_max = getPropertyAsInt(
runtime, params, "speculative.n_max", cparams.speculative.draft.n_max);
cparams.speculative.draft.n_min = getPropertyAsInt(
runtime, params, "spec_draft_n_min", cparams.speculative.draft.n_min);
cparams.speculative.draft.n_min = getPropertyAsInt(
runtime, params, "speculative.n_min", cparams.speculative.draft.n_min);
cparams.speculative.draft.p_min = getPropertyAsFloat(
runtime, params, "spec_draft_p_min", cparams.speculative.draft.p_min);
cparams.speculative.draft.p_min = getPropertyAsFloat(
runtime, params, "speculative.p_min", cparams.speculative.draft.p_min);
cparams.speculative.draft.p_split = getPropertyAsFloat(
runtime, params, "spec_draft_p_split", cparams.speculative.draft.p_split);
cparams.speculative.draft.p_split = getPropertyAsFloat(
runtime, params, "speculative.p_split", cparams.speculative.draft.p_split);

applySpeculativeTypeNames(cparams.speculative, typeNames);

if (hasSpeculativeType(cparams.speculative, COMMON_SPECULATIVE_TYPE_DRAFT_MTP) &&
cparams.speculative.draft.n_max <= 0) {
throw std::invalid_argument("MTP requires spec_draft_n_max > 0");
}
}

void parseCommonParams(jsi::Runtime& runtime, const jsi::Object& params, common_params& cparams) {
cparams.fit_params = false;

Expand Down Expand Up @@ -134,6 +302,10 @@ namespace rnllama_jsi {
std::string cpuMask = getPropertyAsString(runtime, params, "cpu_mask");
#if defined(__ANDROID__)
set_best_cores(cparams.cpuparams, cparams.cpuparams.n_threads);
#elif defined(__APPLE__)
if (cparams.cpuparams.n_threads < 0) {
cparams.cpuparams.n_threads = default_apple_n_threads();
}
#endif

cparams.n_gpu_layers = getPropertyAsInt(runtime, params, "n_gpu_layers", cparams.n_gpu_layers);
Expand Down Expand Up @@ -231,6 +403,8 @@ namespace rnllama_jsi {
}
}
}

applySpeculativeOptions(runtime, params, cparams);
}

void parseCompletionParams(jsi::Runtime& runtime, const jsi::Object& params, rnllama::llama_rn_context* ctx) {
Expand All @@ -242,6 +416,7 @@ namespace rnllama_jsi {
sparams.seed = getPropertyAsInt(runtime, params, "seed", -1);
ctx->params.n_predict = getPropertyAsInt(runtime, params, "n_predict", ctx->params.n_predict);
ctx->params.sampling.ignore_eos = getPropertyAsBool(runtime, params, "ignore_eos", ctx->params.sampling.ignore_eos);
applySpeculativeOptions(runtime, params, ctx->params);

sparams.temp = getPropertyAsDouble(runtime, params, "temperature", sparams.temp);
sparams.n_probs = getPropertyAsInt(runtime, params, "n_probs", sparams.n_probs);
Expand Down
1 change: 1 addition & 0 deletions cpp/jsi/JSIParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace rnllama_jsi {
bool getPropertyAsBool(jsi::Runtime& runtime, const jsi::Object& obj, const char* name, bool defaultValue = false);
float getPropertyAsFloat(jsi::Runtime& runtime, const jsi::Object& obj, const char* name, float defaultValue = 0.0f);

bool hasSpeculativeType(const common_params_speculative& speculative, common_speculative_type type);
void parseCommonParams(jsi::Runtime& runtime, const jsi::Object& params, common_params& cparams);
void parseCompletionParams(jsi::Runtime& runtime, const jsi::Object& params, rnllama::llama_rn_context* ctx);
}
8 changes: 8 additions & 0 deletions cpp/jsi/RNLlamaJSI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,10 @@ namespace rnllama_jsi {
ctx->tts_wrapper->setGuideTokens(guide_tokens);
}

if (!mediaPaths.empty() && ctx->completion->shouldUseMTP()) {
throw std::runtime_error("MTP speculative decoding currently supports text-only completion");
}

if (!ctx->completion->initSampling()) {
throw std::runtime_error("Failed to initialize sampling");
}
Expand Down Expand Up @@ -1273,6 +1277,10 @@ namespace rnllama_jsi {
}
}

if (hasSpeculativeType(cparams.speculative, COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) {
throw std::runtime_error("MTP speculative decoding is not supported for queued parallel completions");
}

int chat_format = getPropertyAsInt(runtime, params, "chat_format", 0);
std::string reasoningFormatStr = getPropertyAsString(runtime, params, "reasoning_format", "none");
common_reasoning_format reasoning_format = common_reasoning_format_from_name(reasoningFormatStr);
Expand Down
Loading
Loading