11
11
import time
12
12
from dataclasses import dataclass
13
13
from pathlib import Path
14
- from typing import Optional , Tuple , List
14
+ from typing import List , Optional , Tuple
15
15
16
16
import torch
17
17
import torch ._dynamo .config
32
32
B_INST , E_INST = "[INST]" , "[/INST]"
33
33
B_SYS , E_SYS = "<<SYS>>" , "<</SYS>>"
34
34
35
+
35
36
class ChatFormat :
36
37
def __init__ (self , tokenizer ):
37
38
self .tokenizer = tokenizer
@@ -62,7 +63,6 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
62
63
return tokens
63
64
64
65
65
-
66
66
@dataclass
67
67
class GeneratorArgs :
68
68
prompt : str = "torchchat is pronounced torch-chat and is so cool because"
@@ -210,11 +210,17 @@ def decode_n_tokens(
210
210
):
211
211
new_tokens , new_probs = [], []
212
212
encountered_eos = False
213
- for i in range (num_new_tokens - 1 ): # -1 to save space to run an EoS if dont generate it naturally
213
+ for i in range (
214
+ num_new_tokens - 1
215
+ ): # -1 to save space to run an EoS if dont generate it naturally
214
216
# Actually better for Inductor to codegen attention here
215
217
with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
216
218
next_token , next_prob = decode_one_token (
217
- model , cur_token .clone (), input_pos , need_probs = need_probs , ** sampling_kwargs
219
+ model ,
220
+ cur_token .clone (),
221
+ input_pos ,
222
+ need_probs = need_probs ,
223
+ ** sampling_kwargs ,
218
224
)
219
225
input_pos += 1
220
226
new_tokens .append (next_token .clone ())
@@ -223,15 +229,25 @@ def decode_n_tokens(
223
229
new_probs .append (next_prob .clone ())
224
230
cur_token = next_token .view (1 , - 1 )
225
231
# encountered eos
226
- if (next_token .item () == eos_token_id or (eot_id is not None and next_token .item () == eot_id )):
232
+ if next_token .item () == eos_token_id or (
233
+ eot_id is not None and next_token .item () == eot_id
234
+ ):
227
235
encountered_eos = True
228
- _ , _ = decode_one_token (model , cur_token , input_pos , need_probs , ** sampling_kwargs )
236
+ _ , _ = decode_one_token (
237
+ model , cur_token , input_pos , need_probs , ** sampling_kwargs
238
+ )
229
239
input_pos += 1
230
240
break
231
241
if not encountered_eos :
232
- eos_token = torch .tensor ([eos_token_id if eot_id is None else eot_id ], dtype = cur_token .dtype , device = cur_token .device )
242
+ eos_token = torch .tensor (
243
+ [eos_token_id if eot_id is None else eot_id ],
244
+ dtype = cur_token .dtype ,
245
+ device = cur_token .device ,
246
+ )
233
247
new_tokens .append (eos_token .clone ())
234
- _ , _ = decode_one_token (model , eos_token .view (1 , - 1 ), input_pos , need_probs , ** sampling_kwargs )
248
+ _ , _ = decode_one_token (
249
+ model , eos_token .view (1 , - 1 ), input_pos , need_probs , ** sampling_kwargs
250
+ )
235
251
input_pos += 1
236
252
237
253
return new_tokens , new_probs
@@ -337,7 +353,9 @@ def generate(
337
353
with torch .device (device ):
338
354
model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
339
355
if is_speculative and draft_model is not model :
340
- draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
356
+ draft_model .setup_caches (
357
+ max_batch_size = 1 , max_seq_length = max_seq_length
358
+ )
341
359
342
360
# create an empty tensor of the expected final shape and
343
361
# fill in the current tokens
@@ -366,7 +384,9 @@ def generate(
366
384
367
385
num_tokens_generated = 0
368
386
input_pos = torch .tensor ([start_pos + T ], device = device , dtype = torch .int )
369
- accept_counts = [0 ] * (speculate_k + 1 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
387
+ accept_counts = [0 ] * (
388
+ speculate_k + 1
389
+ ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
370
390
371
391
if is_speculative :
372
392
input_pos = input_pos .item () # for speculative decoding easier to keep on host
@@ -392,12 +412,14 @@ def generate(
392
412
max_new_tokens - 1 ,
393
413
callback = callback ,
394
414
need_probs = False ,
395
- eos_token_id = tokenizer .eos_id () if tokenizer else 2 ,
396
- eot_id = tokenizer .special_tokens ["<|eot_id|>" ] if is_llama3_model else None ,
415
+ eos_token_id = tokenizer .eos_id () if tokenizer else 2 ,
416
+ eot_id = tokenizer .special_tokens ["<|eot_id|>" ] if is_llama3_model else None ,
397
417
** sampling_kwargs ,
398
418
)
399
419
seq [T + 1 : T + 1 + len (generated_tokens )] = torch .cat (generated_tokens )
400
- seq = seq [:T + 1 + len (generated_tokens )] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
420
+ seq = seq [
421
+ : T + 1 + len (generated_tokens )
422
+ ] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
401
423
402
424
generate_stats = {"accept_counts" : accept_counts }
403
425
return seq , generate_stats
@@ -410,7 +432,6 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"):
410
432
return torch .tensor (tokens , dtype = torch .int , device = device )
411
433
412
434
413
-
414
435
def get_device_info (name : str ) -> str :
415
436
import platform
416
437
from subprocess import check_output
@@ -481,7 +502,9 @@ def _main(
481
502
# Piggy backing off of this flag then for now to identify llama3 without prompting user.
482
503
is_llama3_model = tokenizer_args .is_tiktoken
483
504
if generator_args .chat_mode and is_llama3_model :
484
- logging .debug ("Llama3 model detected in chat mode. Using updated sentence schemas" )
505
+ logging .debug (
506
+ "Llama3 model detected in chat mode. Using updated sentence schemas"
507
+ )
485
508
486
509
builder_args .setup_caches = False
487
510
model = _initialize_model (builder_args , quantize , tokenizer )
@@ -534,20 +557,29 @@ def _main(
534
557
if generator_args .compile_prefill :
535
558
prefill = torch .compile (prefill , fullgraph = True , dynamic = True )
536
559
537
- system_prompt = None
560
+ system_prompt = None
538
561
# Set up our max_seq_length
539
562
if generator_args .chat_mode :
540
- max_seq_length = 2048
541
- print (f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of { max_seq_length } tokens is hit or until the user says /bye" )
542
- system_prompt = input ("System Prompt [Optional]: " )
563
+ max_seq_length = model .config .max_seq_length
564
+ print (
565
+ f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of { max_seq_length } tokens is hit or until the user says /bye"
566
+ )
567
+ get_system_prompt = input (
568
+ "Do you want to enter a system prompt? Enter y for yes and anything else for no. \n "
569
+ )
570
+ if get_system_prompt == "y" or get_system_prompt == "Y" :
571
+ system_prompt = input ("What is your system prompt? \n " )
543
572
if is_llama3_model :
544
573
chat_formatter = ChatFormat (tokenizer )
545
574
else :
546
- max_seq_length = min (encoded .size (0 ) + generator_args .max_new_tokens , model .config .block_size )
547
-
575
+ max_seq_length = min (
576
+ encoded .size (0 ) + generator_args .max_new_tokens , model .config .block_size
577
+ )
548
578
549
579
max_seq_length = (
550
- max_seq_length + speculate_k + 1 if draft_model is not None else max_seq_length
580
+ max_seq_length + speculative_builder_args .speculate_k + 1
581
+ if draft_model is not None
582
+ else max_seq_length
551
583
)
552
584
553
585
aggregate_metrics = {
@@ -557,39 +589,59 @@ def _main(
557
589
start = - 1 if generator_args .compile else 0
558
590
start_pos = 0
559
591
560
-
561
592
# arbitrarily large number as chat mode goes until max_seq length or user exits
562
593
num_samples = generator_args .num_samples if not generator_args .chat_mode else 100000
563
- i = - 1 # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
564
- while (i < num_samples ):
594
+ i = (
595
+ - 1
596
+ ) # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
597
+ while i < num_samples :
565
598
i += 1
566
599
device_sync (device = builder_args .device )
567
600
if i >= 0 and generator_args .chat_mode :
568
601
prompt = input ("User: " )
569
- if ( prompt == "/bye" ) :
602
+ if prompt == "/bye" :
570
603
print ("Exiting Chat.\n " )
571
604
break
572
605
if not is_llama3_model :
573
606
if system_prompt :
574
607
prompt = f"{ B_INST } { B_SYS } \n { system_prompt .strip ()} \n { E_SYS } \n \n { prompt .strip } { E_INST } "
575
- system_prompt = None # can only provide system prompt on first interaction
608
+ system_prompt = (
609
+ None # can only provide system prompt on first interaction
610
+ )
576
611
else :
577
612
prompt = f"{ B_INST } { prompt .strip ()} { E_INST } "
578
613
encoded = encode_tokens (
579
614
tokenizer , prompt , bos = True , device = builder_args .device
580
615
)
581
616
else :
582
- if system_prompt :
583
- encoded = chat_formatter .encode_dialog_prompt ([{"role" : "system" , "content" : system_prompt }, {"role" : "user" , "content" : prompt }])
617
+ if system_prompt is not None :
618
+ encoded = chat_formatter .encode_dialog_prompt (
619
+ [
620
+ {"role" : "system" , "content" : system_prompt },
621
+ {"role" : "user" , "content" : prompt },
622
+ ]
623
+ )
584
624
system_prompt = None
585
- elif (i == 0 ):
586
- encoded = chat_formatter .encode_dialog_prompt ([{"role" : "user" , "content" : prompt }])
625
+ elif i == 0 :
626
+ encoded = chat_formatter .encode_dialog_prompt (
627
+ [{"role" : "user" , "content" : prompt }]
628
+ )
587
629
else :
588
- encoded = chat_formatter .encode_message ({"role" : "user" , "content" : prompt })
589
- encoded .extend (chat_formatter .encode_header ({"role" : "assistant" , "content" : "" }))
590
- encoded = torch .tensor (encoded , dtype = torch .int , device = builder_args .device )
591
- if (encoded .size (0 ) + start_pos > max_seq_length ):
592
- print ("This prompt would take us past the max_seq_length. Ending Conversation." )
630
+ encoded = chat_formatter .encode_message (
631
+ {"role" : "user" , "content" : prompt }
632
+ )
633
+ encoded .extend (
634
+ chat_formatter .encode_header (
635
+ {"role" : "assistant" , "content" : "" }
636
+ )
637
+ )
638
+ encoded = torch .tensor (
639
+ encoded , dtype = torch .int , device = builder_args .device
640
+ )
641
+ if encoded .size (0 ) + start_pos > max_seq_length :
642
+ print (
643
+ "This prompt would take us past the max_seq_length. Ending Conversation."
644
+ )
593
645
break
594
646
595
647
if generator_args .chat_mode and i >= 0 :
@@ -604,12 +656,17 @@ def callback(
604
656
):
605
657
if done_generating :
606
658
return
607
- buffer .append (tokenizer .decode ([period_id ] + x .tolist ())[1 :]) # I think this results in the first output token being dropped from the display which is wrong.
659
+ buffer .append (
660
+ tokenizer .decode ([period_id ] + x .tolist ())[1 :]
661
+ ) # I think this results in the first output token being dropped from the display which is wrong.
608
662
if x .item () == tokenizer .eos_id ():
609
663
done_generating = True
610
- if (is_llama3_model and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]):
664
+ if (
665
+ is_llama3_model
666
+ and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]
667
+ ):
611
668
done_generating = True
612
- buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
669
+ buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
613
670
if len (buffer ) == 4 or done_generating :
614
671
print ("" .join (buffer ), end = "" , flush = True )
615
672
buffer .clear ()
@@ -672,7 +729,7 @@ def callback(x):
672
729
)
673
730
logging .debug (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
674
731
675
- if ( start_pos >= max_seq_length ) :
732
+ if start_pos >= max_seq_length :
676
733
print ("Max Sequence Length Reached. Ending Conversation." )
677
734
break
678
735
0 commit comments