Skip to content

Commit

Permalink
Merge pull request #3 from vllm-project/main
Browse files Browse the repository at this point in the history
Updating Branch
  • Loading branch information
Manikandan-Thangaraj-ZS0321 authored Aug 22, 2024
2 parents 08b8538 + eeee1c3 commit 5129c87
Show file tree
Hide file tree
Showing 69 changed files with 2,120 additions and 578 deletions.
4 changes: 2 additions & 2 deletions .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.409
value: 0.419
- name: "exact_match,flexible-extract"
value: 0.406
value: 0.416
limit: 1000
num_fewshot: 5
15 changes: 12 additions & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ steps:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
- pytest -v -s entrypoints/llm
- pytest -v -s entrypoints/openai

Expand Down Expand Up @@ -311,12 +312,20 @@ steps:
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

- label: Multi-step Tests (4 GPUs) # 10min
- label: Multi-step Tests (4 GPUs) # 21min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
- vllm/
- tests/multi_step/test_correctness.py
- vllm/model_executor/layers/sampler.py
- vllm/sequence.py
- vllm/worker/worker_base.py
- vllm/worker/worker.py
- vllm/worker/multi_step_worker.py
- vllm/worker/model_runner_base.py
- vllm/worker/model_runner.py
- vllm/worker/multi_step_model_runner.py
- vllm/engine
- tests/multi_step
commands:
- pytest -v -s multi_step/test_correctness.py

Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Generate sources:
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ async def async_request_openai_completions(
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"completions"
), "OpenAI Completions API URL must end with 'completions'."
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
Expand Down
43 changes: 43 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,15 @@ def calculate_metrics(
async def benchmark(
backend: str,
api_url: str,
base_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
best_of: int,
use_beam_search: bool,
request_rate: float,
disable_tqdm: bool,
profile: bool,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
Expand All @@ -326,6 +328,22 @@ async def benchmark(
f"are correctly specified. Error: {test_output.error}")
else:
print("Initial test run completed. Starting main benchmark run...")

if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/start_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")

print(f"Traffic request rate: {request_rate}")

pbar = None if disable_tqdm else tqdm(total=len(input_requests))
Expand All @@ -349,6 +367,21 @@ async def benchmark(
pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

if profile:
print("Stopping profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/stop_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler stopped")

if pbar is not None:
pbar.close()

Expand Down Expand Up @@ -433,8 +466,10 @@ def main(args: argparse.Namespace):

if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
base_url = f"{args.base_url}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"

tokenizer = get_tokenizer(tokenizer_id,
trust_remote_code=args.trust_remote_code)
Expand Down Expand Up @@ -506,13 +541,15 @@ def main(args: argparse.Namespace):
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
best_of=args.best_of,
use_beam_search=args.use_beam_search,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
))

# Save config and results to json
Expand Down Expand Up @@ -693,6 +730,12 @@ def main(args: argparse.Namespace):
action="store_true",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--save-result",
action="store_true",
Expand Down
89 changes: 89 additions & 0 deletions benchmarks/kernels/benchmark_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import random
import time

import torch

from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser


@torch.inference_mode()
def main(num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
seed: int = 0,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda")

layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None

def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()

for _ in range(num_iters):
layer(x, residual)
torch.cuda.synchronize()

end_time = time.perf_counter()
if profile:
torch.cuda.cudart().cudaProfilerStart()
return (end_time - start_time) / num_iters

# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=num_warmup_iters, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=num_iters, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
parser = FlexibleArgumentParser(
description="Benchmark the layernorm kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--add-residual", action="store_true")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters",
type=int,
default=100,
help="Number of benchmark iterations. "
"If --profile is set, this number is ignored")

args = parser.parse_args()
print(args)

main(num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
add_residual=args.add_residual,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters)
103 changes: 103 additions & 0 deletions benchmarks/kernels/benchmark_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import random
import time

import torch

from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser


@torch.inference_mode()
def main(num_tokens: int,
hidden_size: int,
static_scale: bool,
quant_dtype: torch.dtype,
dtype: torch.dtype,
seed: int = 0,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda")

x = torch.randn(num_tokens, hidden_size, dtype=dtype)
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None

def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()

for _ in range(num_iters):
if quant_dtype == torch.int8:
ops.scaled_int8_quant(x, scale)
else:
ops.scaled_fp8_quant(x, scale)
torch.cuda.synchronize()

end_time = time.perf_counter()
if profile:
torch.cuda.cudart().cudaProfilerStart()
return (end_time - start_time) / num_iters

# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=num_warmup_iters, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=num_iters, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':

def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError(f"Unsupported dtype: {dt}")

parser = FlexibleArgumentParser(
description="Benchmark the quantization (fp8 or int8) kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--static-scale", action="store_true")
parser.add_argument("--quant-dtype",
type=str,
choices=["fp8", "int8"],
default="int8")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")

parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters",
type=int,
default=100,
help="Number of benchmark iterations. "
"If --profile is set, this number is ignored")

args = parser.parse_args()
print(args)

main(num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
static_scale=args.static_scale,
quant_dtype=to_torch_dtype(args.quant_dtype),
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters)
2 changes: 1 addition & 1 deletion csrc/attention/attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
}

// Finalize the reduction across lanes.
Expand Down
Loading

0 comments on commit 5129c87

Please sign in to comment.