Skip to content

Commit 66a2e53

Browse files
authored
Add new Relax function to the batched model for evaluating query tokens over multiple time steps in parallel (#156)
* add new model for evaluating logits over multiple queries using KV cache * add test * clean * Only the number of past tokens is needed * fix build * fix * correctly handle num_past_tokens > sliding_window case
1 parent 1dcb26d commit 66a2e53

File tree

3 files changed

+365
-61
lines changed

3 files changed

+365
-61
lines changed

examples/python/run_llama_batched_vllm.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ class SequenceGenerationResponse:
117117
token_id: int
118118

119119

120+
@dataclass
121+
class EvalQueryRequest:
122+
request_id: int
123+
num_past_tokens: int
124+
query_token_ids: List[int]
125+
126+
120127
def sample(logits):
121128
logits = torch.from_dlpack(logits)
122129
return torch.argmax(logits, -1).cpu().numpy()
@@ -241,6 +248,72 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]:
241248
)
242249

243250

251+
def _prepare_eval_queries(
252+
requests: List[EvalQueryRequest],
253+
all_slot_mappings,
254+
sliding_window,
255+
dev,
256+
):
257+
seq_lens = []
258+
query_lens = []
259+
input_ids = []
260+
slot_mapping = []
261+
past_slot_mapping = []
262+
positions = []
263+
permute_map = []
264+
265+
query_offset = sum([request.num_past_tokens for request in requests])
266+
past_offset = 0
267+
268+
for request in requests:
269+
num_past_tokens = request.num_past_tokens
270+
num_queries = len(request.query_token_ids)
271+
query_lens.append(num_queries)
272+
request_id = request.request_id
273+
input_ids += request.query_token_ids
274+
275+
positions += [num_past_tokens + i for i in range(num_queries)]
276+
277+
if sliding_window and num_past_tokens + num_queries >= sliding_window:
278+
seq_lens.append(sliding_window)
279+
past_slot_mapping += all_slot_mappings[request_id][
280+
num_past_tokens - (sliding_window - num_queries) : num_past_tokens
281+
]
282+
else:
283+
seq_lens.append(num_past_tokens + num_queries)
284+
past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens]
285+
286+
slot_mapping += all_slot_mappings[request_id][
287+
num_past_tokens : num_past_tokens + num_queries
288+
]
289+
290+
permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list(
291+
range(query_offset, query_offset + num_queries)
292+
)
293+
294+
query_offset += num_queries
295+
past_offset += num_past_tokens
296+
297+
input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev)
298+
positions = tvm.nd.array(np.array(positions, dtype="int32"), dev)
299+
seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev)
300+
slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev)
301+
302+
query_lens = tvm.nd.array(np.array(query_lens, dtype="int32"), dev)
303+
past_slot_mapping = tvm.nd.array(np.array(past_slot_mapping, dtype="int32"), dev)
304+
permute_map = tvm.nd.array(np.array(permute_map, dtype="int32"), dev)
305+
306+
return (
307+
input_ids,
308+
positions,
309+
seq_lens,
310+
slot_mapping,
311+
query_lens,
312+
past_slot_mapping,
313+
permute_map,
314+
)
315+
316+
244317
class Model:
245318
def __init__(
246319
self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window
@@ -443,6 +516,59 @@ def run(args):
443516
for p, g in zip(prompts, generated):
444517
print("Prompt = '{}', generated text = '{}'".format(p, g))
445518

519+
query_token_lens = [4, 3, 5, 2]
520+
521+
eval_query_requests = []
522+
523+
for request_id, query_token_len in zip(request_ids, query_token_lens):
524+
queries_to_eval = requests[request_id].token_ids[-query_token_len:]
525+
num_past = len(requests[request_id].token_ids) - query_token_len
526+
eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval))
527+
528+
(
529+
input_ids,
530+
positions,
531+
seq_lens,
532+
slot_mapping,
533+
query_lens,
534+
past_slot_mapping,
535+
permute_map,
536+
) = _prepare_eval_queries(
537+
eval_query_requests,
538+
cache.slot_mappings,
539+
None,
540+
model.dev,
541+
)
542+
543+
logits = model.mod["evaluate_multi_query"](
544+
input_ids,
545+
positions,
546+
seq_lens,
547+
cache.cache,
548+
slot_mapping,
549+
query_lens,
550+
past_slot_mapping,
551+
permute_map,
552+
model.params,
553+
)[0].numpy()
554+
555+
assert logits.shape[0] == sum(query_token_lens)
556+
557+
logits_offset = 0
558+
559+
for request_id, query_token_len in zip(request_ids, query_token_lens):
560+
for i in range(query_token_len - 1):
561+
# requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens.
562+
# Doing argmax over multi-timestep logits computed in parallel should yield the same
563+
# tokens at the corresponding positions.
564+
past_tokens = requests[request_id].token_ids[:-query_token_len]
565+
assert (
566+
np.argmax(logits[logits_offset + i])
567+
== requests[request_id].token_ids[len(past_tokens) + i + 1]
568+
)
569+
570+
logits_offset += query_token_len
571+
446572

447573
if __name__ == "__main__":
448574
run(parse_args())

mlc_llm/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ def mod_transform_before_build(
593593
# This is equivalent to prefill but without KV cache. It is used for
594594
# determining the number of paged cache blocks that can be allocated.
595595
model_names.append("evaluate")
596+
model_names.append("evaluate_multi_query")
596597

597598
if args.sep_embed:
598599
model_names = ["embed", "prefill_with_embed"] + model_names[1:]

0 commit comments

Comments
 (0)