Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
Expand Down Expand Up @@ -162,7 +162,7 @@ struct common_speculative_impl {
virtual bool need_embd() const = 0;

// true if this implementation requires the target context to extract pre-norm embeddings
virtual bool need_embd_pre_norm() const { return false; }
virtual bool need_embd_nextn() const { return false; }
};

struct common_speculative_impl_draft_simple : public common_speculative_impl {
Expand Down Expand Up @@ -487,8 +487,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
}
}

llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);

pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));

Expand Down Expand Up @@ -583,7 +583,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));

//{
Expand Down Expand Up @@ -625,7 +625,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
verify_h[seq_id].resize((size_t) n_rows * n_embd);

for (int32_t i = 0; i < n_rows; ++i) {
const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i);
const float * h = llama_get_embeddings_nextn_ith(ctx_tgt, i_batch_beg[seq_id] + i);
std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes);
}

Expand Down Expand Up @@ -686,7 +686,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * smpl = smpls[seq_id].get();

common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
++i_batch;

const auto * cur_p = common_sampler_get_candidates(smpl, true);
Expand Down Expand Up @@ -772,7 +772,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
return false;
}

bool need_embd_pre_norm() const override {
bool need_embd_nextn() const override {
return true;
}
};
Expand Down Expand Up @@ -1539,13 +1539,13 @@ bool common_speculative_need_embd(common_speculative * spec) {
return false;
}

bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
bool common_speculative_need_embd_nextn(common_speculative * spec) {
if (spec == nullptr) {
return false;
}

for (auto & impl : spec->impls) {
if (impl->need_embd_pre_norm()) {
if (impl->need_embd_nextn()) {
return true;
}
}
Expand Down
4 changes: 2 additions & 2 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b
// true if any implementation requires target post-norm embeddings to be extracted
bool common_speculative_need_embd(common_speculative * spec);

// true if any implementation requires target pre-norm embeddings to be extracted
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
// true if any implementation requires target nextn embeddings to be extracted
bool common_speculative_need_embd_nextn(common_speculative * spec);

// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);
Expand Down
141 changes: 71 additions & 70 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,20 @@ llama_context::llama_context(
cparams.n_rs_seq = 0;
}

cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
cparams.embeddings = params.embeddings;
cparams.embeddings_pre_norm = false;
cparams.embeddings_pre_norm_masked = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
cparams.embeddings = params.embeddings;
cparams.embeddings_nextn = false;
cparams.embeddings_nextn_masked = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;


cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
Expand Down Expand Up @@ -889,34 +890,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}

float * llama_context::get_embeddings_pre_norm() {
float * llama_context::get_embeddings_nextn() {
output_reorder();

return embd_pre_norm.data;
return embd_nextn.data;
}

float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
float * llama_context::get_embeddings_nextn_ith(int32_t i) {
output_reorder();

try {
if (embd_pre_norm.data == nullptr) {
throw std::runtime_error("no pre-norm embeddings");
if (embd_nextn.data == nullptr) {
throw std::runtime_error("no nextn embeddings");
}

const uint32_t n_embd = model.hparams.n_embd;

if (!cparams.embeddings_pre_norm_masked) {
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
if (!cparams.embeddings_nextn_masked) {
// unmasked: nextn rows are stored densely, indexed by raw token position.
if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) {
throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd));
}
return embd_pre_norm.data + (size_t) i * n_embd;
return embd_nextn.data + (size_t) i * n_embd;
}

const int64_t j = output_resolve_row(i);
return embd_pre_norm.data + j*n_embd;
return embd_nextn.data + j*n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ABORT("fatal error");
#else
Expand Down Expand Up @@ -1105,11 +1106,11 @@ void llama_context::set_embeddings(bool value) {
//sched_need_reserve = true;
}

void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
void llama_context::set_embeddings_nextn(bool value, bool masked) {
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);

cparams.embeddings_pre_norm = value;
cparams.embeddings_pre_norm_masked = masked;
cparams.embeddings_nextn = value;
cparams.embeddings_nextn_masked = masked;
}

void llama_context::set_causal_attn(bool value) {
Expand Down Expand Up @@ -1326,7 +1327,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
}

int llama_context::encode(const llama_batch & batch_inp) {
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
// so accept either present rather than requiring exactly one.
GGML_ASSERT(batch_inp.token || batch_inp.embd);

Expand Down Expand Up @@ -1399,9 +1400,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}

auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;

// extract logits
if (logits.data && t_logits) {
Expand Down Expand Up @@ -1467,14 +1468,14 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}

// extract pre-norm embeddings (hidden state before the final output norm)
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
// extract nextn embeddings (hidden state before the final output norm)
if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);

const uint32_t n_embd = hparams.n_embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
}

// TODO: hacky solution
Expand Down Expand Up @@ -1629,7 +1630,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
}

int llama_context::decode(const llama_batch & batch_inp) {
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
// so accept either present rather than requiring exactly one.
GGML_ASSERT(batch_inp.token || batch_inp.embd);

Expand Down Expand Up @@ -1829,9 +1830,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}

auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;

if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
Expand Down Expand Up @@ -1912,22 +1913,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}

// extract pre-norm embeddings (hidden state before the final output norm)
// extract nextn embeddings before
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
{
const bool masked = cparams.embeddings_pre_norm_masked;
const bool masked = cparams.embeddings_nextn_masked;
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;

if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);

const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
const uint32_t n_embd = hparams.n_embd;
float * embd_nextn_out = embd_nextn.data + offset*n_embd;

GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float));
}
}

Expand Down Expand Up @@ -2019,9 +2020,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();

bool has_logits = true;
bool has_embd = cparams.embeddings;
bool has_embd_pre_norm = cparams.embeddings_pre_norm;
bool has_logits = true;
bool has_embd = cparams.embeddings;
bool has_embd_nextn = cparams.embeddings_nextn;

// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
Expand All @@ -2033,14 +2034,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
size_t backend_float_count = 0;
size_t backend_token_count = 0;

logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;

if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
// unmasked: pre-norm row exists for every token in the batch, not just
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
// unmasked: nextn row exists for every token in the batch, not just
// those flagged via batch.logits[i] -> size by token count instead.
embd_pre_norm.size = (size_t) n_embd * n_batch;
embd_nextn.size = (size_t) n_embd * n_batch;
}

// Allocate backend sampling output buffers if there are backend samplers configured.
Expand All @@ -2057,7 +2058,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {

const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size =
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
(logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) +
( backend_token_count) * sizeof(llama_token);

// alloc only when more than the current capacity is required
Expand All @@ -2074,7 +2075,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
buf_output = nullptr;
logits.data = nullptr;
embd.data = nullptr;
embd_pre_norm.data = nullptr;
embd_nextn.data = nullptr;
}

auto * buft = ggml_backend_cpu_buffer_type();
Expand Down Expand Up @@ -2103,8 +2104,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float);

embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
offset += embd_pre_norm.size * sizeof(float);
embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0};
offset += embd_nextn.size * sizeof(float);

if (has_sampling) {
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
Expand Down Expand Up @@ -2172,9 +2173,9 @@ void llama_context::output_reorder() {
}
}

if (embd_pre_norm.size > 0) {
if (embd_nextn.size > 0) {
for (uint64_t k = 0; k < n_embd; k++) {
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]);
}
}

Expand Down Expand Up @@ -3588,20 +3589,20 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}

void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_pre_norm(value, masked);
void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_nextn(value, masked);
}

float * llama_get_embeddings_pre_norm(llama_context * ctx) {
float * llama_get_embeddings_nextn(llama_context * ctx) {
ctx->synchronize();

return ctx->get_embeddings_pre_norm();
return ctx->get_embeddings_nextn();
}

float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) {
float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) {
ctx->synchronize();

return ctx->get_embeddings_pre_norm_ith(i);
return ctx->get_embeddings_nextn_ith(i);
}

bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
Expand Down
Loading
Loading