Skip to content

Commit

Permalink
adding batchsize support for torchao llama benchmarks (#1182)
Browse files Browse the repository at this point in the history
Summary: added batchsize argument to torchao llama benchmarks

Test Plan: see benchmarks.sh

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Oct 29, 2024
1 parent cbd90e3 commit aa86005
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 25 deletions.
16 changes: 15 additions & 1 deletion torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,18 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt

# Different Batch Size Benchmarks
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 1
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 128

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 1
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 128

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128
57 changes: 33 additions & 24 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[0, -1], temperature, top_k)
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

Expand All @@ -75,7 +75,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
new_tokens.append(next_token)
callback(new_tokens[-1])
new_probs.append(next_prob)
cur_token = next_token.view(1, -1)
cur_token = next_token

return new_tokens, new_probs

Expand All @@ -88,6 +88,7 @@ def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
batch_size: int,
*,
interactive: bool,
callback = lambda x: x,
Expand All @@ -102,34 +103,34 @@ def generate(

# create an empty tensor of the expected final shape and fill in the current tokens
device = prompt.device
T = prompt.numel()
T = prompt.size(-1)

# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
new_tokens = max_seq_length - T

# format model input
prompt, input_pos = prepare_inputs_for_model(prompt)
prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize

# full prompt+output will be stored in seq
seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device)
seq[:T] = prompt.view(-1)
seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device)
seq[:, :T] = prompt

# setup model caches
with torch.device(device):
if cache_size is None:
cache_size = max_seq_length
assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt"
model.setup_caches(max_batch_size=1, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)

# format model input
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)

# execute prefill
next_token = prefill(model, x, input_pos, **sampling_kwargs).clone()
seq[T] = next_token
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
seq[:, T] = next_token.squeeze()
# execute token generation
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)

seq = torch.cat((seq[:T+1], *generated_tokens))
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)

return seq

Expand Down Expand Up @@ -157,6 +158,7 @@ def main(
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
batch_size: int = 1,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
Expand Down Expand Up @@ -229,9 +231,9 @@ def main(
use_hqq=True
else:
use_hqq=False
groupsize=int(quantization.split("-")[1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model, int4_weight_only(group_size=groupsize))
group_size=int(quantization.split("-")[1])
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
Expand Down Expand Up @@ -267,9 +269,9 @@ def main(
use_hqq = "hqq" in quantization
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear)
if "uintx" in quantization:
# uintx-nbits-groupsize, e.g. "uintx-2-64"
# uintx-nbits-group_size, e.g. "uintx-2-64"
if "hqq" in quantization:
# uintx-nbits-groupsize-hqq
# uintx-nbits-group_size-hqq
use_hqq = True
else:
use_hqq = False
Expand Down Expand Up @@ -303,6 +305,7 @@ def main(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
Expand Down Expand Up @@ -375,6 +378,7 @@ def callback(x):
model,
encoded,
max_new_tokens,
batch_size,
interactive=interactive,
callback=callback,
temperature=temperature,
Expand All @@ -392,13 +396,13 @@ def callback(x):
t = time.perf_counter() - t0

if not interactive:
tok_list = y.tolist()
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())]
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
print(tokenizer.decode(tokens))
else:
print()
tokens_generated = y.size(0) - prompt_length
tokens_generated = (y.size(-1) - prompt_length)
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
Expand All @@ -421,6 +425,8 @@ def callback(x):
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() /1e9
print(f"Average tokens/sec: {tokpersec:.2f}")
if batch_size > 1:
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
print(f"Peak Memory Usage: {mem:.02f} GB")
print(f"Model Size: {model_size:.02f} GB")
Expand All @@ -439,6 +445,7 @@ def callback(x):
result_txt += f"--interactive " if interactive else ""
result_txt += f"--num_samples {num_samples} "
result_txt += f"--max_new_tokens {max_new_tokens} "
result_txt += f"--batch_size {batch_size} "
result_txt += f"--top_k {top_k} "
result_txt += f"--temperature {temperature} "
result_txt += f"--cache_size {cache_size}" if cache_size else ""
Expand All @@ -459,13 +466,15 @@ def callback(x):
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, embed-int8wo'
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
+'embed-int8wo'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
Expand All @@ -484,6 +493,6 @@ def callback(x):

args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)

0 comments on commit aa86005

Please sign in to comment.