Skip to content

Commit c750f24

Browse files
mikekgfbmalfet
authored andcommitted
set buffer size to 8192 as default, decode precision as a string, lint (pytorch#476)
* set buffer size to 8192 as default, decode precision as a string, lint * typo * typo * typo
1 parent 95ef489 commit c750f24

File tree

6 files changed

+114
-57
lines changed

6 files changed

+114
-57
lines changed

build/builder.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class BuilderArgs:
4141
def __post_init__(self):
4242
if self.device is None:
4343
self.device = "cuda" if torch.cuda.is_available() else "cpu"
44-
44+
4545
if not (
4646
(self.checkpoint_path and self.checkpoint_path.is_file())
4747
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
@@ -408,10 +408,10 @@ def _initialize_model(
408408
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")
409409

410410
if builder_args.setup_caches:
411-
# TODO: get this from args?
412-
max_seq_length = 2048
413411
with torch.device(builder_args.device):
414-
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
412+
model.setup_caches(
413+
max_batch_size=1, max_seq_length=model.config.max_seq_length
414+
)
415415

416416
model.to(dtype=builder_args.precision)
417417

build/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ModelArgs:
3737
multiple_of: int = 256
3838
ffn_dim_multiplier: Optional[int] = None
3939
use_tiktoken: bool = False
40+
max_seq_length: int = 8192
4041

4142
def __post_init__(self):
4243
if self.n_local_heads == -1:

chat_in_browser.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def create_app(*args):
2020
["python3", "generate.py", *args], stdin=subprocess.PIPE, stdout=subprocess.PIPE
2121
)
2222

23-
2423
@app.route("/")
2524
def main():
2625
print("Starting chat session.")
@@ -93,7 +92,7 @@ def chat():
9392
# Strip "Model: " from output
9493
model_prefix = "Model: "
9594
if output.startswith(model_prefix):
96-
output = output[len(model_prefix):]
95+
output = output[len(model_prefix) :]
9796

9897
global convo
9998

download.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,16 @@ def _download_hf_snapshot(
3636
ignore_patterns="*safetensors*",
3737
)
3838
except HTTPError as e:
39-
if e.response.status_code == 401: # Missing HuggingFace CLI login.
39+
if e.response.status_code == 401: # Missing HuggingFace CLI login.
4040
print(
4141
"Access denied. Create a HuggingFace account and run 'pip3 install huggingface_hub' and 'huggingface-cli login' to authenticate.",
42-
file=sys.stderr
42+
file=sys.stderr,
4343
)
4444
exit(1)
45-
elif e.response.status_code == 403: # No access to the specific model.
45+
elif e.response.status_code == 403: # No access to the specific model.
4646
# The error message includes a link to request access to the given model. This prints nicely and does not include
4747
# a traceback.
48-
print(
49-
str(e),
50-
file=sys.stderr
51-
)
48+
print(str(e), file=sys.stderr)
5249
exit(1)
5350
else:
5451
raise e

generate.py

+97-40
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import time
1212
from dataclasses import dataclass
1313
from pathlib import Path
14-
from typing import Optional, Tuple, List
14+
from typing import List, Optional, Tuple
1515

1616
import torch
1717
import torch._dynamo.config
@@ -32,6 +32,7 @@
3232
B_INST, E_INST = "[INST]", "[/INST]"
3333
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
3434

35+
3536
class ChatFormat:
3637
def __init__(self, tokenizer):
3738
self.tokenizer = tokenizer
@@ -62,7 +63,6 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
6263
return tokens
6364

6465

65-
6666
@dataclass
6767
class GeneratorArgs:
6868
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
@@ -210,11 +210,17 @@ def decode_n_tokens(
210210
):
211211
new_tokens, new_probs = [], []
212212
encountered_eos = False
213-
for i in range(num_new_tokens - 1): # -1 to save space to run an EoS if dont generate it naturally
213+
for i in range(
214+
num_new_tokens - 1
215+
): # -1 to save space to run an EoS if dont generate it naturally
214216
# Actually better for Inductor to codegen attention here
215217
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
216218
next_token, next_prob = decode_one_token(
217-
model, cur_token.clone(), input_pos, need_probs=need_probs, **sampling_kwargs
219+
model,
220+
cur_token.clone(),
221+
input_pos,
222+
need_probs=need_probs,
223+
**sampling_kwargs,
218224
)
219225
input_pos += 1
220226
new_tokens.append(next_token.clone())
@@ -223,15 +229,25 @@ def decode_n_tokens(
223229
new_probs.append(next_prob.clone())
224230
cur_token = next_token.view(1, -1)
225231
# encountered eos
226-
if (next_token.item() == eos_token_id or (eot_id is not None and next_token.item() == eot_id)):
232+
if next_token.item() == eos_token_id or (
233+
eot_id is not None and next_token.item() == eot_id
234+
):
227235
encountered_eos = True
228-
_, _ = decode_one_token(model, cur_token, input_pos, need_probs, **sampling_kwargs)
236+
_, _ = decode_one_token(
237+
model, cur_token, input_pos, need_probs, **sampling_kwargs
238+
)
229239
input_pos += 1
230240
break
231241
if not encountered_eos:
232-
eos_token = torch.tensor([eos_token_id if eot_id is None else eot_id], dtype=cur_token.dtype, device=cur_token.device)
242+
eos_token = torch.tensor(
243+
[eos_token_id if eot_id is None else eot_id],
244+
dtype=cur_token.dtype,
245+
device=cur_token.device,
246+
)
233247
new_tokens.append(eos_token.clone())
234-
_, _ = decode_one_token(model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs)
248+
_, _ = decode_one_token(
249+
model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs
250+
)
235251
input_pos += 1
236252

237253
return new_tokens, new_probs
@@ -337,7 +353,9 @@ def generate(
337353
with torch.device(device):
338354
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
339355
if is_speculative and draft_model is not model:
340-
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
356+
draft_model.setup_caches(
357+
max_batch_size=1, max_seq_length=max_seq_length
358+
)
341359

342360
# create an empty tensor of the expected final shape and
343361
# fill in the current tokens
@@ -366,7 +384,9 @@ def generate(
366384

367385
num_tokens_generated = 0
368386
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
369-
accept_counts = [0] * (speculate_k + 1) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
387+
accept_counts = [0] * (
388+
speculate_k + 1
389+
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
370390

371391
if is_speculative:
372392
input_pos = input_pos.item() # for speculative decoding easier to keep on host
@@ -392,12 +412,14 @@ def generate(
392412
max_new_tokens - 1,
393413
callback=callback,
394414
need_probs=False,
395-
eos_token_id = tokenizer.eos_id() if tokenizer else 2,
396-
eot_id = tokenizer.special_tokens["<|eot_id|>"] if is_llama3_model else None,
415+
eos_token_id=tokenizer.eos_id() if tokenizer else 2,
416+
eot_id=tokenizer.special_tokens["<|eot_id|>"] if is_llama3_model else None,
397417
**sampling_kwargs,
398418
)
399419
seq[T + 1 : T + 1 + len(generated_tokens)] = torch.cat(generated_tokens)
400-
seq = seq[:T + 1 + len(generated_tokens)] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
420+
seq = seq[
421+
: T + 1 + len(generated_tokens)
422+
] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
401423

402424
generate_stats = {"accept_counts": accept_counts}
403425
return seq, generate_stats
@@ -410,7 +432,6 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"):
410432
return torch.tensor(tokens, dtype=torch.int, device=device)
411433

412434

413-
414435
def get_device_info(name: str) -> str:
415436
import platform
416437
from subprocess import check_output
@@ -481,7 +502,9 @@ def _main(
481502
# Piggy backing off of this flag then for now to identify llama3 without prompting user.
482503
is_llama3_model = tokenizer_args.is_tiktoken
483504
if generator_args.chat_mode and is_llama3_model:
484-
logging.debug("Llama3 model detected in chat mode. Using updated sentence schemas")
505+
logging.debug(
506+
"Llama3 model detected in chat mode. Using updated sentence schemas"
507+
)
485508

486509
builder_args.setup_caches = False
487510
model = _initialize_model(builder_args, quantize, tokenizer)
@@ -534,20 +557,29 @@ def _main(
534557
if generator_args.compile_prefill:
535558
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
536559

537-
system_prompt=None
560+
system_prompt = None
538561
# Set up our max_seq_length
539562
if generator_args.chat_mode:
540-
max_seq_length = 2048
541-
print(f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye")
542-
system_prompt = input("System Prompt [Optional]: ")
563+
max_seq_length = model.config.max_seq_length
564+
print(
565+
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
566+
)
567+
get_system_prompt = input(
568+
"Do you want to enter a system prompt? Enter y for yes and anything else for no. \n"
569+
)
570+
if get_system_prompt == "y" or get_system_prompt == "Y":
571+
system_prompt = input("What is your system prompt? \n")
543572
if is_llama3_model:
544573
chat_formatter = ChatFormat(tokenizer)
545574
else:
546-
max_seq_length = min(encoded.size(0) + generator_args.max_new_tokens, model.config.block_size)
547-
575+
max_seq_length = min(
576+
encoded.size(0) + generator_args.max_new_tokens, model.config.block_size
577+
)
548578

549579
max_seq_length = (
550-
max_seq_length + speculate_k + 1 if draft_model is not None else max_seq_length
580+
max_seq_length + speculative_builder_args.speculate_k + 1
581+
if draft_model is not None
582+
else max_seq_length
551583
)
552584

553585
aggregate_metrics = {
@@ -557,39 +589,59 @@ def _main(
557589
start = -1 if generator_args.compile else 0
558590
start_pos = 0
559591

560-
561592
# arbitrarily large number as chat mode goes until max_seq length or user exits
562593
num_samples = generator_args.num_samples if not generator_args.chat_mode else 100000
563-
i = -1 # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
564-
while (i < num_samples):
594+
i = (
595+
-1
596+
) # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
597+
while i < num_samples:
565598
i += 1
566599
device_sync(device=builder_args.device)
567600
if i >= 0 and generator_args.chat_mode:
568601
prompt = input("User: ")
569-
if (prompt == "/bye"):
602+
if prompt == "/bye":
570603
print("Exiting Chat.\n")
571604
break
572605
if not is_llama3_model:
573606
if system_prompt:
574607
prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}"
575-
system_prompt = None # can only provide system prompt on first interaction
608+
system_prompt = (
609+
None # can only provide system prompt on first interaction
610+
)
576611
else:
577612
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
578613
encoded = encode_tokens(
579614
tokenizer, prompt, bos=True, device=builder_args.device
580615
)
581616
else:
582-
if system_prompt:
583-
encoded = chat_formatter.encode_dialog_prompt([{"role" : "system", "content" : system_prompt}, {"role" : "user", "content" : prompt}])
617+
if system_prompt is not None:
618+
encoded = chat_formatter.encode_dialog_prompt(
619+
[
620+
{"role": "system", "content": system_prompt},
621+
{"role": "user", "content": prompt},
622+
]
623+
)
584624
system_prompt = None
585-
elif(i == 0):
586-
encoded = chat_formatter.encode_dialog_prompt([{"role" : "user", "content" : prompt}])
625+
elif i == 0:
626+
encoded = chat_formatter.encode_dialog_prompt(
627+
[{"role": "user", "content": prompt}]
628+
)
587629
else:
588-
encoded = chat_formatter.encode_message({"role" : "user", "content" : prompt})
589-
encoded.extend(chat_formatter.encode_header({"role": "assistant", "content": ""}))
590-
encoded = torch.tensor(encoded, dtype=torch.int, device=builder_args.device)
591-
if (encoded.size(0) + start_pos > max_seq_length):
592-
print("This prompt would take us past the max_seq_length. Ending Conversation.")
630+
encoded = chat_formatter.encode_message(
631+
{"role": "user", "content": prompt}
632+
)
633+
encoded.extend(
634+
chat_formatter.encode_header(
635+
{"role": "assistant", "content": ""}
636+
)
637+
)
638+
encoded = torch.tensor(
639+
encoded, dtype=torch.int, device=builder_args.device
640+
)
641+
if encoded.size(0) + start_pos > max_seq_length:
642+
print(
643+
"This prompt would take us past the max_seq_length. Ending Conversation."
644+
)
593645
break
594646

595647
if generator_args.chat_mode and i >= 0:
@@ -604,12 +656,17 @@ def callback(
604656
):
605657
if done_generating:
606658
return
607-
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) # I think this results in the first output token being dropped from the display which is wrong.
659+
buffer.append(
660+
tokenizer.decode([period_id] + x.tolist())[1:]
661+
) # I think this results in the first output token being dropped from the display which is wrong.
608662
if x.item() == tokenizer.eos_id():
609663
done_generating = True
610-
if (is_llama3_model and x.item() == tokenizer.special_tokens["<|eot_id|>"]):
664+
if (
665+
is_llama3_model
666+
and x.item() == tokenizer.special_tokens["<|eot_id|>"]
667+
):
611668
done_generating = True
612-
buffer = buffer[:-1] # drop the eot_id from the output buffer
669+
buffer = buffer[:-1] # drop the eot_id from the output buffer
613670
if len(buffer) == 4 or done_generating:
614671
print("".join(buffer), end="", flush=True)
615672
buffer.clear()
@@ -672,7 +729,7 @@ def callback(x):
672729
)
673730
logging.debug(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
674731

675-
if (start_pos >= max_seq_length):
732+
if start_pos >= max_seq_length:
676733
print("Max Sequence Length Reached. Ending Conversation.")
677734
break
678735

quantize.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18-
from build.utils import find_multiple, get_precision, use_et_backend
18+
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend
1919

2020

2121
#########################################################################
@@ -97,11 +97,14 @@ def quantized_model(self) -> nn.Module:
9797

9898

9999
class PrecisionHandler(QuantHandler):
100-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
100+
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
101101
self.model_ = model
102102
self.device = device
103103
self.tokenizer = tokenizer
104-
self.kwargs = kwargs
104+
105+
if isinstance(dtype, str):
106+
dtype = name_to_dtype(dtype)
107+
self.dtype = dtype
105108

106109
def create_quantized_state_dict(self) -> Dict: # "StateDict"
107110
pass
@@ -110,7 +113,7 @@ def convert_for_runtime(self) -> nn.Module:
110113
pass
111114

112115
def quantized_model(self) -> nn.Module:
113-
return self.model_.to(device=self.device, **self.kwargs)
116+
return self.model_.to(device=self.device, dtype=self.dtype)
114117

115118

116119
#########################################################################

0 commit comments

Comments
 (0)