Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ _build/
# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h

# Benchmark dataset
*.json
9 changes: 9 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def main(args: argparse.Namespace):
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
)

for batch_size in args.batch_size:
Expand Down Expand Up @@ -152,6 +153,14 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8'],
default='auto',
help='Data type for kv cache storage. If "auto", will use model data '
'type. FP8_E5M2 is only supported on cuda version greater than 11.8. '
'On AMD GPUs, only the more standard FP8_E4M3 is supported for inference.')
parser.add_argument(
'--profile',
action='store_true',
Expand Down
13 changes: 12 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def run_vllm(
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand All @@ -83,6 +84,7 @@ def run_vllm(
dtype=dtype,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager)
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -284,6 +287,14 @@ def main(args: argparse.Namespace):
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
default="auto",
help='Data type for kv cache storage. If "auto", will use model data '
'type. FP8_E5M2 is only supported on cuda version greater than 11.8. '
'On AMD GPUs, only the more standard FP8_E4M3 is supported for inference.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
34 changes: 19 additions & 15 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional
import argparse
import random
import time

import torch

from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._C import ops

NUM_BLOCKS = 1024
Expand All @@ -23,6 +25,7 @@ def main(
dtype: torch.dtype,
seed: int,
do_profile: bool,
kv_cache_dtype: Optional[str] = None,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
Expand Down Expand Up @@ -59,15 +62,10 @@ def main(
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")

# Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
key_cache.uniform_(-scale, scale)
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device="cuda")
value_cache.uniform_(-scale, scale)
key_caches, value_caches = create_kv_caches_with_random(
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
dtype)
key_cache, value_cache = key_caches[0], value_caches[0]

# Prepare for the paged attention kernel.
output = torch.empty_like(query)
Expand Down Expand Up @@ -106,6 +104,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
elif version == "v2":
ops.paged_attention_v2(
Expand All @@ -123,6 +122,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down Expand Up @@ -168,16 +168,19 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
default="half")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
default="auto",
help='Data type for kv cache storage. If "auto", will use model data '
'type. FP8_E5M2 is only supported on cuda version greater than 11.8. '
'On AMD GPUs, only the more standard FP8_E4M3 is supported for inference.')
args = parser.parse_args()
print(args)

if args.num_query_heads % args.num_kv_heads != 0:
raise ValueError("num_query_heads must be divisible by num_kv_heads")
dtype_to_torch_dtype = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
main(
version=args.version,
num_seqs=args.batch_size,
Expand All @@ -187,7 +190,8 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
head_size=args.head_size,
block_size=args.block_size,
use_alibi=args.use_alibi,
dtype=dtype_to_torch_dtype[args.dtype],
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
kv_cache_dtype=args.kv_cache_dtype,
)
1 change: 1 addition & 0 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_fp8.cuh"
Loading