Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,11 +1026,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
llama_sampler_chain_n(sampler) > 0;

if (sampler && can_offload) {
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
if (host_buft) {
buft = host_buft;
}
auto * buft = ggml_backend_dev_buffer_type(model.dev_output());

sampler->iface->backend_init(sampler, buft);

Expand Down
19 changes: 13 additions & 6 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2419,6 +2419,9 @@ void llm_graph_context::build_sampling() const {
return;
}

std::array<ggml_tensor *, 2> outs;
outs[0] = res->t_logits;

auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
res->add_input(std::move(inp_sampling));

Expand All @@ -2439,14 +2442,14 @@ void llm_graph_context::build_sampling() const {
// add a dummy row of logits
// this trick makes the graph static, regardless of which samplers are activated
// this is important in order to minimize graph reallocations
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);

for (const auto & [seq_id, sampler] : samplers) {
const auto it = seq_to_logit_row.find(seq_id);

// inactive samplers always work on the first row
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;

ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
Expand All @@ -2463,22 +2466,26 @@ void llm_graph_context::build_sampling() const {

if (data.sampled != nullptr) {
res->t_sampled[seq_id] = data.sampled;
ggml_build_forward_expand(gf, data.sampled);
outs[1] = data.sampled;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}

if (data.probs != nullptr) {
res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, data.probs);
outs[1] = data.probs;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}

if (data.logits != nullptr) {
res->t_sampled_logits[seq_id] = data.logits;
ggml_build_forward_expand(gf, data.logits);
outs[1] = data.logits;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}

if (data.candidates != nullptr) {
res->t_candidates[seq_id] = data.candidates;
ggml_build_forward_expand(gf, data.candidates);
outs[1] = data.candidates;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
}

Expand Down
81 changes: 19 additions & 62 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend {

std::mt19937 rng;

// backend input
struct ggml_tensor * inp_uniform;

ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
ggml_tensor * inp_uniform;
};

static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
Expand Down Expand Up @@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init(
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;

// allocate inputs
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};

sctx->inp_ctx.reset(ggml_init(params));

// Create the uniform random scalar input tensor. This will be set by
// llama_sampler_dist_backend_set_input after this graph is built.
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);

// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));

ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
}

const bool res = llama_sampler_backend_support(smpl, buft);

sctx->init(res);

if (!res) {
sctx->inp_ctx.reset(nullptr);
sctx->inp_buf.reset(nullptr);
}

return res;
}

Expand All @@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply(
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
GGML_UNUSED(gf);

auto * sctx = (llama_sampler_dist *) smpl->ctx;

sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);

struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
ggml_set_name(probs, "dist_probs");

Expand Down Expand Up @@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply(

static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;

GGML_ASSERT(sctx->inp_uniform != nullptr);

// We sample in double precision and cast to float to match rnd numbers of
Expand Down Expand Up @@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .inp_uniform = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
}
);
}
Expand Down Expand Up @@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend {

struct ggml_tensor * inp_logit_bias;
struct ggml_tensor * inp_logit_idxs;

ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
};

static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
Expand Down Expand Up @@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply(
return;
}

const size_t n = sctx->logit_bias.size();

sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n);
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
ggml_set_input(sctx->inp_logit_bias);

sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n);
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
ggml_set_input(sctx->inp_logit_idxs);

ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);

cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
Expand Down Expand Up @@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
static bool llama_sampler_logit_bias_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);

auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;

sctx->init(true);
Expand All @@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init(
return true;
}

ggml_init_params params = {
/*.mem_size =*/ 2*ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};

sctx->inp_ctx.reset(ggml_init(params));

const size_t n = sctx->logit_bias.size();

sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
ggml_set_input(sctx->inp_logit_bias);

sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
ggml_set_input(sctx->inp_logit_idxs);

// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));

ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);

return true;
}

Expand Down Expand Up @@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias(
/* .to_search = */ {},
/* .inp_logit_bias = */ nullptr,
/* .inp_logit_idxs = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
}
);
}
Expand Down
Loading