@@ -117,6 +117,13 @@ class SequenceGenerationResponse:
117117 token_id : int
118118
119119
120+ @dataclass
121+ class EvalQueryRequest :
122+ request_id : int
123+ num_past_tokens : int
124+ query_token_ids : List [int ]
125+
126+
120127def sample (logits ):
121128 logits = torch .from_dlpack (logits )
122129 return torch .argmax (logits , - 1 ).cpu ().numpy ()
@@ -241,6 +248,72 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]:
241248 )
242249
243250
251+ def _prepare_eval_queries (
252+ requests : List [EvalQueryRequest ],
253+ all_slot_mappings ,
254+ sliding_window ,
255+ dev ,
256+ ):
257+ seq_lens = []
258+ query_lens = []
259+ input_ids = []
260+ slot_mapping = []
261+ past_slot_mapping = []
262+ positions = []
263+ permute_map = []
264+
265+ query_offset = sum ([request .num_past_tokens for request in requests ])
266+ past_offset = 0
267+
268+ for request in requests :
269+ num_past_tokens = request .num_past_tokens
270+ num_queries = len (request .query_token_ids )
271+ query_lens .append (num_queries )
272+ request_id = request .request_id
273+ input_ids += request .query_token_ids
274+
275+ positions += [num_past_tokens + i for i in range (num_queries )]
276+
277+ if sliding_window and num_past_tokens + num_queries >= sliding_window :
278+ seq_lens .append (sliding_window )
279+ past_slot_mapping += all_slot_mappings [request_id ][
280+ num_past_tokens - (sliding_window - num_queries ) : num_past_tokens
281+ ]
282+ else :
283+ seq_lens .append (num_past_tokens + num_queries )
284+ past_slot_mapping += all_slot_mappings [request_id ][:num_past_tokens ]
285+
286+ slot_mapping += all_slot_mappings [request_id ][
287+ num_past_tokens : num_past_tokens + num_queries
288+ ]
289+
290+ permute_map += list (range (past_offset , past_offset + num_past_tokens )) + list (
291+ range (query_offset , query_offset + num_queries )
292+ )
293+
294+ query_offset += num_queries
295+ past_offset += num_past_tokens
296+
297+ input_ids = tvm .nd .array (np .array (input_ids , dtype = "int32" ), dev )
298+ positions = tvm .nd .array (np .array (positions , dtype = "int32" ), dev )
299+ seq_lens = tvm .nd .array (np .array (seq_lens , dtype = "int32" ), dev )
300+ slot_mapping = tvm .nd .array (np .array (slot_mapping , dtype = "int32" ), dev )
301+
302+ query_lens = tvm .nd .array (np .array (query_lens , dtype = "int32" ), dev )
303+ past_slot_mapping = tvm .nd .array (np .array (past_slot_mapping , dtype = "int32" ), dev )
304+ permute_map = tvm .nd .array (np .array (permute_map , dtype = "int32" ), dev )
305+
306+ return (
307+ input_ids ,
308+ positions ,
309+ seq_lens ,
310+ slot_mapping ,
311+ query_lens ,
312+ past_slot_mapping ,
313+ permute_map ,
314+ )
315+
316+
244317class Model :
245318 def __init__ (
246319 self , artifact_path , model_name , quant , vocab_size , num_shards , dev , sliding_window
@@ -443,6 +516,59 @@ def run(args):
443516 for p , g in zip (prompts , generated ):
444517 print ("Prompt = '{}', generated text = '{}'" .format (p , g ))
445518
519+ query_token_lens = [4 , 3 , 5 , 2 ]
520+
521+ eval_query_requests = []
522+
523+ for request_id , query_token_len in zip (request_ids , query_token_lens ):
524+ queries_to_eval = requests [request_id ].token_ids [- query_token_len :]
525+ num_past = len (requests [request_id ].token_ids ) - query_token_len
526+ eval_query_requests .append (EvalQueryRequest (request_id , num_past , queries_to_eval ))
527+
528+ (
529+ input_ids ,
530+ positions ,
531+ seq_lens ,
532+ slot_mapping ,
533+ query_lens ,
534+ past_slot_mapping ,
535+ permute_map ,
536+ ) = _prepare_eval_queries (
537+ eval_query_requests ,
538+ cache .slot_mappings ,
539+ None ,
540+ model .dev ,
541+ )
542+
543+ logits = model .mod ["evaluate_multi_query" ](
544+ input_ids ,
545+ positions ,
546+ seq_lens ,
547+ cache .cache ,
548+ slot_mapping ,
549+ query_lens ,
550+ past_slot_mapping ,
551+ permute_map ,
552+ model .params ,
553+ )[0 ].numpy ()
554+
555+ assert logits .shape [0 ] == sum (query_token_lens )
556+
557+ logits_offset = 0
558+
559+ for request_id , query_token_len in zip (request_ids , query_token_lens ):
560+ for i in range (query_token_len - 1 ):
561+ # requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens.
562+ # Doing argmax over multi-timestep logits computed in parallel should yield the same
563+ # tokens at the corresponding positions.
564+ past_tokens = requests [request_id ].token_ids [:- query_token_len ]
565+ assert (
566+ np .argmax (logits [logits_offset + i ])
567+ == requests [request_id ].token_ids [len (past_tokens ) + i + 1 ]
568+ )
569+
570+ logits_offset += query_token_len
571+
446572
447573if __name__ == "__main__" :
448574 run (parse_args ())
0 commit comments