forked from JRosenkranz/fms-extras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpaged_speculative_inference.py
408 lines (362 loc) · 13.7 KB
/
paged_speculative_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import argparse
import itertools
import os
import time
import transformers
import torch
import torch._inductor.config
from fms.models import get_model
from fms.utils import generation, tokenizers
from torch import distributed as dist
import fms_extras.models.paged_gpt_bigcode
import fms_extras.models.paged_llama
from fms_extras.models.speculator import MLPSpeculator
from fms_extras.utils.generation import paged_generate, speculative_generate
# This example script validates the LLaMA implementation by running inference on a couple of prompts.
# torchrun --nproc_per_node=1 scripts/inference.py --variant=7b --model_path=~/models/7B-F --tokenizer=~/models/tokenizer.model --model_source=meta --speculator_path=~/models/speculator_7B_F.pth --compile
parser = argparse.ArgumentParser(
description="Script to run inference on a causal model"
)
parser.add_argument("--device_type", type=str, default="cuda")
parser.add_argument(
"--architecture",
type=str,
default="llama",
help="The model architecture to benchmark",
)
parser.add_argument(
"--variant",
type=str,
default="7b",
help="The model variant (configuration) to benchmark. E.g. 7b, 13b, 70b.",
)
parser.add_argument(
"--model_path",
type=str,
help="Path to the directory containing LLaMa weights (.pth files sharded by tensor parallel rank, not HF weights)",
)
parser.add_argument(
"--speculator_path",
type=str,
default=None,
help="Path to the checkpoint containing speculator weights (single .pth file, not HF weights)",
)
parser.add_argument(
"--speculator_variant",
type=str,
default="840m",
help="The model variant (configuration) to benchmark. E.g. 840m, 1.4b, 2b, etc.",
)
parser.add_argument(
"--speculator_source",
type=str,
default=None,
choices=["hf"],
help="Source format of speculator weights. Note: If the weights path specified in speculator_path are not local and "
"the source is hf, the weights will be pulled using the normal Huggingface from_pretrained method.",
)
parser.add_argument(
"--model_source",
type=str,
help="Source of the checkpoint. E.g. 'meta', 'hf', None",
)
parser.add_argument(
"--checkpoint_sharding",
type=str,
default=None,
help="type of weight sharding. E.g. tensor-parallel (tp), None",
)
parser.add_argument(
"--tokenizer",
type=str,
required=True,
help="Path to the tokenizer (e.g. ~/tokenizer.model)",
)
parser.add_argument(
"--compile",
action="store_true",
help="Use torch.compile (slow for first inference pass)",
)
parser.add_argument(
"--compile_mode",
type=str,
help="Mode for compilation",
default="default",
choices=["default", "reduce-overhead"],
)
parser.add_argument(
"--deterministic",
action="store_true",
help="Set torch.use_deterministic_algorithms? Requires env variable `CUBLAS_WORKSPACE_CONFIG=:4096:8`",
)
parser.add_argument(
"--distributed",
action="store_true",
help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)",
)
parser.add_argument("--context_file", type=str, default=None, help="File to summarize")
parser.add_argument(
"--batch_input",
action="store_true",
help="use a batch of prompts as input (note this is still wip for reduce-overhead=True)",
)
# top_k_tokens_per_head
parser.add_argument(
"--top_k_tokens_per_head",
type=lambda s: list(map(int, s.split(","))),
default=[5, 3, 2],
help="Number of tokens to consider from each head when forming the candidate tree. For each candidate branch in the tree, head n produces topk[n] additional sub-branches.",
)
parser.add_argument(
"--prompt_type",
type=str,
choices=["chat", "code"],
default="chat",
help="type of prompts to be used, either chat or code",
)
parser.add_argument(
"--speculator_load_type",
type=str,
choices=["singlefile", "registered_local", "hf_remote"],
default="singlefile",
help="how to load the speculator",
)
args = parser.parse_args()
if args.batch_input and args.compile and args.compile_mode == "reduce-overhead":
print(
"setting compile_mode to default as cudagraphs is not yet supported with batches"
)
compile_mode = "default"
else:
compile_mode = args.compile_mode
local_rank = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
if args.device_type == "cuda":
device = torch.device(args.device_type, local_rank)
torch.cuda.set_device(device)
else:
device = torch.device(args.device_type)
#torch.set_default_dtype(torch.half)
torch.set_default_dtype(torch.bfloat16)
# requires setting environment variable: `CUBLAS_WORKSPACE_CONFIG=:4096:8`
if args.deterministic:
torch.use_deterministic_algorithms(True)
#if args.distributed:
if True:
dist.init_process_group()
#torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
if args.distributed:
distr_param = "tp"
else:
if torch.cuda.device_count() > 1 and world_size == 1:
distr_param = "mp"
else:
distr_param = None
print("loading model")
model = get_model(
f"paged_{args.architecture}",
args.variant,
model_path=args.model_path,
checkpoint_sharding=args.checkpoint_sharding,
device_type=args.device_type,
source=args.model_source,
distributed_strategy='fsdp',
#distributed_strategy=distr_param,
#group=dist.group.WORLD,
)
decode_model = None
tokenizer = tokenizers.get_tokenizer(args.tokenizer)
model.eval()
torch.set_grad_enabled(False)
speculator = None
if args.speculator_path is not None:
print("loading speculator")
# todo: handling of remote weights in get_model
#is_local = os.path.exists(args.speculator_path) or args.speculator_source != "hf"
if args.speculator_load_type == "singlefile": #manual
print("loading speculator singlefile")
speculator = MLPSpeculator(
#model.config.emb_dim, 4096, model.config.src_vocab_size, n_predict=4
model.config.emb_dim, 6144, model.config.src_vocab_size, n_predict=5, tie_wts=True, scale_input=True
#tie_emb=True, tie_head=True, tie_transition=True, scale_input=True,
)
speculator.load_state_dict(
torch.load(args.speculator_path, map_location=device)["model_state"]
)
elif args.speculator_load_type == "registered_local":
print("loading speculator registered local")
speculator = get_model(
"mlp_speculator",
f"{args.architecture}.{args.variant}.{args.speculator_variant}",
model_path=args.speculator_path,
source=args.speculator_source,
device_type=args.device_type,
)
elif args.speculator_load_type == "hf_remote":
print("loading speculator HF remote")
from fms_extras.models.hf.modeling_mlp_speculator import (
MLPSpeculatorPreTrainedModel, MLPSpeculatorConfig
)
speculator = MLPSpeculatorPreTrainedModel.from_pretrained(
args.speculator_path, #device_map=args.device_type
).speculator
#config = MLPSpeculatorConfig.from_pretrained(args.speculator_path)
#speculator = MLPSpeculatorPreTrainedModel(config)
#speculator.load_state_dict(
# transformers.modeling_utils.load_state_dict(args.speculator_path + '/pytorch_model.bin'),
# strict=False)
#speculator = speculator.speculator
else:
print("Incorrect speculator_load_type")
exit(1)
if local_rank == 0:
total_params = sum(
p.numel() for p in speculator.parameters() if p.requires_grad
)
print(f"\nspeculator has {total_params / 1e6} Million params\n")
#print([i for i,j in speculator.named_parameters()])
#print(torch.cuda.memory_summary())
#exit(0)
speculator = speculator.to(device)
if local_rank == 0:
total_params = sum(
p.numel() for p in speculator.parameters() if p.requires_grad
)
print(f"\nspeculator has {total_params / 1e6} Million params\n")
if len(args.top_k_tokens_per_head) != speculator.n_predict:
print(
"length of top_k_tokens_per_head must be equal to the speculator's number of heads (n_predict)"
)
exit()
print("loading complete on rank", local_rank)
print("initializing paged cache")
# cache setup
from fms_extras.utils.cache.paged import PagedKVCacheManager
use_cache = True
if hasattr(model.config, "kvheads"):
kv_heads = model.config.kvheads
else:
kv_heads = 1 if model.config.multiquery_attn else model.config.nheads
kv_cache_manager = PagedKVCacheManager(
model.config.nlayers,
model.config.nheads,
model.config.emb_dim,
kv_heads=kv_heads,
tensor_parallel_size=dist.get_world_size() if args.distributed else 1,
dtype=torch.get_default_dtype(),
device=device,
total_num_gpu_blocks=2000,
)
print("cache initialization complete on rank", local_rank)
add_special_tokens = tokenizer.bos_token_id != tokenizer.eos_token_id
def ids_for_prompt(prompt):
tokens = tokenizer.tokenize(prompt)
ids = tokenizer.convert_tokens_to_ids(tokens)
if add_special_tokens:
ids = [tokenizer.bos_token_id] + ids
ids = torch.tensor(ids, dtype=torch.long, device=device)
return ids
def print_result(result, inp, n_steps):
if local_rank != 0:
return
# stop at EOS token if present
if add_special_tokens:
result = generation.truncate_after_eos(result, tokenizer.eos_token_id)
print(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(result)))
print(f"{len(result) - len(inp)} tokens in {n_steps} steps")
print()
def infer(ids, warmup):
# With greedy generation (do_sample=False) we _should_ always get the same results.
# There is currently a bug in start_pos for batched rotary embeddings that can lead
# varying results for the same prompt.
if local_rank == 0:
print("==================")
cudagraphs = compile_mode == "reduce-overhead"
max_seq_len = (
model.config.max_expected_seq_len
if hasattr(model.config, "max_expected_seq_len")
else model.config.max_pos
)
if speculator:
result, n_steps, ttft, generated_token_time_out = speculative_generate(
model,
ids,
speculator,
kv_cache_manager,
new_tokens=100,
max_seq_len=max_seq_len,
decode_model=decode_model,
# todo: we can only reduce-overhead for now when batch size is 1
flattening=not (args.compile and compile_mode == "reduce-overhead"),
#cudagraphs=cudagraphs,
threshes=args.top_k_tokens_per_head,
)
else:
result, n_steps, ttft, generated_token_time_out = paged_generate(
model,
ids,
kv_cache_manager,
max_new_tokens=100,
max_seq_len=max_seq_len,
do_sample=False,
decode_model=decode_model,
cudagraphs=cudagraphs,
)
if not warmup:
total_tokens = 0
for i in range(len(result)):
print_result(result[i], ids[i], n_steps)
total_tokens += len(result[i]) - len(ids[i])
avg_tokens = total_tokens / len(result)
print(f"time to first token: {ttft}")
print(f"time per token (decode): {generated_token_time_out / avg_tokens}")
if args.compile:
print("compiling model")
# Bug with kv-cache in PT2.1
torch._inductor.config.joint_graph_constant_folding = False
# compiling can make first inference pass slow
decode_model = model
decode_model = torch.compile(decode_model, mode=compile_mode, fullgraph=True)
model = torch.compile(model, fullgraph=True, dynamic=True)
if speculator:
speculator = torch.compile(speculator, mode=compile_mode)
speculator.generate_suffixes = torch.compile(
speculator.generate_suffixes, mode=compile_mode
)
if args.prompt_type == "chat":
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
prompt1 = template.format(
"Provide a list of instructions for preparing chicken soup."
)
prompt2 = template.format("Explain some popular greetings in Spanish.")
prompt3 = template.format("Explain to me why ignorance is bliss.")
prompt4 = template.format(
"I have just come into a very large sum of money. I received the money from my parents who told me I could do whatever I want with it. My first thought was to go to a financial advisor. Provide me a list of things that I can do with my new found wealth."
)
elif args.prompt_type == "code":
template = "[INST] Write code to solve the following coding problem that obeys the constraints and passes the example test cases. Please wrap your code answer using ```:\n{}\n[/INST]"
prompt1 = template.format("Write a bubble sort function in python.")
prompt2 = template.format(
"Using the Java streams API, write a simple function which will get the cumulative sum of a list of integers."
)
prompt3 = template.format(
"In bash, how do I list all directories and sub-directories which contain a .py file."
)
prompt4 = template.format(
"Write a simple decorator in python which will modify all string inputs to ints if possible."
)
else:
print("prompt_type must be one of chat or code")
exit()
prompt1 = ids_for_prompt(prompt1)
prompt2 = ids_for_prompt(prompt2)
prompt3 = ids_for_prompt(prompt3)
prompt4 = ids_for_prompt(prompt4)
if args.batch_input:
ids = [prompt1, prompt2, prompt3, prompt4]
else:
ids = [prompt1]
infer(ids, warmup=True)
print("generating output", local_rank)
infer(ids, warmup=True)
infer(ids, warmup=False)