Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 56 additions & 49 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2928,8 +2928,10 @@ struct server_context_impl {
has_mtmd = true;
}

const int32_t n_before_user = slot.task->params.n_before_user;
const bool n_before_user_known = n_before_user > 0;
const auto & user_boundaries = slot.task->params.user_boundaries;
const auto is_user_boundary = [&user_boundaries](int32_t pos) {
return std::binary_search(user_boundaries.begin(), user_boundaries.end(), pos);
};

// add prompt tokens for processing in the current batch
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
Expand Down Expand Up @@ -2959,10 +2961,8 @@ struct server_context_impl {

slot.n_prompt_tokens_processed++;

// stop the prompt batch exactly before the latest user input, so a checkpoint
// can be created after the previous messages
if (n_before_user_known &&
slot.prompt.n_tokens() == n_before_user) {
// stop the prompt batch before each user message so a checkpoint can be created
if (is_user_boundary((int32_t) slot.prompt.n_tokens())) {
break;
}

Expand Down Expand Up @@ -3008,7 +3008,7 @@ struct server_context_impl {
slot.init_sampler();
} else {
// skip ordinary mid-prompt checkpoints
if (!n_before_user_known && !near_prompt_end) {
if (user_boundaries.empty() && !near_prompt_end) {
do_checkpoint = false;
}
}
Expand All @@ -3020,21 +3020,12 @@ struct server_context_impl {
// their token position is the batch start rather than the prompt end
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;

{
const bool is_on_user =
n_before_user_known &&
n_tokens_start == n_before_user;

const bool is_after_user =
n_before_user_known &&
n_tokens_start > n_before_user;

if (do_checkpoint && !user_boundaries.empty()) {
const bool is_allowed =
!n_before_user_known ||
is_on_user ||
(is_after_user && near_prompt_end);
is_user_boundary(n_tokens_start) ||
(n_tokens_start > user_boundaries.back() && near_prompt_end);

if (do_checkpoint && !is_allowed) {
if (!is_allowed) {
do_checkpoint = false;
}
}
Expand Down Expand Up @@ -3566,48 +3557,64 @@ void server_context::on_sleeping_changed(std::function<void(bool)> callback) {
impl->queue_tasks.on_sleeping_state(std::move(callback));
}

// compute the number of tokens before the last user message in the prompt
static int32_t prompt_get_n_before_user(
const json & message_spans,
static int32_t prompt_n_tokens_before_byte(
int32_t byte_pos,
const std::string & prompt,
const std::vector<raw_buffer> & files,
const llama_vocab * vocab,
mtmd_context * mctx) {
int32_t result = -1;
int32_t byte_pos = -1;
GGML_ASSERT(byte_pos >= 0 && (size_t) byte_pos <= prompt.size());

for (const auto & span : message_spans) {
const std::string role = json_value(span, "role", std::string());
const std::string prefix = prompt.substr(0, (size_t) byte_pos);

if (role == "user") {
byte_pos = json_value(span, "pos", -1);
}
const std::string marker = get_media_marker();
size_t n_prefix_media = 0;
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
n_prefix_media++;
}

if (byte_pos >= 0) {
GGML_ASSERT((size_t) byte_pos <= prompt.size());
GGML_ASSERT(n_prefix_media <= files.size());

const std::string prefix = prompt.substr(0, (size_t) byte_pos);
if (mctx != nullptr && n_prefix_media > 0) {
// TODO: this makes a copy - avoid it
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
return (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
}

const std::string marker = get_media_marker();
size_t n_prefix_media = 0;
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
n_prefix_media++;
}
return (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
}

// compute the number of tokens before each user message in the prompt
static std::vector<int32_t> prompt_get_user_boundaries(
const json & message_spans,
const std::string & prompt,
const std::vector<raw_buffer> & files,
const llama_vocab * vocab,
mtmd_context * mctx) {
std::vector<int32_t> result;
result.reserve(message_spans.size());

GGML_ASSERT(n_prefix_media <= files.size());
for (const auto & span : message_spans) {
if (json_value(span, "role", std::string()) != "user") {
continue;
}

if (mctx != nullptr && n_prefix_media > 0) {
// TODO: this makes a copy - avoid it
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
const int32_t byte_pos = json_value(span, "pos", -1);
if (byte_pos < 0) {
continue;
}

result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
} else {
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
const int32_t n_tok = prompt_n_tokens_before_byte(byte_pos, prompt, files, vocab, mctx);
if (n_tok > 0) {
result.push_back(n_tok);
}
}
Comment thread
mfielding92 marked this conversation as resolved.

std::sort(result.begin(), result.end());
result.erase(std::unique(result.begin(), result.end()), result.end());

SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
byte_pos, n_prefix_media, result);
if (!result.empty()) {
SRV_TRC("message_spans: %zu user turn boundary(ies)\n", result.size());
}

return result;
Expand Down Expand Up @@ -3665,8 +3672,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(

const auto message_spans = json_value(data, "message_spans", json::array());
if (prompt.is_string() && message_spans.is_array()) {
task.params.n_before_user =
prompt_get_n_before_user(
task.params.user_boundaries =
prompt_get_user_boundaries(
message_spans,
prompt.get<std::string>(),
files,
Expand Down
4 changes: 2 additions & 2 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ struct task_params {

int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)

// number of prompt tokens before the latest user message
int32_t n_before_user = -1;
// number of prompt tokens before each user message
std::vector<int32_t> user_boundaries;

int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
Expand Down