server: support multiple generations from one prompt (OAI "n" option)#17775
server: support multiple generations from one prompt (OAI "n" option)#17775ngxson merged 11 commits intoggml-org:masterfrom
Conversation
|
@allozaur @ServeurpersoCom one application of this feature can be having multiple response choices on web UI. Kinda a low-prio feature, I think could be quite nice to add! Edit: we could technically also add per-response sampling control, for example one response with temperature=0.0 and another response with 1.0; there are many possibilities, but we need to see what's the use case exactly Example on chatgpt:
|
Oh, absolutely! I would love to take over this one, maybe still this year? |
yeah no rush! feel free to start the task as soon as this PR is merged |
|
This is more of an idea than a desired feature, at least for the moment but, multiple generations from the same prompt would allow for "best-of-n" scenarios. optillm is a good example of this. |
|
@ggerganov pinging in case you missed this PR |
ggerganov
left a comment
There was a problem hiding this comment.
Very nice! The implementation is much simpler than I anticipated.
| server_tokens server_tokens::clone() const { | ||
| server_tokens res; | ||
| res.has_mtmd = has_mtmd; | ||
| res.tokens = tokens; | ||
| for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { | ||
| size_t idx = it->first; | ||
| const mtmd::input_chunk_ptr & chunk = it->second; | ||
| res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get())); | ||
| } | ||
| return res; | ||
| } | ||
|
|
There was a problem hiding this comment.
Now that we have this function, I think we can enable host-memory prompt caching with mtmd:
- Update this code:
llama.cpp/tools/server/server-task.cpp
Lines 1396 to 1404 in 6fb3226
- Remove this condition:
llama.cpp/tools/server/server-context.cpp
Lines 886 to 889 in 6fb3226
I haven't tested, but I think the only reason that prompt caching didn't work was because wasn't sure how to copy the server_tokens. So it's worth giving it a try after these changes.
There was a problem hiding this comment.
Yes it will be nice to enable RAM cache for mtmd. I created an issue so we can have a look later on: #17821
| if (slot.is_parent() || slot.is_child()) { | ||
| send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER); | ||
| slot.release(); | ||
| continue; | ||
| } | ||
|
|
There was a problem hiding this comment.
Hm, what is the reason to not support context shift here?
There was a problem hiding this comment.
Not quite sure about this, but IIUC llama_kv_cache::seq_add does not have a notion of copy-on-write. For example, if a KV cell is both used by 2 sequences, one seq shifting it will also cause the second to also be shifted
This is fine if the current (generating) token position is synchronized among all sequence, but we don't have an explicit logic to guarantee that this will always happen
There was a problem hiding this comment.
Also, the generation length of each sequence an be different, which can be quite difficult to keep track
There was a problem hiding this comment.
I see, that is correct. The problem is that some of the tokens are shared when we use unified KV cache. It would work with split KV cache, but maybe it's not worth the extra logic branching.
Either way, context shifting is probably something that we should remove at some point - it does not have much value with today's models with more than 128k token contexts.
tools/server/server-context.cpp
Outdated
| slot.copy_state_to(*child); | ||
| child->state = SLOT_STATE_DONE_PROMPT; | ||
| } | ||
| slot.state = SLOT_STATE_DONE_PROMPT; |
| states.push_back(child.params.oaicompat_chat_syntax); | ||
| tasks.push_back(std::move(child)); |
There was a problem hiding this comment.
I think we should improve this by making tasks and states more associated with each other - feel like this is currently error-prone because one might forget to update the states when adding a new task.
Does it make sense to have the task_result_state be part of the server_task itself?
There was a problem hiding this comment.
Does it make sense to have the
task_result_statebe part of theserver_taskitself?
The principle is that server_task will be std::move to task queue, and eventually be moved to slot, so it cannot hold task_result_state because the state need to stays in the HTTP thread
What I'm thinking is that we can just allow server_response_reader to create the state for each task, because currently tasks need to be posted by server_response_reader anyway
Btw, the further plan is to only expose server_response_reader to HTTP handlers as the API is easier to follow and it's also safer than managing directly the server_queue/response. WDYT?
There was a problem hiding this comment.
I'll implement this in a follow-up PR
|
Edit: nvm it works - just use |
|
btw for /completions and /infill, I added support for both |
|
Do I understand correctly that with this change, instead of sending multiple separate requests with the same prompt, I can now send a single request and it will be faster? |
Try it with -np, --parallel (not tested yet, I'm not sure) |
Yes it will be faster - the |
|
Great work on this PR! I can confirm parallel sequences work perfectly. Here's my test setup: The implementation correctly:
This is especially efficient when memory-bound since parallel batching allows better compute utilization while waiting for memory bandwidth: getting 3 to 4x total throughput! |
…ggml-org#17775) * backend support * server: support multiple generations from one prompt (OAI "n" option) * fix invalid batch * format oai * clean up * disable ctx shift * add test * update comments * fix style * add n_cmpl to docs [no ci] * allowing using both n_cmpl and n
…ggml-org#17775) * backend support * server: support multiple generations from one prompt (OAI "n" option) * fix invalid batch * format oai * clean up * disable ctx shift * add test * update comments * fix style * add n_cmpl to docs [no ci] * allowing using both n_cmpl and n
|
I think there is a problem with this implementation - each of the parallel completions appears to processes the same input prompt. I'll take a deeper look later to confirm, but from a quick test the computation is more than it should be (i.e. we compute the same prompt |
| } | ||
|
|
||
| bool is_child() const { | ||
| return is_processing() && task->id_parent >= 0; |
There was a problem hiding this comment.
| return is_processing() && task->id_parent >= 0; | |
| return task->id_parent >= 0; |
@ggerganov yeah right, I think the problem is here: we check for is_child() to see if it should be set to a waiting state, but at the time of check, is_processing() is false
|
Our 'return_progress = true' can show the issue clearly with a script : |
…ggml-org#17775) * backend support * server: support multiple generations from one prompt (OAI "n" option) * fix invalid batch * format oai * clean up * disable ctx shift * add test * update comments * fix style * add n_cmpl to docs [no ci] * allowing using both n_cmpl and n
… (#17775) * backend support * server: support multiple generations from one prompt (OAI "n" option) * fix invalid batch * format oai * clean up * disable ctx shift * add test * update comments * fix style * add n_cmpl to docs [no ci] * allowing using both n_cmpl and n

Fix #11142
Implementation
The requirement is that number of slots must be equal or larger than number of "n" completion choices.
SLOT_STATE_WAIT_OTHERSLOT_STATE_WAIT_OTHERstate, then copy parent's state into these slots viallama_memory_seq_cpTODO: