@@ -155,15 +155,23 @@ struct task_server {
155155 json data;
156156 bool infill_mode = false ;
157157 bool embedding_mode = false ;
158+ int multitask_id = -1 ;
158159};
159160
160161struct task_result {
161162 int id;
163+ int multitask_id = -1 ;
162164 bool stop;
163165 bool error;
164166 json result_json;
165167};
166168
169+ struct task_multi {
170+ int id;
171+ std::set<int > subtasks_remaining{};
172+ std::vector<task_result> results{};
173+ };
174+
167175// TODO: can become bool if we can't find use of more states
168176enum slot_state
169177{
@@ -406,6 +414,9 @@ struct llama_client_slot
406414 double t_prompt_processing; // ms
407415 double t_token_generation; // ms
408416
417+ // multitasks
418+ int multitask_id = -1 ;
419+
409420 void reset () {
410421 num_prompt_tokens = 0 ;
411422 generated_text = " " ;
@@ -529,7 +540,8 @@ struct llama_server_context
529540
530541 std::vector<task_server> queue_tasks;
531542 std::vector<task_result> queue_results;
532- std::mutex mutex_tasks;
543+ std::vector<task_multi> queue_multitasks;
544+ std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
533545 std::mutex mutex_results;
534546
535547 ~llama_server_context ()
@@ -1112,17 +1124,40 @@ struct llama_server_context
11121124 return slot.images .size () > 0 ;
11131125 }
11141126
1115- void send_error (int id , std::string error)
1127+ void send_error (task_server& task , std::string error)
11161128 {
11171129 std::lock_guard<std::mutex> lock (mutex_results);
11181130 task_result res;
1119- res.id = id;
1131+ res.id = task.id ;
1132+ res.multitask_id = task.multitask_id ;
11201133 res.stop = false ;
11211134 res.error = true ;
11221135 res.result_json = { { " content" , error } };
11231136 queue_results.push_back (res);
11241137 }
11251138
1139+ void add_multi_task (int id, std::vector<int >& sub_ids)
1140+ {
1141+ std::lock_guard<std::mutex> lock (mutex_tasks);
1142+ task_multi multi;
1143+ multi.id = id;
1144+ std::copy (sub_ids.begin (), sub_ids.end (), std::inserter (multi.subtasks_remaining , multi.subtasks_remaining .end ()));
1145+ queue_multitasks.push_back (multi);
1146+ }
1147+
1148+ void update_multi_task (int multitask_id, int subtask_id, task_result& result)
1149+ {
1150+ std::lock_guard<std::mutex> lock (mutex_tasks);
1151+ for (auto & multitask : queue_multitasks)
1152+ {
1153+ if (multitask.id == multitask_id)
1154+ {
1155+ multitask.subtasks_remaining .erase (subtask_id);
1156+ multitask.results .push_back (result);
1157+ }
1158+ }
1159+ }
1160+
11261161 json get_model_props ()
11271162 {
11281163 return get_formated_generation (slots[0 ]);
@@ -1167,6 +1202,7 @@ struct llama_server_context
11671202 std::lock_guard<std::mutex> lock (mutex_results);
11681203 task_result res;
11691204 res.id = slot.task_id ;
1205+ res.multitask_id = slot.multitask_id ;
11701206 res.error = false ;
11711207 res.stop = false ;
11721208
@@ -1206,6 +1242,7 @@ struct llama_server_context
12061242 std::lock_guard<std::mutex> lock (mutex_results);
12071243 task_result res;
12081244 res.id = slot.task_id ;
1245+ res.multitask_id = slot.multitask_id ;
12091246 res.error = false ;
12101247 res.stop = true ;
12111248
@@ -1251,6 +1288,12 @@ struct llama_server_context
12511288 res.result_json [" model" ] = slot.oaicompat_model ;
12521289 }
12531290
1291+ // parent multitask, if any, needs to be updated
1292+ if (slot.multitask_id != -1 )
1293+ {
1294+ update_multi_task (slot.multitask_id , slot.task_id , res);
1295+ }
1296+
12541297 queue_results.push_back (res);
12551298 }
12561299
@@ -1259,6 +1302,7 @@ struct llama_server_context
12591302 std::lock_guard<std::mutex> lock (mutex_results);
12601303 task_result res;
12611304 res.id = slot.task_id ;
1305+ res.multitask_id = slot.multitask_id ;
12621306 res.error = false ;
12631307 res.stop = true ;
12641308
@@ -1285,16 +1329,26 @@ struct llama_server_context
12851329 queue_results.push_back (res);
12861330 }
12871331
1288- int request_completion (json data, bool infill, bool embedding)
1332+ int request_completion (json data, bool infill, bool embedding, int multitask_id )
12891333 {
1290- std::lock_guard <std::mutex> lock (mutex_tasks);
1334+ std::unique_lock <std::mutex> lock (mutex_tasks);
12911335 task_server task;
12921336 task.id = id_gen++;
12931337 task.target_id = 0 ;
12941338 task.data = std::move (data);
12951339 task.infill_mode = infill;
12961340 task.embedding_mode = embedding;
12971341 task.type = COMPLETION_TASK;
1342+ task.multitask_id = multitask_id;
1343+
1344+ // when a completion task's prompt array is not a singleton, we split it into multiple requests
1345+ if (task.data .at (" prompt" ).size () > 1 )
1346+ {
1347+ lock.unlock (); // entering new func scope
1348+ return split_multiprompt_task (task);
1349+ }
1350+
1351+ // otherwise, it's a single-prompt task, we actually queue it
12981352 queue_tasks.push_back (task);
12991353 return task.id ;
13001354 }
@@ -1313,8 +1367,17 @@ struct llama_server_context
13131367
13141368 for (int i = 0 ; i < (int ) queue_results.size (); i++)
13151369 {
1370+ // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
1371+ if (queue_results[i].multitask_id == task_id)
1372+ {
1373+ update_multi_task (task_id, queue_results[i].id , queue_results[i]);
1374+ queue_results.erase (queue_results.begin () + i);
1375+ continue ;
1376+ }
1377+
13161378 if (queue_results[i].id == task_id)
13171379 {
1380+ assert (queue_results[i].multitask_id == -1 );
13181381 task_result res = queue_results[i];
13191382 queue_results.erase (queue_results.begin () + i);
13201383 return res;
@@ -1404,6 +1467,27 @@ struct llama_server_context
14041467 queue_tasks.push_back (task);
14051468 }
14061469
1470+ int split_multiprompt_task (task_server& multiprompt_task)
1471+ {
1472+ auto prompt_count = multiprompt_task.data .at (" prompt" ).size ();
1473+ assert (prompt_count > 1 );
1474+
1475+ int multitask_id = id_gen++;
1476+ std::vector<int > subtask_ids (prompt_count);
1477+ for (int i = 0 ; i < prompt_count; i++)
1478+ {
1479+ json subtask_data = multiprompt_task.data ;
1480+ subtask_data[" prompt" ] = subtask_data[" prompt" ][i];
1481+
1482+ // subtasks inherit everything else (infill mode, embedding mode, etc.)
1483+ subtask_ids[i] = request_completion (subtask_data, multiprompt_task.infill_mode , multiprompt_task.embedding_mode , multitask_id);
1484+ }
1485+
1486+ // queue up the multitask so we can track its subtask progression
1487+ add_multi_task (multitask_id, subtask_ids);
1488+ return multitask_id;
1489+ }
1490+
14071491 void process_tasks ()
14081492 {
14091493 std::lock_guard<std::mutex> lock (mutex_tasks);
@@ -1419,7 +1503,7 @@ struct llama_server_context
14191503 {
14201504 LOG_TEE (" slot unavailable\n " );
14211505 // send error result
1422- send_error (task. id , " slot unavailable" );
1506+ send_error (task, " slot unavailable" );
14231507 return ;
14241508 }
14251509
@@ -1433,11 +1517,12 @@ struct llama_server_context
14331517 slot->infill = task.infill_mode ;
14341518 slot->embedding = task.embedding_mode ;
14351519 slot->task_id = task.id ;
1520+ slot->multitask_id = task.multitask_id ;
14361521
14371522 if (!launch_slot_with_data (slot, task.data ))
14381523 {
14391524 // send error result
1440- send_error (task. id , " internal_error" );
1525+ send_error (task, " internal_error" );
14411526 break ;
14421527 }
14431528 } break ;
@@ -1453,6 +1538,38 @@ struct llama_server_context
14531538 } break ;
14541539 }
14551540 }
1541+
1542+ // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
1543+ auto queue_iterator = queue_multitasks.begin ();
1544+ while (queue_iterator != queue_multitasks.end ())
1545+ {
1546+ if (queue_iterator->subtasks_remaining .empty ())
1547+ {
1548+ // all subtasks done == multitask is done
1549+ task_result aggregate_result;
1550+ aggregate_result.id = queue_iterator->id ;
1551+ aggregate_result.stop = true ;
1552+ aggregate_result.error = false ;
1553+
1554+ // collect json results into one json result
1555+ std::vector<json> result_jsons;
1556+ for (auto & subres : queue_iterator->results )
1557+ {
1558+ result_jsons.push_back (subres.result_json );
1559+ aggregate_result.error = aggregate_result.error && subres.error ;
1560+ }
1561+ aggregate_result.result_json = json{ " results" , result_jsons };
1562+
1563+ std::lock_guard<std::mutex> lock (mutex_results);
1564+ queue_results.push_back (aggregate_result);
1565+
1566+ queue_iterator = queue_multitasks.erase (queue_iterator);
1567+ }
1568+ else
1569+ {
1570+ ++queue_iterator;
1571+ }
1572+ }
14561573 }
14571574
14581575 bool update_slots () {
@@ -2596,7 +2713,7 @@ int main(int argc, char **argv)
25962713 svr.Post (" /completion" , [&llama](const httplib::Request &req, httplib::Response &res)
25972714 {
25982715 json data = json::parse (req.body );
2599- const int task_id = llama.request_completion (data, false , false );
2716+ const int task_id = llama.request_completion (data, false , false , - 1 );
26002717 if (!json_value (data, " stream" , false )) {
26012718 std::string completion_text;
26022719 task_result result = llama.next_result (task_id);
@@ -2685,7 +2802,7 @@ int main(int argc, char **argv)
26852802 {
26862803 json data = oaicompat_completion_params_parse (json::parse (req.body ));
26872804
2688- const int task_id = llama.request_completion (data, false , false );
2805+ const int task_id = llama.request_completion (data, false , false , - 1 );
26892806
26902807 if (!json_value (data, " stream" , false )) {
26912808 std::string completion_text;
@@ -2754,7 +2871,7 @@ int main(int argc, char **argv)
27542871 svr.Post (" /infill" , [&llama](const httplib::Request &req, httplib::Response &res)
27552872 {
27562873 json data = json::parse (req.body );
2757- const int task_id = llama.request_completion (data, true , false );
2874+ const int task_id = llama.request_completion (data, true , false , - 1 );
27582875 if (!json_value (data, " stream" , false )) {
27592876 std::string completion_text;
27602877 task_result result = llama.next_result (task_id);
@@ -2858,7 +2975,7 @@ int main(int argc, char **argv)
28582975 {
28592976 prompt = " " ;
28602977 }
2861- const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false , true );
2978+ const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false , true , - 1 );
28622979 task_result result = llama.next_result (task_id);
28632980 return res.set_content (result.result_json .dump (), " application/json" );
28642981 });
0 commit comments