Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ struct common_speculative_impl {

virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;

// true if this implementation requires the target context to extract embeddings
// true if this implementation requires the target context to extract post-norm embeddings
virtual bool need_embd() const = 0;

// true if this implementation requires the target context to extract pre-norm embeddings
virtual bool need_embd_pre_norm() const { return false; }
};

struct common_speculative_impl_draft_simple : public common_speculative_impl {
Expand Down Expand Up @@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}

llama_set_embeddings_pre_norm(ctx_tgt, true);
llama_set_embeddings_pre_norm(ctx_dft, true);
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);

pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));

Expand Down Expand Up @@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
}

bool need_embd() const override {
return false;
}

bool need_embd_pre_norm() const override {
return true;
}
};
Expand Down Expand Up @@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) {
return false;
}

bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
if (spec == nullptr) {
return false;
}

for (auto & impl : spec->impls) {
if (impl->need_embd_pre_norm()) {
return true;
}
}

return false;
}

void common_speculative_draft(common_speculative * spec) {
if (spec == nullptr) {
return;
Expand Down
5 changes: 4 additions & 1 deletion common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
// process the batch and update the internal state of the speculative context
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);

// true if any implementation requires target embeddings to be extracted
// true if any implementation requires target post-norm embeddings to be extracted
bool common_speculative_need_embd(common_speculative * spec);

// true if any implementation requires target pre-norm embeddings to be extracted
bool common_speculative_need_embd_pre_norm(common_speculative * spec);

// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);

Expand Down
51 changes: 37 additions & 14 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
throw std::runtime_error("no pre-norm embeddings");
}

const int64_t j = output_resolve_row(i);
const uint32_t n_embd = model.hparams.n_embd;

if (!cparams.embeddings_pre_norm_masked) {
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
}
return embd_pre_norm.data + (size_t) i * n_embd;
}

const int64_t j = output_resolve_row(i);
return embd_pre_norm.data + j*n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
Expand Down Expand Up @@ -1088,10 +1097,11 @@ void llama_context::set_embeddings(bool value) {
//sched_need_reserve = true;
}

void llama_context::set_embeddings_pre_norm(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);

cparams.embeddings_pre_norm = value;
cparams.embeddings_pre_norm = value;
cparams.embeddings_pre_norm_masked = masked;
}

void llama_context::set_causal_attn(bool value) {
Expand Down Expand Up @@ -1737,6 +1747,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
};

int64_t n_outputs_prev = 0;
int64_t n_tokens_prev = 0;

do {
const auto & ubatch = mctx->get_ubatch();
Expand Down Expand Up @@ -1882,16 +1893,21 @@ int llama_context::decode(const llama_batch & batch_inp) {

// extract pre-norm embeddings (hidden state before the final output norm)
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);
{
const bool masked = cparams.embeddings_pre_norm_masked;
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;

if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);

const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;

GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
}
}

// Copy backend sampling output if this ubatch produced any sampling tensors.
Expand All @@ -1908,6 +1924,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}

n_outputs_prev += n_outputs;
n_tokens_prev += ubatch.n_tokens;
} while (mctx->next());

// set to total number of outputs in the batch, for use in llama_get_logits_ith
Expand Down Expand Up @@ -1999,6 +2016,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;

if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
// unmasked: pre-norm row exists for every token in the batch, not just
// those flagged via batch.logits[i] -> size by token count instead.
embd_pre_norm.size = (size_t) n_embd * n_batch;
}

// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
if (has_sampling) {
Expand Down Expand Up @@ -3547,8 +3570,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}

void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) {
ctx->set_embeddings_pre_norm(value);
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_pre_norm(value, masked);
}

float * llama_get_embeddings_pre_norm(llama_context * ctx) {
Expand Down
2 changes: 1 addition & 1 deletion src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct llama_context {
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);

void set_embeddings (bool value);
void set_embeddings_pre_norm(bool value);
void set_embeddings_pre_norm(bool value, bool masked);
void set_causal_attn(bool value);
void set_warmup(bool value);

Expand Down
3 changes: 2 additions & 1 deletion src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ struct llama_cparams {
float yarn_beta_slow;

bool embeddings;
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0
bool causal_attn;
bool offload_kqv;
bool flash_attn;
Expand Down
10 changes: 5 additions & 5 deletions src/llama-ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
// pre-norm embeddings (hidden state before the final output norm)
//

// mirrors:
// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value);
// Set whether the context outputs pre-norm embeddings or not
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked);

// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx);

// mirrors:
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);
3 changes: 3 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,9 @@ void llm_graph_result::set_outputs() {
if (t_embd_pooled != nullptr) {
ggml_set_output(t_embd_pooled);
}
if (t_h_pre_norm != nullptr) {
ggml_set_output(t_h_pre_norm);
}
for (auto & [seq_id, t] : t_sampled) {
if (t != nullptr) {
ggml_set_output(t);
Expand Down
6 changes: 5 additions & 1 deletion src/models/qwen35.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}

if (il == n_transformer_layers - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
Expand Down Expand Up @@ -211,6 +211,10 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;

if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}

// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);

Expand Down
6 changes: 5 additions & 1 deletion src/models/qwen35moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}

if (il == n_transformer_layers - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
Expand Down Expand Up @@ -234,6 +234,10 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
Comment thread
am17an marked this conversation as resolved.

if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}

// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);

Expand Down
5 changes: 5 additions & 0 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ struct server_slot {
return task->need_embd() || (spec && common_speculative_need_embd(spec));
}

bool need_embd_pre_norm() const {
GGML_ASSERT(task);
return spec && common_speculative_need_embd_pre_norm(spec);
}

// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
// (MTP supports splitting — uses task->need_embd() not need_embd())
Expand Down
Loading