diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py
index a378bc6baa5a..e29881fcbac0 100644
--- a/.buildkite/check-wheel-size.py
+++ b/.buildkite/check-wheel-size.py
@@ -8,12 +8,12 @@
# Note that we have 400 MiB quota, please use it wisely.
# See https://github.com/pypi/support/issues/3792 .
# Please also sync the value with the one in Dockerfile.
-VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400))
+VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400))
def print_top_10_largest_files(zip_file):
"""Print the top 10 largest files in the given zip file."""
- with zipfile.ZipFile(zip_file, 'r') as z:
+ with zipfile.ZipFile(zip_file, "r") as z:
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
file_sizes.sort(key=lambda x: x[1], reverse=True)
for f, size in file_sizes[:10]:
@@ -28,14 +28,18 @@ def check_wheel_size(directory):
wheel_path = os.path.join(root, file_name)
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
if wheel_size_mb > VLLM_MAX_SIZE_MB:
- print(f"Not allowed: Wheel {wheel_path} is larger "
- f"({wheel_size_mb:.2f} MB) than the limit "
- f"({VLLM_MAX_SIZE_MB} MB).")
+ print(
+ f"Not allowed: Wheel {wheel_path} is larger "
+ f"({wheel_size_mb:.2f} MB) than the limit "
+ f"({VLLM_MAX_SIZE_MB} MB)."
+ )
print_top_10_largest_files(wheel_path)
return 1
else:
- print(f"Wheel {wheel_path} is within the allowed size "
- f"({wheel_size_mb:.2f} MB).")
+ print(
+ f"Wheel {wheel_path} is within the allowed size "
+ f"({wheel_size_mb:.2f} MB)."
+ )
return 0
@@ -45,4 +49,4 @@ def check_wheel_size(directory):
sys.exit(1)
directory = sys.argv[1]
- sys.exit(check_wheel_size(directory))
\ No newline at end of file
+ sys.exit(check_wheel_size(directory))
diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py
index 36e1b6c01326..270663c415c7 100644
--- a/.buildkite/generate_index.py
+++ b/.buildkite/generate_index.py
@@ -22,5 +22,5 @@
print(f"Generated index.html for {args.wheel}")
# cloudfront requires escaping the '+' character
f.write(
- template.format(wheel=filename,
- wheel_html_escaped=filename.replace("+", "%2B")))
+ template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
+ )
diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml
new file mode 100644
index 000000000000..cca58097e8aa
--- /dev/null
+++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml
@@ -0,0 +1,11 @@
+# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Llama-3.2-1B-Instruct-FP8 -b "auto" -l 1319 -f 5 -t 1
+model_name: "RedHatAI/Llama-3.2-1B-Instruct-FP8"
+tasks:
+- name: "gsm8k"
+ metrics:
+ - name: "exact_match,strict-match"
+ value: 0.335
+ - name: "exact_match,flexible-extract"
+ value: 0.323
+limit: 1319
+num_fewshot: 5
diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
new file mode 100644
index 000000000000..54579a63a9b8
--- /dev/null
+++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
@@ -0,0 +1,11 @@
+# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2.5-1.5B-Instruct -b auto -l 1319 -f 5 -t 1
+model_name: "Qwen/Qwen2.5-1.5B-Instruct"
+tasks:
+- name: "gsm8k"
+ metrics:
+ - name: "exact_match,strict-match"
+ value: 0.54
+ - name: "exact_match,flexible-extract"
+ value: 0.59
+limit: 1319
+num_fewshot: 5
diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
new file mode 100644
index 000000000000..a2f235f48581
--- /dev/null
+++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
@@ -0,0 +1,11 @@
+# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1
+model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
+tasks:
+- name: "gsm8k"
+ metrics:
+ - name: "exact_match,strict-match"
+ value: 0.47
+ - name: "exact_match,flexible-extract"
+ value: 0.64
+limit: 1319
+num_fewshot: 5
diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt
index 37eeac85c933..27a1a9a82bd3 100644
--- a/.buildkite/lm-eval-harness/configs/models-large.txt
+++ b/.buildkite/lm-eval-harness/configs/models-large.txt
@@ -3,3 +3,4 @@ Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
DeepSeek-V2-Lite-Chat.yaml
+Meta-Llama-3-8B-QQQ.yaml
diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt
index 254d01edf844..36e0543879b3 100644
--- a/.buildkite/lm-eval-harness/configs/models-small.txt
+++ b/.buildkite/lm-eval-harness/configs/models-small.txt
@@ -1,10 +1,6 @@
-Meta-Llama-3-8B-Instruct.yaml
-Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
+Qwen2.5-1.5B-Instruct.yaml
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
-Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
+Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
-Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
-Qwen2-1.5B-Instruct-FP8W8.yaml
-Meta-Llama-3-8B-QQQ.yaml
diff --git a/.buildkite/lm-eval-harness/conftest.py b/.buildkite/lm-eval-harness/conftest.py
index a0bcc993ed4a..769d2efda4ad 100644
--- a/.buildkite/lm-eval-harness/conftest.py
+++ b/.buildkite/lm-eval-harness/conftest.py
@@ -8,11 +8,14 @@ def pytest_addoption(parser):
parser.addoption(
"--config-list-file",
action="store",
- help="Path to the file listing model config YAMLs (one per line)")
- parser.addoption("--tp-size",
- action="store",
- default="1",
- help="Tensor parallel size to use for evaluation")
+ help="Path to the file listing model config YAMLs (one per line)",
+ )
+ parser.addoption(
+ "--tp-size",
+ action="store",
+ default="1",
+ help="Tensor parallel size to use for evaluation",
+ )
@pytest.fixture(scope="session")
@@ -33,7 +36,8 @@ def pytest_generate_tests(metafunc):
config_dir = config_list_file.parent
with open(config_list_file, encoding="utf-8") as f:
configs = [
- config_dir / line.strip() for line in f
+ config_dir / line.strip()
+ for line in f
if line.strip() and not line.startswith("#")
]
metafunc.parametrize("config_filename", configs)
diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py
index c5411daf0df6..409a6ca82008 100644
--- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py
+++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py
@@ -16,19 +16,22 @@
def launch_lm_eval(eval_config, tp_size):
- trust_remote_code = eval_config.get('trust_remote_code', False)
- model_args = f"pretrained={eval_config['model_name']}," \
- f"tensor_parallel_size={tp_size}," \
- f"enforce_eager=true," \
- f"add_bos_token=true," \
- f"trust_remote_code={trust_remote_code}"
+ trust_remote_code = eval_config.get("trust_remote_code", False)
+ model_args = (
+ f"pretrained={eval_config['model_name']},"
+ f"tensor_parallel_size={tp_size},"
+ f"enforce_eager=true,"
+ f"add_bos_token=true,"
+ f"trust_remote_code={trust_remote_code}"
+ )
results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks=[task["name"] for task in eval_config["tasks"]],
num_fewshot=eval_config["num_fewshot"],
limit=eval_config["limit"],
- batch_size="auto")
+ batch_size="auto",
+ )
return results
@@ -42,9 +45,10 @@ def test_lm_eval_correctness_param(config_filename, tp_size):
for metric in task["metrics"]:
ground_truth = metric["value"]
measured_value = results["results"][task["name"]][metric["name"]]
- print(f'{task["name"]} | {metric["name"]}: '
- f'ground_truth={ground_truth} | measured={measured_value}')
- success = success and np.isclose(
- ground_truth, measured_value, rtol=RTOL)
+ print(
+ f"{task['name']} | {metric['name']}: "
+ f"ground_truth={ground_truth} | measured={measured_value}"
+ )
+ success = success and np.isclose(ground_truth, measured_value, rtol=RTOL)
assert success
diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py
index 1030ec24e8d7..7f2a2d8dc296 100644
--- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py
+++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py
@@ -65,18 +65,18 @@ def read_markdown(file):
def results_to_json(latency, throughput, serving):
- return json.dumps({
- 'latency': latency.to_dict(),
- 'throughput': throughput.to_dict(),
- 'serving': serving.to_dict()
- })
+ return json.dumps(
+ {
+ "latency": latency.to_dict(),
+ "throughput": throughput.to_dict(),
+ "serving": serving.to_dict(),
+ }
+ )
if __name__ == "__main__":
-
# collect results
for test_file in results_folder.glob("*.json"):
-
with open(test_file) as f:
raw_result = json.loads(f.read())
@@ -120,7 +120,8 @@ def results_to_json(latency, throughput, serving):
for perc in [10, 25, 50, 75, 90, 99]:
# Multiply 1000 to convert the time unit from s to ms
raw_result.update(
- {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]})
+ {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}
+ )
raw_result["avg_latency"] = raw_result["avg_latency"] * 1000
# add the result to raw_result
@@ -153,26 +154,27 @@ def results_to_json(latency, throughput, serving):
serving_results = pd.DataFrame.from_dict(serving_results)
throughput_results = pd.DataFrame.from_dict(throughput_results)
- raw_results_json = results_to_json(latency_results, throughput_results,
- serving_results)
+ raw_results_json = results_to_json(
+ latency_results, throughput_results, serving_results
+ )
# remapping the key, for visualization purpose
if not latency_results.empty:
- latency_results = latency_results[list(
- latency_column_mapping.keys())].rename(
- columns=latency_column_mapping)
+ latency_results = latency_results[list(latency_column_mapping.keys())].rename(
+ columns=latency_column_mapping
+ )
if not serving_results.empty:
- serving_results = serving_results[list(
- serving_column_mapping.keys())].rename(
- columns=serving_column_mapping)
+ serving_results = serving_results[list(serving_column_mapping.keys())].rename(
+ columns=serving_column_mapping
+ )
if not throughput_results.empty:
- throughput_results = throughput_results[list(
- throughput_results_column_mapping.keys())].rename(
- columns=throughput_results_column_mapping)
+ throughput_results = throughput_results[
+ list(throughput_results_column_mapping.keys())
+ ].rename(columns=throughput_results_column_mapping)
- processed_results_json = results_to_json(latency_results,
- throughput_results,
- serving_results)
+ processed_results_json = results_to_json(
+ latency_results, throughput_results, serving_results
+ )
for df in [latency_results, serving_results, throughput_results]:
if df.empty:
@@ -184,38 +186,39 @@ def results_to_json(latency, throughput, serving):
# The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...",
# we want to turn it into "8xGPUTYPE"
df["GPU"] = df["GPU"].apply(
- lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}")
+ lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}"
+ )
# get markdown tables
- latency_md_table = tabulate(latency_results,
- headers='keys',
- tablefmt='pipe',
- showindex=False)
- serving_md_table = tabulate(serving_results,
- headers='keys',
- tablefmt='pipe',
- showindex=False)
- throughput_md_table = tabulate(throughput_results,
- headers='keys',
- tablefmt='pipe',
- showindex=False)
+ latency_md_table = tabulate(
+ latency_results, headers="keys", tablefmt="pipe", showindex=False
+ )
+ serving_md_table = tabulate(
+ serving_results, headers="keys", tablefmt="pipe", showindex=False
+ )
+ throughput_md_table = tabulate(
+ throughput_results, headers="keys", tablefmt="pipe", showindex=False
+ )
# document the result
with open(results_folder / "benchmark_results.md", "w") as f:
-
- results = read_markdown("../.buildkite/nightly-benchmarks/" +
- "performance-benchmarks-descriptions.md")
+ results = read_markdown(
+ "../.buildkite/nightly-benchmarks/"
+ + "performance-benchmarks-descriptions.md"
+ )
results = results.format(
latency_tests_markdown_table=latency_md_table,
throughput_tests_markdown_table=throughput_md_table,
serving_tests_markdown_table=serving_md_table,
- benchmarking_results_in_json_string=processed_results_json)
+ benchmarking_results_in_json_string=processed_results_json,
+ )
f.write(results)
# document benchmarking results in json
with open(results_folder / "benchmark_results.json", "w") as f:
-
- results = latency_results.to_dict(
- orient='records') + throughput_results.to_dict(
- orient='records') + serving_results.to_dict(orient='records')
+ results = (
+ latency_results.to_dict(orient="records")
+ + throughput_results.to_dict(orient="records")
+ + serving_results.to_dict(orient="records")
+ )
f.write(json.dumps(results))
diff --git a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py
index 5e17b79d26a1..778a3a8d87f6 100644
--- a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py
+++ b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py
@@ -14,15 +14,12 @@ def main(model, cachedir):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
- description="Download and save Hugging Face tokenizer")
- parser.add_argument("--model",
- type=str,
- required=True,
- help="Name of the model")
- parser.add_argument("--cachedir",
- type=str,
- required=True,
- help="Directory to save the tokenizer")
+ description="Download and save Hugging Face tokenizer"
+ )
+ parser.add_argument("--model", type=str, required=True, help="Name of the model")
+ parser.add_argument(
+ "--cachedir", type=str, required=True, help="Directory to save the tokenizer"
+ )
args = parser.parse_args()
main(args.model, args.cachedir)
diff --git a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py
index 0ff95a0911b1..10a7a2f5a467 100644
--- a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py
+++ b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py
@@ -11,33 +11,33 @@
def parse_arguments():
parser = argparse.ArgumentParser(
- description=
- 'Parse command line arguments for summary-nightly-results script.')
- parser.add_argument('--results-folder',
- type=str,
- required=True,
- help='The folder where the results are stored.')
- parser.add_argument('--description',
- type=str,
- required=True,
- help='Description of the results.')
+ description="Parse command line arguments for summary-nightly-results script."
+ )
+ parser.add_argument(
+ "--results-folder",
+ type=str,
+ required=True,
+ help="The folder where the results are stored.",
+ )
+ parser.add_argument(
+ "--description", type=str, required=True, help="Description of the results."
+ )
args = parser.parse_args()
return args
def get_perf(df, method, model, metric):
-
means = []
for qps in [2, 4, 8, 16, "inf"]:
- target = df['Test name'].str.contains(model)
- target = target & df['Engine'].str.contains(method)
- target = target & df['Test name'].str.contains("qps_" + str(qps))
+ target = df["Test name"].str.contains(model)
+ target = target & df["Engine"].str.contains(method)
+ target = target & df["Test name"].str.contains("qps_" + str(qps))
filtered_df = df[target]
if filtered_df.empty:
- means.append(0.)
+ means.append(0.0)
else:
means.append(filtered_df[metric].values[0])
@@ -45,7 +45,6 @@ def get_perf(df, method, model, metric):
def get_perf_w_std(df, method, model, metric):
-
if metric in ["TTFT", "ITL"]:
mean = get_perf(df, method, model, "Mean " + metric + " (ms)")
mean = mean.tolist()
@@ -60,7 +59,8 @@ def get_perf_w_std(df, method, model, metric):
else:
assert metric == "Tput"
mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf(
- df, method, model, "Output Tput (tok/s)")
+ df, method, model, "Output Tput (tok/s)"
+ )
mean = mean.tolist()
std = None
@@ -80,18 +80,17 @@ def main(args):
# generate markdown table
df = pd.DataFrame.from_dict(results)
- md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False)
+ md_table = tabulate(df, headers="keys", tablefmt="pipe", showindex=False)
with open(args.description) as f:
description = f.read()
- description = description.format(
- nightly_results_benchmarking_table=md_table)
+ description = description.format(nightly_results_benchmarking_table=md_table)
with open("nightly_results.md", "w") as f:
f.write(description)
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_arguments()
main(args)
diff --git a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py
index 62ee5e10b509..2a7b37991f31 100644
--- a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py
+++ b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py
@@ -34,10 +34,8 @@
}
if __name__ == "__main__":
-
# collect results
for test_file in results_folder.glob("*.json"):
-
with open(test_file) as f:
raw_result = json.loads(f.read())
@@ -56,17 +54,16 @@
serving_results = pd.DataFrame.from_dict(serving_results)
if not serving_results.empty:
- serving_results = serving_results[list(
- serving_column_mapping.keys())].rename(
- columns=serving_column_mapping)
+ serving_results = serving_results[list(serving_column_mapping.keys())].rename(
+ columns=serving_column_mapping
+ )
- serving_md_table_with_headers = tabulate(serving_results,
- headers='keys',
- tablefmt='pipe',
- showindex=False)
+ serving_md_table_with_headers = tabulate(
+ serving_results, headers="keys", tablefmt="pipe", showindex=False
+ )
# remove the first line of header
- serving_md_table_lines = serving_md_table_with_headers.split('\n')
- serving_md_table_without_header = '\n'.join(serving_md_table_lines[2:])
+ serving_md_table_lines = serving_md_table_with_headers.split("\n")
+ serving_md_table_without_header = "\n".join(serving_md_table_lines[2:])
prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE")
@@ -76,10 +73,9 @@
# document results with header.
# for those who wants to reproduce our benchmark.
f.write(serving_md_table_with_headers)
- f.write('\n')
+ f.write("\n")
# document benchmarking results in json
with open(results_folder / f"{prefix}_nightly_results.json", "w") as f:
-
- results = serving_results.to_dict(orient='records')
+ results = serving_results.to_dict(orient="records")
f.write(json.dumps(results))
diff --git a/.buildkite/pyproject.toml b/.buildkite/pyproject.toml
new file mode 100644
index 000000000000..d5cad1c73c6f
--- /dev/null
+++ b/.buildkite/pyproject.toml
@@ -0,0 +1,46 @@
+# This local pyproject file is part of the migration from yapf to ruff format.
+# It uses the same core rules as the main pyproject.toml file, but with the
+# following differences:
+# - ruff line length is overridden to 88
+# - deprecated typing ignores (UP006, UP035) have been removed
+
+[tool.ruff]
+line-length = 88
+
+[tool.ruff.lint.per-file-ignores]
+"vllm/third_party/**" = ["ALL"]
+"vllm/version.py" = ["F401"]
+"vllm/_version.py" = ["ALL"]
+
+[tool.ruff.lint]
+select = [
+ # pycodestyle
+ "E",
+ # Pyflakes
+ "F",
+ # pyupgrade
+ "UP",
+ # flake8-bugbear
+ "B",
+ # flake8-simplify
+ "SIM",
+ # isort
+ "I",
+ # flake8-logging-format
+ "G",
+]
+ignore = [
+ # star imports
+ "F405", "F403",
+ # lambda expression assignment
+ "E731",
+ # Loop control variable not used within loop body
+ "B007",
+ # f-string format
+ "UP032",
+ # Can remove once 3.10+ is the minimum Python version
+ "UP007",
+]
+
+[tool.ruff.format]
+docstring-code-format = true
diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml
index 4cc9c70a6adb..b3c27e2c99c2 100644
--- a/.buildkite/release-pipeline.yaml
+++ b/.buildkite/release-pipeline.yaml
@@ -14,7 +14,7 @@ steps:
agents:
queue: cpu_queue_postmerge
commands:
- - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh"
@@ -31,7 +31,7 @@ steps:
agents:
queue: cpu_queue_postmerge
commands:
- - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh"
@@ -64,7 +64,7 @@ steps:
- "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"
plugins:
- docker-login#v3.0.0:
- username: vllm
+ username: vllmbot
password-env: DOCKERHUB_TOKEN
env:
DOCKER_BUILDKIT: "1"
diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh
index d29903bf497f..bbc896ec6819 100755
--- a/.buildkite/scripts/hardware_ci/run-amd-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh
@@ -3,6 +3,9 @@
# This script runs test inside the corresponding ROCm docker container.
set -o pipefail
+# Export Python path
+export PYTHONPATH=".."
+
# Print ROCm version
echo "--- Confirming Clean Initial State"
while true; do
@@ -74,6 +77,23 @@ HF_MOUNT="/root/.cache/huggingface"
commands=$@
echo "Commands:$commands"
+
+if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then
+ commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
+fi
+
+if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
+ commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
+fi
+
+if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then
+ commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"}
+fi
+
+if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then
+ commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
+fi
+
#ignore certain kernels tests
if [[ $commands == *" kernels/core"* ]]; then
commands="${commands} \
@@ -161,6 +181,8 @@ fi
PARALLEL_JOB_COUNT=8
+MYPYTHONPATH=".."
+
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
if [[ $commands == *"--shard-id="* ]]; then
# assign job count as the number of shards used
@@ -181,6 +203,7 @@ if [[ $commands == *"--shard-id="* ]]; then
-e AWS_SECRET_ACCESS_KEY \
-v "${HF_CACHE}:${HF_MOUNT}" \
-e "HF_HOME=${HF_MOUNT}" \
+ -e "PYTHONPATH=${MYPYTHONPATH}" \
--name "${container_name}_${GPU}" \
"${image_name}" \
/bin/bash -c "${commands_gpu}" \
@@ -211,6 +234,7 @@ else
-e AWS_SECRET_ACCESS_KEY \
-v "${HF_CACHE}:${HF_MOUNT}" \
-e "HF_HOME=${HF_MOUNT}" \
+ -e "PYTHONPATH=${MYPYTHONPATH}" \
--name "${container_name}" \
"${image_name}" \
/bin/bash -c "${commands}"
diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
index 5d863dd82e9b..077bd9914907 100755
--- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
@@ -32,9 +32,12 @@ function cpu_tests() {
set -e
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
pip install sentence-transformers datamodel_code_generator
- pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
- pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
- pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
+ pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
+ pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
+ pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]"
}
# All of CPU tests are expected to be finished less than 40 mins.
diff --git a/.buildkite/scripts/hardware_ci/run-hpu-test.sh b/.buildkite/scripts/hardware_ci/run-hpu-test.sh
index 95b6ac37f185..5efac3ddf469 100644
--- a/.buildkite/scripts/hardware_ci/run-hpu-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-hpu-test.sh
@@ -10,15 +10,17 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu .
# Setup cleanup
# certain versions of HPU software stack have a bug that can
# override the exit code of the script, so we need to use
-# separate remove_docker_container and remove_docker_container_and_exit
+# separate remove_docker_containers and remove_docker_containers_and_exit
# functions, while other platforms only need one remove_docker_container
# function.
EXITCODE=1
-remove_docker_container() { docker rm -f hpu-test || true; }
-remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; }
-trap remove_docker_container_and_exit EXIT
-remove_docker_container
+remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; }
+remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; }
+trap remove_docker_containers_and_exit EXIT
+remove_docker_containers
# Run the image and launch offline inference
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
+docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2
+
EXITCODE=$?
diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh
index ec6a080eb499..3d294ea5f8a7 100644
--- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-neuron-test.sh
@@ -11,13 +11,14 @@ container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
HF_CACHE="$(realpath ~)/huggingface"
mkdir -p "${HF_CACHE}"
HF_MOUNT="/root/.cache/huggingface"
+HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN)
NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache"
mkdir -p "${NEURON_COMPILE_CACHE_URL}"
NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache"
# Try building the docker image
-aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
+aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws
# prune old image and containers to save disk space, and only once a day
# by using a timestamp file in tmp.
@@ -47,8 +48,16 @@ trap remove_docker_container EXIT
docker run --rm -it --device=/dev/neuron0 --network bridge \
-v "${HF_CACHE}:${HF_MOUNT}" \
-e "HF_HOME=${HF_MOUNT}" \
+ -e "HF_TOKEN=${HF_TOKEN}" \
-v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
--name "${container_name}" \
${image_name} \
- /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys"
+ /bin/bash -c "
+ python3 /workspace/vllm/examples/offline_inference/neuron.py;
+ python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
+ for f in /workspace/vllm/tests/neuron/2_core/*.py; do
+ echo 'Running test file: '$f;
+ python3 -m pytest \$f -v --capture=tee-sys;
+ done
+ "
\ No newline at end of file
diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
index 939daddad92b..2d375d7e9d87 100755
--- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
@@ -26,27 +26,27 @@ docker run --privileged --net host --shm-size=16G -it \
&& tpu-info \
&& { \
echo TEST_0: Running test_perf.py; \
- pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \
echo TEST_0_EXIT_CODE: \$?; \
} & \
- && { \
+ { \
echo TEST_1: Running test_compilation.py; \
- pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \
echo TEST_1_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_2: Running test_basic.py; \
- pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \
echo TEST_2_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
- pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
+ python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
echo TEST_3_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_4: Running test_quantization_accuracy.py; \
- pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \
echo TEST_4_EXIT_CODE: \$?; \
} & \
{ \
@@ -56,43 +56,43 @@ docker run --privileged --net host --shm-size=16G -it \
} & \
{ \
echo TEST_6: Running test_tpu_model_runner.py; \
- pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \
echo TEST_6_EXIT_CODE: \$?; \
} & \
- && { \
+ { \
echo TEST_7: Running test_sampler.py; \
- pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \
echo TEST_7_EXIT_CODE: \$?; \
} & \
- && { \
+ { \
echo TEST_8: Running test_topk_topp_sampler.py; \
- pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \
echo TEST_8_EXIT_CODE: \$?; \
} & \
- && { \
+ { \
echo TEST_9: Running test_multimodal.py; \
- pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \
echo TEST_9_EXIT_CODE: \$?; \
} & \
- && { \
+ { \
echo TEST_10: Running test_pallas.py; \
- pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \
echo TEST_10_EXIT_CODE: \$?; \
} & \
- && { \
+ { \
echo TEST_11: Running test_struct_output_generate.py; \
- pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \
echo TEST_11_EXIT_CODE: \$?; \
} & \
- && { \
+ { \
echo TEST_12: Running test_moe_pallas.py; \
- pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \
+ python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \
echo TEST_12_EXIT_CODE: \$?; \
} & \
# Disable the TPU LoRA tests until the feature is activated
- # && { \
+ # & { \
# echo TEST_13: Running test_moe_pallas.py; \
- # pytest -s -v /workspace/vllm/tests/tpu/lora/; \
+ # python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/; \
# echo TEST_13_EXIT_CODE: \$?; \
# } & \
wait \
diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh
index 75e3ef264095..037897e53dbe 100644
--- a/.buildkite/scripts/upload-wheels.sh
+++ b/.buildkite/scripts/upload-wheels.sh
@@ -75,3 +75,4 @@ else
fi
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
+aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html"
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 01d04759f536..80a5a610c8ac 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -32,16 +32,17 @@ steps:
##### fast check tests #####
- label: Documentation Build # 2min
- working_dir: "/vllm-workspace/test_docs/docs"
+ mirror_hardwares: [amdexperimental]
+ working_dir: "/vllm-workspace/test_docs"
fast_check: true
no_gpu: True
commands:
- - pip install -r ../../requirements/docs.txt
- - SPHINXOPTS=\"-W\" make html
- # Check API reference (if it fails, you may have missing mock imports)
- - grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html
+ - pip install -r ../requirements/docs.txt
+ # TODO: add `--strict` once warnings in docstrings are fixed
+ - mkdocs build
- label: Async Engine, Inputs, Utils, Worker Test # 24min
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/mq_llm_engine
@@ -57,11 +58,13 @@ steps:
- pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
- pytest -v -s test_inputs.py
+ - pytest -v -s test_outputs.py
- pytest -v -s multimodal
- pytest -v -s test_utils.py # Utils
- pytest -v -s worker # Worker
- label: Python-only Installation Test
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- tests/standalone_tests/python_only_compile.sh
- setup.py
@@ -69,7 +72,7 @@ steps:
- bash standalone_tests/python_only_compile.sh
- label: Basic Correctness Test # 30min
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
fast_check: true
torch_nightly: true
source_file_dependencies:
@@ -86,6 +89,7 @@ steps:
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- label: Chunked Prefill Test
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/basic_correctness/test_chunked_prefill
@@ -94,7 +98,7 @@ steps:
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test # 10min
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
fast_check: true
source_file_dependencies:
- vllm/core
@@ -104,10 +108,10 @@ steps:
- pytest -v -s core
- label: Entrypoints Test # 40min
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
fast_check: true
torch_nightly: true
- #mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/entrypoints/llm
@@ -121,11 +125,12 @@ steps:
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
- - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py
+ - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
- pytest -v -s entrypoints/test_chat_utils.py
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
- label: Distributed Tests (4 GPUs) # 10min
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@@ -133,6 +138,7 @@ steps:
- vllm/core/
- tests/distributed/test_utils
- tests/distributed/test_pynccl
+ - tests/distributed/test_events
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile/test_basic_correctness
- examples/offline_inference/rlhf.py
@@ -143,22 +149,25 @@ steps:
# test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
+ # test with tp=2 and pp=2
+ - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
+ - pytest -v -s distributed/test_events.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- pushd ../examples/offline_inference
- - python3 rlhf.py
- - RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
+ - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
+ - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
- label: Metrics, Tracing Test # 10min
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
num_gpus: 2
source_file_dependencies:
- vllm/
@@ -172,7 +181,7 @@ steps:
##### 1 GPU test #####
- label: Regression Test # 5min
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/
- tests/test_regression
@@ -182,7 +191,7 @@ steps:
working_dir: "/vllm-workspace/tests" # optional
- label: Engine Test # 10min
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/
- tests/engine
@@ -196,7 +205,7 @@ steps:
- pytest -v -s tokenization
- label: V1 Test
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/v1
@@ -209,10 +218,11 @@ steps:
- pytest -v -s v1/worker
- pytest -v -s v1/structured_output
- pytest -v -s v1/spec_decode
+ - pytest -v -s v1/kv_connector/unit
- pytest -v -s v1/test_serial_utils.py
- - pytest -v -s v1/test_stats.py
- pytest -v -s v1/test_utils.py
- pytest -v -s v1/test_oracle.py
+ - pytest -v -s v1/test_metrics_reader.py
# TODO: accuracy does not match, whether setting
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
- pytest -v -s v1/e2e
@@ -221,8 +231,8 @@ steps:
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
- label: Examples Test # 25min
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/examples"
- #mirror_hardwares: [amd]
source_file_dependencies:
- vllm/entrypoints
- examples/
@@ -237,7 +247,7 @@ steps:
- python3 offline_inference/vision_language.py --seed 0
- python3 offline_inference/vision_language_embedding.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0
- - VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
+ - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder.py
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
- python3 offline_inference/basic/classify.py
@@ -246,7 +256,7 @@ steps:
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
- label: Prefix Caching Test # 9min
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/
- tests/prefix_caching
@@ -254,6 +264,7 @@ steps:
- pytest -v -s prefix_caching
- label: Samplers Test # 36min
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
@@ -264,7 +275,7 @@ steps:
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
- label: LogitsProcessor Test # 5min
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/model_executor/layers
- vllm/model_executor/guided_decoding
@@ -275,6 +286,7 @@ steps:
- pytest -v -s model_executor/test_guided_processors.py
- label: Speculative decoding tests # 40min
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/spec_decode
- tests/spec_decode
@@ -285,7 +297,7 @@ steps:
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
- label: LoRA Test %N # 15min each
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/lora
- tests/lora
@@ -293,6 +305,7 @@ steps:
parallelism: 4
- label: PyTorch Compilation Unit Tests
+ mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies:
- vllm/
@@ -300,9 +313,12 @@ steps:
commands:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
+ - pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
+ - pytest -v -s compile/test_async_tp.py
- label: PyTorch Fullgraph Smoke Test # 9min
+ mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies:
- vllm/
@@ -314,6 +330,7 @@ steps:
- pytest -v -s compile/piecewise/test_toy_llama.py
- label: PyTorch Fullgraph Test # 18min
+ mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies:
- vllm/
@@ -322,7 +339,7 @@ steps:
- pytest -v -s compile/test_full_graph.py
- label: Kernels Core Operation Test
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- csrc/
- tests/kernels/core
@@ -330,7 +347,7 @@ steps:
- pytest -v -s kernels/core
- label: Kernels Attention Test %N
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- csrc/attention/
- vllm/attention
@@ -341,7 +358,7 @@ steps:
parallelism: 2
- label: Kernels Quantization Test %N
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- csrc/quantization/
- vllm/model_executor/layers/quantization
@@ -351,7 +368,7 @@ steps:
parallelism: 2
- label: Kernels MoE Test
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/moe/
- tests/kernels/moe
@@ -360,7 +377,7 @@ steps:
- pytest -v -s kernels/moe
- label: Kernels Mamba Test
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/mamba/
- tests/kernels/mamba
@@ -368,25 +385,28 @@ steps:
- pytest -v -s kernels/mamba
- label: Tensorizer Test # 11min
- # mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
soft_fail: true
source_file_dependencies:
- vllm/model_executor/model_loader
- tests/tensorizer_loader
+ - tests/entrypoints/openai/test_tensorizer_entrypoint.py
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s tensorizer_loader
+ - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
- label: Benchmarks # 9min
+ mirror_hardwares: [amdexperimental, amdproduction]
working_dir: "/vllm-workspace/.buildkite"
- mirror_hardwares: [amd]
source_file_dependencies:
- benchmarks/
commands:
- bash scripts/run-benchmarks.sh
- label: Benchmarks CLI Test # 10min
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/
- tests/benchmarks/
@@ -394,6 +414,7 @@ steps:
- pytest -v -s benchmarks/
- label: Quantization Test
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
@@ -402,6 +423,7 @@ steps:
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
- label: LM Eval Small Models # 53min
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- csrc/
@@ -411,6 +433,7 @@ steps:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- label: OpenAI API correctness
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/
- vllm/entrypoints/openai/
@@ -419,6 +442,7 @@ steps:
- pytest -s entrypoints/openai/correctness/
- label: Encoder Decoder tests # 5min
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/encoder_decoder
@@ -426,8 +450,8 @@ steps:
- pytest -v -s encoder_decoder
- label: OpenAI-Compatible Tool Use # 20 min
+ mirror_hardwares: [amdexperimental]
fast_check: false
- #mirror_hardwares: [ amd ]
source_file_dependencies:
- vllm/
- tests/tool_use
@@ -439,6 +463,7 @@ steps:
##### models test #####
- label: Basic Models Test # 24min
+ mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies:
- vllm/
@@ -448,43 +473,55 @@ steps:
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_utils.py
- pytest -v -s models/test_vision.py
- # V1 Test: https://github.com/vllm-project/vllm/issues/14531
- - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
- - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
+ - pytest -v -s models/test_initialization.py
- label: Language Models Test (Standard)
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental]
+ torch_nightly: true
source_file_dependencies:
- vllm/
- tests/models/language
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
+ - pip freeze | grep -E 'torch'
- pytest -v -s models/language -m core_model
-- label: Language Models Test (Extended)
+- label: Language Models Test (Extended Generation) # 1hr20min
+ mirror_hardwares: [amdexperimental]
optional: true
source_file_dependencies:
- vllm/
- - tests/models/language
+ - tests/models/language/generation
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
- - pytest -v -s models/language -m 'not core_model'
+ - pytest -v -s models/language/generation -m 'not core_model'
+
+- label: Language Models Test (Extended Pooling) # 36min
+ mirror_hardwares: [amdexperimental]
+ optional: true
+ source_file_dependencies:
+ - vllm/
+ - tests/models/language/pooling
+ commands:
+ - pytest -v -s models/language/pooling -m 'not core_model'
- label: Multi-Modal Models Test (Standard)
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental]
+ torch_nightly: true
source_file_dependencies:
- vllm/
- tests/models/multimodal
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
+ - pip freeze | grep -E 'torch'
- pytest -v -s models/multimodal/processing
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Models Test (Extended) 1
+ mirror_hardwares: [amdexperimental]
optional: true
source_file_dependencies:
- vllm/
@@ -494,6 +531,7 @@ steps:
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
- label: Multi-Modal Models Test (Extended) 2
+ mirror_hardwares: [amdexperimental]
optional: true
source_file_dependencies:
- vllm/
@@ -503,6 +541,7 @@ steps:
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
- label: Multi-Modal Models Test (Extended) 3
+ mirror_hardwares: [amdexperimental, amdproduction]
optional: true
source_file_dependencies:
- vllm/
@@ -512,7 +551,7 @@ steps:
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
- label: Quantized Models Test
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/model_executor/layers/quantization
- tests/models/quantization
@@ -521,7 +560,7 @@ steps:
# This test is used only in PR development phase to test individual models and should never run on main
- label: Custom Models Test
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
optional: true
commands:
- echo 'Testing custom models...'
@@ -533,7 +572,7 @@ steps:
##### multi gpus test #####
- label: Distributed Comm Ops Test # 7min
- mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental, amdproduction]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
@@ -544,6 +583,7 @@ steps:
- pytest -v -s distributed/test_shm_broadcast.py
- label: 2 Node Tests (4 GPUs in total) # 16min
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
num_nodes: 2
@@ -562,7 +602,7 @@ steps:
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- label: Distributed Tests (2 GPUs) # 40min
- #mirror_hardwares: [amd]
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
@@ -599,13 +639,14 @@ steps:
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- label: Plugin Tests (2 GPUs) # 40min
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
- vllm/plugins/
- tests/plugins/
commands:
- # begin platform plugin tests, all the code in-between runs on dummy platform
+ # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform
- pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y
@@ -616,8 +657,10 @@ steps:
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process
+ - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
- label: Multi-step Tests (4 GPUs) # 36min
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@@ -638,6 +681,7 @@ steps:
- pytest -v -s multi_step/test_correctness_llm.py
- label: Pipeline Parallelism Test # 45min
+ mirror_hardwares: [amdexperimental, amdproduction]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@@ -651,6 +695,7 @@ steps:
- pytest -v -s distributed/test_pipeline_parallel.py
- label: LoRA TP Test (Distributed)
+ mirror_hardwares: [amdexperimental, amdproduction]
num_gpus: 4
source_file_dependencies:
- vllm/lora
@@ -666,6 +711,7 @@ steps:
- label: Weight Loading Multiple GPU Test # 33min
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
@@ -675,6 +721,7 @@ steps:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
- label: Weight Loading Multiple GPU Test - Large Models # optional
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
gpu: a100
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 76aa5f7a35d5..4452ce22d504 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -13,6 +13,7 @@
/vllm/model_executor/guided_decoding @mgoin @russellb
/vllm/multimodal @DarkLight1337 @ywang96
/vllm/vllm_flash_attn @LucasWilkinson
+/vllm/lora @jeejeelee
CMakeLists.txt @tlrmchlsmth
# vLLM V1
@@ -40,3 +41,8 @@ CMakeLists.txt @tlrmchlsmth
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb
/tests/v1/structured_output @mgoin @russellb
/tests/weight_loading @mgoin @youkaichao
+/tests/lora @jeejeelee
+
+# Docs
+/docs @hmellor
+mkdocs.yaml @hmellor
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml
index 00b0f024c0da..f05be2ba8707 100644
--- a/.github/ISSUE_TEMPLATE/400-bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml
@@ -81,14 +81,14 @@ body:
required: true
- type: markdown
attributes:
- value: >
- โ ๏ธ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the models' output:
+ value: |
+ โ ๏ธ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the model's output:
- Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc).
- If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect.
- Thanks for contributing ๐!
+ Thanks for reporting ๐!
- type: checkboxes
id: askllm
attributes:
diff --git a/.github/ISSUE_TEMPLATE/450-ci-failure.yml b/.github/ISSUE_TEMPLATE/450-ci-failure.yml
new file mode 100644
index 000000000000..7af0e0673a2f
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/450-ci-failure.yml
@@ -0,0 +1,69 @@
+name: ๐งช CI failure report
+description: Report a failing test.
+title: "[CI Failure]: "
+labels: ["ci-failure"]
+
+body:
+- type: markdown
+ attributes:
+ value: >
+ #### Include the name of the failing Buildkite step and test file in the title.
+- type: input
+ attributes:
+ label: Name of failing test
+ description: |
+ Paste in the fully-qualified name of the failing test from the logs.
+ placeholder: |
+ `path/to/test_file.py::test_name[params]`
+ validations:
+ required: true
+- type: checkboxes
+ attributes:
+ label: Basic information
+ description: Select all items that apply to the failing test.
+ options:
+ - label: Flaky test
+ - label: Can reproduce locally
+ - label: Caused by external libraries (e.g. bug in `transformers`)
+- type: textarea
+ attributes:
+ label: ๐งช Describe the failing test
+ description: |
+ Please provide a clear and concise description of the failing test.
+ placeholder: |
+ A clear and concise description of the failing test.
+
+ ```
+ The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present.
+ ```
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: ๐ History of failing test
+ description: |
+ Since when did the test start to fail?
+ You can look up its history via [Buildkite Test Suites](https://buildkite.com/organizations/vllm/analytics/suites/ci-1/tests?branch=main).
+
+ If you have time, identify the PR that caused the test to fail on main. You can do so via the following methods:
+
+ - Use Buildkite Test Suites to find the PR where the test failure first occurred, and reproduce the failure locally.
+
+ - Run [`git bisect`](https://git-scm.com/docs/git-bisect) locally.
+
+ - Manually unblock Buildkite steps for suspected PRs on main and check the results. (authorized users only)
+ placeholder: |
+ Approximate timeline and/or problematic PRs
+
+ A link to the Buildkite analytics of the failing test (if available)
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: CC List.
+ description: >
+ The list of people you want to CC. Usually, this includes those who worked on the PR that failed the test.
+- type: markdown
+ attributes:
+ value: >
+ Thanks for reporting ๐!
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 7042e81a84da..65be771b94fb 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -3,4 +3,4 @@ FILL IN THE PR DESCRIPTION HERE
FIX #xxxx (*link existing issues this PR will resolve*)
-**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions)
+**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions)
diff --git a/.github/mergify.yml b/.github/mergify.yml
index 15fa3660a87d..e595060c325a 100644
--- a/.github/mergify.yml
+++ b/.github/mergify.yml
@@ -58,7 +58,7 @@ pull_request_rules:
- files~=^benchmarks/structured_schemas/
- files=benchmarks/benchmark_serving_structured_output.py
- files=benchmarks/run_structured_output_benchmark.sh
- - files=docs/source/features/structured_outputs.md
+ - files=docs/features/structured_outputs.md
- files=examples/offline_inference/structured_outputs.py
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
@@ -135,9 +135,7 @@ pull_request_rules:
- files~=^tests/entrypoints/openai/tool_parsers/
- files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py
- files~=^vllm/entrypoints/openai/tool_parsers/
- - files=docs/source/features/tool_calling.md
- - files=docs/source/getting_started/examples/openai_chat_completion_client_with_tools.md
- - files=docs/source/getting_started/examples/chat_with_tools.md
+ - files=docs/features/tool_calling.md
- files~=^examples/tool_chat_*
- files=examples/offline_inference/chat_with_tools.py
- files=examples/online_serving/openai_chat_completion_client_with_tools_required.py
@@ -163,6 +161,17 @@ pull_request_rules:
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
+- name: assign reviewer for tensorizer changes
+ conditions:
+ - files~=^vllm/model_executor/model_loader/tensorizer.py
+ - files~=^vllm/model_executor/model_loader/tensorizer_loader.py
+ - files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
+ - files~=^tests/tensorizer_loader/
+ actions:
+ assign:
+ users:
+ - "sangstar"
+
- name: remove 'needs-rebase' label when conflict is resolved
conditions:
- -conflict
diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh
index 3246c6f9bc4b..8d65936fba1d 100755
--- a/.github/scripts/cleanup_pr_body.sh
+++ b/.github/scripts/cleanup_pr_body.sh
@@ -26,7 +26,7 @@ sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}"
# Remove HTML section that includes text of "PR Checklist (Click to Expand)"
python3 - <
-
-
+
+
@@ -16,18 +16,20 @@ Easy, fast, and cheap LLM serving for everyone
---
*Latest News* ๐ฅ
+- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
+- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
+- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
+
+
+Previous News
+
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
-- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
-
-
-Previous News
-
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
@@ -56,7 +58,7 @@ vLLM is fast with:
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
- Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph
-- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8.
+- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, INT8, and FP8.
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
- Speculative decoding
- Chunked prefill
@@ -72,7 +74,7 @@ vLLM is flexible and easy to use with:
- OpenAI-compatible API server
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
- Prefix caching support
-- Multi-lora support
+- Multi-LoRA support
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
@@ -98,14 +100,14 @@ Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
## Contributing
We welcome and value any contributions and collaborations.
-Please check out [Contributing to vLLM](https://docs.vllm.ai/en/stable/contributing/overview.html) for how to get involved.
+Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
## Sponsors
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
-
+
Cash Donations:
- a16z
- Dropbox
diff --git a/benchmarks/README.md b/benchmarks/README.md
index 4a8ab895e18e..ecab570bb31c 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -146,10 +146,9 @@ python3 vllm/benchmarks/benchmark_serving.py \
``` bash
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
- --speculative-model "[ngram]" \
--ngram_prompt_lookup_min 2 \
--ngram-prompt-lookup-max 5 \
- --num_speculative_tokens 5
+ --speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5}
```
``` bash
@@ -274,10 +273,9 @@ python3 vllm/benchmarks/benchmark_throughput.py \
--output-len=100 \
--num-prompts=2048 \
--async-engine \
- --speculative-model="[ngram]" \
--ngram_prompt_lookup_min=2 \
--ngram-prompt-lookup-max=5 \
- --num_speculative_tokens=5
+ --speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5}
```
```
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index e6a67fda6827..88616e1108c5 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -12,8 +12,7 @@
import aiohttp
import huggingface_hub.constants
from tqdm.asyncio import tqdm
-from transformers import (AutoTokenizer, PreTrainedTokenizer,
- PreTrainedTokenizerFast)
+from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
# NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed.
@@ -43,8 +42,7 @@ class RequestFuncOutput:
latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token
- itl: list[float] = field(
- default_factory=list) # list of inter-token latencies
+ itl: list[float] = field(default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
error: str = ""
@@ -57,8 +55,9 @@ async def async_request_tgi(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
- async with aiohttp.ClientSession(trust_env=True,
- timeout=AIOHTTP_TIMEOUT) as session:
+ async with aiohttp.ClientSession(
+ trust_env=True, timeout=AIOHTTP_TIMEOUT
+ ) as session:
params = {
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
@@ -105,8 +104,7 @@ async def async_request_tgi(
# Decoding phase
else:
- output.itl.append(timestamp -
- most_recent_timestamp)
+ output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
@@ -133,8 +131,9 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
- async with aiohttp.ClientSession(trust_env=True,
- timeout=AIOHTTP_TIMEOUT) as session:
+ async with aiohttp.ClientSession(
+ trust_env=True, timeout=AIOHTTP_TIMEOUT
+ ) as session:
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
@@ -159,8 +158,7 @@ async def async_request_trt_llm(
if not chunk_bytes:
continue
- chunk = chunk_bytes.decode("utf-8").removeprefix(
- "data:")
+ chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
@@ -172,8 +170,7 @@ async def async_request_trt_llm(
# Decoding phase
else:
- output.itl.append(timestamp -
- most_recent_timestamp)
+ output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
@@ -197,9 +194,14 @@ async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
- async with aiohttp.ClientSession(trust_env=True,
- timeout=AIOHTTP_TIMEOUT) as session:
+ api_url = request_func_input.api_url
+ assert api_url.endswith(("completions", "profile")), (
+ "OpenAI Completions API URL must end with 'completions' or 'profile'."
+ )
+ async with aiohttp.ClientSession(
+ trust_env=True, timeout=AIOHTTP_TIMEOUT
+ ) as session:
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
@@ -207,6 +209,8 @@ async def async_request_deepspeed_mii(
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
"top_p": 1.0,
}
+ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
+
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -217,19 +221,21 @@ async def async_request_deepspeed_mii(
st = time.perf_counter()
try:
- async with session.post(url=request_func_input.api_url,
- json=payload) as response:
+ async with session.post(
+ url=api_url, json=payload, headers=headers
+ ) as response:
if response.status == 200:
parsed_resp = await response.json()
output.latency = time.perf_counter() - st
if "choices" in parsed_resp:
- output.generated_text = parsed_resp["choices"][0][
- "text"]
+ output.generated_text = parsed_resp["choices"][0]["text"]
elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0]
else:
- output.error = ("Unexpected response format: "
- "neither 'choices' nor 'text' found")
+ output.error = (
+ "Unexpected response format: "
+ "neither 'choices' nor 'text' found"
+ )
output.success = False
output.success = True
else:
@@ -250,15 +256,17 @@ async def async_request_openai_completions(
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
- assert api_url.endswith(
- ("completions", "profile")
- ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
+ assert api_url.endswith(("completions", "profile")), (
+ "OpenAI Completions API URL must end with 'completions' or 'profile'."
+ )
- async with aiohttp.ClientSession(trust_env=True,
- timeout=AIOHTTP_TIMEOUT) as session:
+ async with aiohttp.ClientSession(
+ trust_env=True, timeout=AIOHTTP_TIMEOUT
+ ) as session:
payload = {
- "model": request_func_input.model_name \
- if request_func_input.model_name else request_func_input.model,
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
@@ -273,9 +281,7 @@ async def async_request_openai_completions(
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
- headers = {
- "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
- }
+ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -284,8 +290,9 @@ async def async_request_openai_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
- async with session.post(url=api_url, json=payload,
- headers=headers) as response:
+ async with session.post(
+ url=api_url, json=payload, headers=headers
+ ) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
@@ -293,8 +300,7 @@ async def async_request_openai_completions(
if not chunk_bytes:
continue
- chunk = chunk_bytes.decode("utf-8").removeprefix(
- "data: ")
+ chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
@@ -314,21 +320,20 @@ async def async_request_openai_completions(
# Decoding phase
else:
- output.itl.append(timestamp -
- most_recent_timestamp)
+ output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
- output.output_tokens = usage.get(
- "completion_tokens")
+ output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
- "This response will be marked as failed!")
+ "This response will be marked as failed!"
+ )
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
@@ -349,23 +354,22 @@ async def async_request_openai_chat_completions(
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
- assert api_url.endswith(
- ("chat/completions", "profile")
- ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
+ assert api_url.endswith(("chat/completions", "profile")), (
+ "OpenAI Chat Completions API URL must end with 'chat/completions'."
+ )
- async with aiohttp.ClientSession(trust_env=True,
- timeout=AIOHTTP_TIMEOUT) as session:
+ async with aiohttp.ClientSession(
+ trust_env=True, timeout=AIOHTTP_TIMEOUT
+ ) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
- "model": request_func_input.model_name \
- if request_func_input.model_name else request_func_input.model,
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
"messages": [
- {
- "role": "user",
- "content": content
- },
+ {"role": "user", "content": content},
],
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
@@ -391,16 +395,16 @@ async def async_request_openai_chat_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
- async with session.post(url=api_url, json=payload,
- headers=headers) as response:
+ async with session.post(
+ url=api_url, json=payload, headers=headers
+ ) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
- chunk = chunk_bytes.decode("utf-8").removeprefix(
- "data: ")
+ chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
@@ -414,13 +418,11 @@ async def async_request_openai_chat_completions(
# Decoding phase
else:
- output.itl.append(timestamp -
- most_recent_timestamp)
+ output.itl.append(timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
- output.output_tokens = usage.get(
- "completion_tokens")
+ output.output_tokens = usage.get("completion_tokens")
most_recent_timestamp = timestamp
@@ -446,25 +448,28 @@ async def async_request_openai_audio(
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
+
api_url = request_func_input.api_url
- assert api_url.endswith(
- ("transcriptions", "translations"
- )), "OpenAI Chat Completions API URL must end with 'transcriptions' "
+ assert api_url.endswith(("transcriptions", "translations")), (
+ "OpenAI Chat Completions API URL must end with 'transcriptions' "
+ )
"or `translations`."
- async with aiohttp.ClientSession(trust_env=True,
- timeout=AIOHTTP_TIMEOUT) as session:
+ async with aiohttp.ClientSession(
+ trust_env=True, timeout=AIOHTTP_TIMEOUT
+ ) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
- "model": request_func_input.model_name \
- if request_func_input.model_name else request_func_input.model,
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"language": "en",
# Flattened due to multipart/form-data
"stream_include_usage": True,
- "stream_continuous_usage_stats": True
+ "stream_continuous_usage_stats": True,
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
@@ -479,9 +484,9 @@ def to_bytes(y, sr):
buffer.seek(0)
return buffer
- with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
+ with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData()
- form.add_field('file', f, content_type='audio/wav')
+ form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
@@ -493,24 +498,22 @@ def to_bytes(y, sr):
st = time.perf_counter()
most_recent_timestamp = st
try:
- async with session.post(url=api_url,
- data=form,
- headers=headers) as response:
+ async with session.post(
+ url=api_url, data=form, headers=headers
+ ) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
- chunk = chunk_bytes.decode("utf-8").removeprefix(
- "data: ")
+ chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
- content = choices[0]["delta"].get(
- "content")
+ content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
@@ -519,12 +522,14 @@ def to_bytes(y, sr):
# Decoding phase
else:
output.itl.append(
- timestamp - most_recent_timestamp)
+ timestamp - most_recent_timestamp
+ )
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
- "completion_tokens")
+ "completion_tokens"
+ )
most_recent_timestamp = timestamp
@@ -545,7 +550,7 @@ def to_bytes(y, sr):
def get_model(pretrained_model_name_or_path: str) -> str:
- if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
+ if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
from modelscope import snapshot_download
from vllm.model_executor.model_loader.weight_utils import get_lock
@@ -556,7 +561,8 @@ def get_model(pretrained_model_name_or_path: str) -> str:
model_path = snapshot_download(
model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
- ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
+ ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
+ )
return model_path
return pretrained_model_name_or_path
@@ -569,23 +575,23 @@ def get_tokenizer(
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists(
- pretrained_model_name_or_path):
- pretrained_model_name_or_path = get_model(
- pretrained_model_name_or_path)
+ pretrained_model_name_or_path
+ ):
+ pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
- raise ValueError(
- "Cannot use the fast tokenizer in slow tokenizer mode.")
+ raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
from vllm.transformers_utils.tokenizer import MistralTokenizer
except ImportError as e:
- raise ImportError("MistralTokenizer requires vllm package.\n"
- "Please install it with `pip install vllm` "
- "to use mistral tokenizer mode.") from e
- return MistralTokenizer.from_pretrained(
- str(pretrained_model_name_or_path))
+ raise ImportError(
+ "MistralTokenizer requires vllm package.\n"
+ "Please install it with `pip install vllm` "
+ "to use mistral tokenizer mode."
+ ) from e
+ return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
else:
return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
@@ -608,7 +614,7 @@ def get_tokenizer(
}
OPENAI_COMPATIBLE_BACKENDS = [
- k for k, v in ASYNC_REQUEST_FUNCS.items()
- if v in (async_request_openai_completions,
- async_request_openai_chat_completions)
+ k
+ for k, v in ASYNC_REQUEST_FUNCS.items()
+ if v in (async_request_openai_completions, async_request_openai_chat_completions)
]
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index 98d3360cd6ff..5513a5f78f1c 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -35,6 +35,7 @@
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
+from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
logger = logging.getLogger(__name__)
@@ -82,14 +83,12 @@ def __init__(
self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
- self.random_seed = (random_seed
- if random_seed is not None else self.DEFAULT_SEED)
+ self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
self.data = None
def apply_multimodal_chat_transformation(
- self,
- prompt: str,
- mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
+ self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
+ ) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
@@ -111,8 +110,7 @@ def load_data(self) -> None:
NotImplementedError: If a subclass does not implement this method.
"""
# TODO (jenniferzhao): add support for downloading data
- raise NotImplementedError(
- "load_data must be implemented in subclasses.")
+ raise NotImplementedError("load_data must be implemented in subclasses.")
def get_random_lora_request(
self,
@@ -158,8 +156,9 @@ def get_random_lora_request(
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
@abstractmethod
- def sample(self, tokenizer: PreTrainedTokenizerBase,
- num_requests: int) -> list[SampleRequest]:
+ def sample(
+ self, tokenizer: PreTrainedTokenizerBase, num_requests: int
+ ) -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
@@ -177,8 +176,9 @@ def sample(self, tokenizer: PreTrainedTokenizerBase,
"""
raise NotImplementedError("sample must be implemented in subclasses.")
- def maybe_oversample_requests(self, requests: list[SampleRequest],
- num_requests: int) -> None:
+ def maybe_oversample_requests(
+ self, requests: list[SampleRequest], num_requests: int
+ ) -> None:
"""
Oversamples the list of requests if its size is less than the desired
number.
@@ -189,11 +189,9 @@ def maybe_oversample_requests(self, requests: list[SampleRequest],
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
- additional = random.choices(requests,
- k=num_requests - len(requests))
+ additional = random.choices(requests, k=num_requests - len(requests))
requests.extend(additional)
- logger.info("Oversampled requests to reach %d total samples.",
- num_requests)
+ logger.info("Oversampled requests to reach %d total samples.", num_requests)
# -----------------------------------------------------------------------------
@@ -218,14 +216,14 @@ def is_valid_sequence(
"""
# Check for invalid conditions
prompt_too_short = prompt_len < min_len
- output_too_short = (not skip_min_output_len_check) and (output_len
- < min_len)
+ output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met
- return not (prompt_too_short or output_too_short or prompt_too_long
- or combined_too_long)
+ return not (
+ prompt_too_short or output_too_short or prompt_too_long or combined_too_long
+ )
@cache
@@ -257,28 +255,28 @@ def process_image(image: Any) -> Mapping[str, Any]:
Raises:
ValueError: If the input is not a supported type.
"""
- if isinstance(image, dict) and 'bytes' in image:
- image = Image.open(BytesIO(image['bytes']))
+ if isinstance(image, dict) and "bytes" in image:
+ image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image):
- image = image.convert("RGB")
+ image = convert_image_mode(image, "RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
- image_base64 = base64.b64encode(
- image_data.getvalue()).decode("utf-8")
+ image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
return {
"type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{image_base64}"
- },
+ "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
}
if isinstance(image, str):
- image_url = (image if image.startswith(
- ("http://", "file://")) else f"file://{image}")
+ image_url = (
+ image if image.startswith(("http://", "file://")) else f"file://{image}"
+ )
return {"type": "image_url", "image_url": {"url": image_url}}
- raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
- " or str or dictionary with raw image bytes.")
+ raise ValueError(
+ f"Invalid image input {image}. Must be a PIL.Image.Image"
+ " or str or dictionary with raw image bytes."
+ )
# -----------------------------------------------------------------------------
@@ -318,8 +316,11 @@ def sample(
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
- prefix_token_ids = (np.random.randint(
- 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
+ prefix_token_ids = (
+ np.random.randint(0, vocab_size, size=prefix_len).tolist()
+ if prefix_len > 0
+ else []
+ )
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(real_input_len * (1 - range_ratio))
@@ -329,21 +330,17 @@ def sample(
# Add logging for debugging
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
- logger.info("Sampling output_len from [%s, %s]", output_low,
- output_high)
-
- input_lens = np.random.randint(input_low,
- input_high + 1,
- size=num_requests)
- output_lens = np.random.randint(output_low,
- output_high + 1,
- size=num_requests)
+ logger.info("Sampling output_len from [%s, %s]", output_low, output_high)
+
+ input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
+ output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests)
requests = []
for i in range(num_requests):
- inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
- vocab_size).tolist()
+ inner_seq = (
+ (offsets[i] + i + np.arange(input_lens[i])) % vocab_size
+ ).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
@@ -354,8 +351,9 @@ def sample(
# [1650, 939, 486] -> ['ฤ call', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
- re_encoded_sequence = tokenizer.encode(
- prompt, add_special_tokens=False)[:input_lens[i]]
+ re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
+ : input_lens[i]
+ ]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i])
requests.append(
@@ -363,7 +361,8 @@ def sample(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
- ))
+ )
+ )
return requests
@@ -390,7 +389,8 @@ def load_data(self) -> None:
self.data = json.load(f)
# Filter entries with at least two conversation turns.
self.data = [
- entry for entry in self.data
+ entry
+ for entry in self.data
if "conversations" in entry and len(entry["conversations"]) >= 2
]
random.seed(self.random_seed)
@@ -416,27 +416,28 @@ def sample(
)
lora_request, tokenizer = self.get_random_lora_request(
- tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
+ tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
+ )
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
- new_output_len = (len(completion_ids)
- if output_len is None else output_len)
- if not is_valid_sequence(prompt_len,
- new_output_len,
- skip_min_output_len_check=output_len
- is not None):
+ new_output_len = len(completion_ids) if output_len is None else output_len
+ if not is_valid_sequence(
+ prompt_len,
+ new_output_len,
+ skip_min_output_len_check=output_len is not None,
+ ):
continue
if enable_multimodal_chat:
- prompt = self.apply_multimodal_chat_transformation(
- prompt, None)
+ prompt = self.apply_multimodal_chat_transformation(prompt, None)
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=new_output_len,
lora_request=lora_request,
- ))
+ )
+ )
self.maybe_oversample_requests(samples, num_requests)
return samples
@@ -482,20 +483,20 @@ def sample(
) -> list:
# Calculate average token length for a poem line.
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
- avg_len = sum(len(tokens)
- for tokens in tokenized_lines) / len(tokenized_lines)
+ avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
# Build the base prompt.
base_prompt = "Pick as many lines as you can from these poem lines:\n"
base_msg = [{"role": "user", "content": base_prompt}]
- base_fmt = tokenizer.apply_chat_template(base_msg,
- add_generation_prompt=True,
- tokenize=False)
+ base_fmt = tokenizer.apply_chat_template(
+ base_msg, add_generation_prompt=True, tokenize=False
+ )
base_offset = len(tokenizer(base_fmt).input_ids)
if input_len <= base_offset:
raise ValueError(
f"'input_len' must be higher than the base prompt length "
- f"({base_offset}).")
+ f"({base_offset})."
+ )
# Determine how many poem lines to use.
num_input_lines = round((input_len - base_offset) / avg_len)
@@ -504,21 +505,23 @@ def sample(
samples = []
while len(samples) < num_requests:
- extra_lines = random.choices(self.data,
- k=num_input_lines - num_prefix_lines)
+ extra_lines = random.choices(
+ self.data, k=num_input_lines - num_prefix_lines
+ )
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
msg = [{"role": "user", "content": prompt}]
prompt_formatted = tokenizer.apply_chat_template(
- msg, add_generation_prompt=True, tokenize=False)
+ msg, add_generation_prompt=True, tokenize=False
+ )
prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len:
samples.append(
SampleRequest(
- prompt=prompt_formatted
- if return_prompt_formatted else prompt,
+ prompt=prompt_formatted if return_prompt_formatted else prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
- ))
+ )
+ )
return samples
@@ -538,7 +541,9 @@ def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
- def load_data(self, ):
+ def load_data(
+ self,
+ ):
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
@@ -552,8 +557,7 @@ def load_data(self, ):
def _sample_loaded_data(self, num_requests: int) -> list:
if num_requests <= len(self.data):
- data = self.data.sample(n=num_requests,
- random_state=self.random_seed)
+ data = self.data.sample(n=num_requests, random_state=self.random_seed)
else:
data = self.data.sample(
n=num_requests,
@@ -577,7 +581,8 @@ def sample(
input_len = int(data[i][2])
output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request(
- tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
+ tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
+ )
vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size.
@@ -589,7 +594,8 @@ def sample(
prompt_len=input_len,
expected_output_len=output_len,
lora_request=lora_req,
- ))
+ )
+ )
return samples
@@ -632,20 +638,23 @@ def load_data(self) -> None:
class ConversationDataset(HuggingFaceDataset):
"""Dataset for conversation data with multimodal support."""
+
SUPPORTED_DATASET_PATHS = {
- 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
+ "lmms-lab/LLaVA-OneVision-Data",
+ "Aeala/ShareGPT_Vicuna_unfiltered",
}
IS_MULTIMODAL = True
- def sample(self,
- tokenizer: PreTrainedTokenizerBase,
- num_requests: int,
- output_len: Optional[int] = None,
- enable_multimodal_chat: bool = False,
- **kwargs) -> list:
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs,
+ ) -> list:
# Filter examples with at least 2 conversations
- filtered_data = self.data.filter(
- lambda x: len(x["conversations"]) >= 2)
+ filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
dynamic_output = output_len is None
@@ -661,24 +670,22 @@ def sample(self,
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
- if dynamic_output and not is_valid_sequence(
- prompt_len, completion_len):
+ if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
continue
- mm_content = process_image(
- item["image"]) if "image" in item else None
+ mm_content = process_image(item["image"]) if "image" in item else None
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len and output len
- prompt = self.apply_multimodal_chat_transformation(
- prompt, mm_content)
+ prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
- ))
+ )
+ )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@@ -695,10 +702,8 @@ class VisionArenaDataset(HuggingFaceDataset):
DEFAULT_OUTPUT_LEN = 128
SUPPORTED_DATASET_PATHS = {
- "lmarena-ai/VisionArena-Chat":
- lambda x: x["conversation"][0][0]["content"],
- "lmarena-ai/vision-arena-bench-v0.1":
- lambda x: x["turns"][0][0]["content"]
+ "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
+ "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
}
IS_MULTIMODAL = True
@@ -710,16 +715,14 @@ def sample(
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
- output_len = (output_len
- if output_len is not None else self.DEFAULT_OUTPUT_LEN)
+ output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
if parser_fn is None:
- raise ValueError(
- f"Unsupported dataset path: {self.dataset_path}")
+ raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
prompt = parser_fn(item)
mm_content = process_image(item["images"][0])
prompt_len = len(tokenizer(prompt).input_ids)
@@ -727,15 +730,15 @@ def sample(
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len
- prompt = self.apply_multimodal_chat_transformation(
- prompt, mm_content)
+ prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
- ))
+ )
+ )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@@ -760,14 +763,15 @@ class InstructCoderDataset(HuggingFaceDataset):
"likaixin/InstructCoder",
}
- def sample(self,
- tokenizer: PreTrainedTokenizerBase,
- num_requests: int,
- output_len: Optional[int] = None,
- enable_multimodal_chat: bool = False,
- **kwargs) -> list:
- output_len = (output_len
- if output_len is not None else self.DEFAULT_OUTPUT_LEN)
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs,
+ ) -> list:
+ output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
@@ -779,7 +783,8 @@ def sample(self,
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
- ))
+ )
+ )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@@ -794,38 +799,38 @@ class MTBenchDataset(HuggingFaceDataset):
MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench
- We create a single turn dataset for MT-Bench.
+ We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
- """ # noqa: E501
+ """ # noqa: E501
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
"philschmid/mt-bench",
}
- def sample(self,
- tokenizer: PreTrainedTokenizerBase,
- num_requests: int,
- output_len: Optional[int] = None,
- enable_multimodal_chat: bool = False,
- **kwargs) -> list:
- output_len = (output_len
- if output_len is not None else self.DEFAULT_OUTPUT_LEN)
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs,
+ ) -> list:
+ output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
- prompt = item['turns'][0]
+ prompt = item["turns"][0]
# apply template
- prompt = tokenizer.apply_chat_template([{
- "role": "user",
- "content": prompt
- }],
- add_generation_prompt=True,
- tokenize=False)
+ prompt = tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}],
+ add_generation_prompt=True,
+ tokenize=False,
+ )
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
@@ -833,7 +838,8 @@ def sample(self,
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
- ))
+ )
+ )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@@ -847,23 +853,27 @@ class AIMODataset(HuggingFaceDataset):
"""
Dataset class for processing a AIMO dataset with reasoning questions.
"""
+
SUPPORTED_DATASET_PATHS = {
- "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
- "AI-MO/NuminaMath-CoT"
+ "AI-MO/aimo-validation-aime",
+ "AI-MO/NuminaMath-1.5",
+ "AI-MO/NuminaMath-CoT",
}
- def sample(self,
- tokenizer: PreTrainedTokenizerBase,
- num_requests: int,
- output_len: Optional[int] = None,
- **kwargs) -> list:
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ **kwargs,
+ ) -> list:
sampled_requests = []
dynamic_output = output_len is None
for item in self.data:
if len(sampled_requests) >= num_requests:
break
- prompt, completion = item['problem'], item["solution"]
+ prompt, completion = item["problem"], item["solution"]
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
@@ -871,10 +881,9 @@ def sample(self,
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
- if dynamic_output and not is_valid_sequence(prompt_len,
- completion_len,
- max_prompt_len=2048,
- max_total_len=32000):
+ if dynamic_output and not is_valid_sequence(
+ prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
+ ):
continue
sampled_requests.append(
SampleRequest(
@@ -882,7 +891,8 @@ def sample(self,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
- ))
+ )
+ )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@@ -905,25 +915,25 @@ def sample(self,
### Response:
-""" # noqa: E501
+""" # noqa: E501
def _format_zeta_prompt(
- sample: dict,
- original_start_marker: str = "<|editable_region_start|>") -> dict:
+ sample: dict, original_start_marker: str = "<|editable_region_start|>"
+) -> dict:
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
-
- This function formats examples from the NEP dataset
- into prompts and expected outputs. It could be
+
+ This function formats examples from the NEP dataset
+ into prompts and expected outputs. It could be
further extended to support more NEP datasets.
-
+
Args:
- sample: The dataset sample containing events,
+ sample: The dataset sample containing events,
inputs, and outputs.
- original_start_marker: The marker indicating the
- start of the editable region. Defaults to
+ original_start_marker: The marker indicating the
+ start of the editable region. Defaults to
"<|editable_region_start|>".
-
+
Returns:
A dictionary with the formatted prompts and expected outputs.
"""
@@ -953,10 +963,8 @@ class NextEditPredictionDataset(HuggingFaceDataset):
"zed-industries/zeta": _format_zeta_prompt,
}
- def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
- **kwargs):
- formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
- self.dataset_path)
+ def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
+ formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
@@ -967,8 +975,10 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
prompt=sample["prompt"],
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len(
- tokenizer(sample["expected_output"]).input_ids),
- ))
+ tokenizer(sample["expected_output"]).input_ids
+ ),
+ )
+ )
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples, num_requests)
@@ -997,18 +1007,22 @@ class ASRDataset(HuggingFaceDataset):
| AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+
- """ # noqa: E501
+ """ # noqa: E501
+
SUPPORTED_DATASET_PATHS = {
- "openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
- "edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
+ "openslr/librispeech_asr",
+ "facebook/voxpopuli",
+ "LIUM/tedlium",
+ "edinburghcstr/ami",
+ "speechcolab/gigaspeech",
+ "kensho/spgispeech",
}
DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True
# TODO Whisper-specific. Abstract interface when more models are supported.
- TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
- "<|notimestamps|>"
+ TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
skip_long_audios: bool = True
def sample(
@@ -1019,8 +1033,8 @@ def sample(
**kwargs,
) -> list:
import librosa
- output_len = (output_len
- if output_len is not None else self.DEFAULT_OUTPUT_LEN)
+
+ output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
@@ -1043,10 +1057,14 @@ def sample(
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
- ))
+ )
+ )
if skipped:
- logger.warning("%d samples discarded from dataset due to" \
- " their length being greater than" \
- " what Whisper supports.", skipped)
+ logger.warning(
+ "%d samples discarded from dataset due to"
+ " their length being greater than"
+ " what Whisper supports.",
+ skipped,
+ )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
index dfd9bb1e6a4d..84759c5c354d 100644
--- a/benchmarks/benchmark_latency.py
+++ b/benchmarks/benchmark_latency.py
@@ -11,9 +11,9 @@
import numpy as np
import torch
-from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm
+from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
@@ -21,13 +21,14 @@
from vllm.utils import FlexibleArgumentParser
-def save_to_pytorch_benchmark_format(args: argparse.Namespace,
- results: dict[str, Any]) -> None:
+def save_to_pytorch_benchmark_format(
+ args: argparse.Namespace, results: dict[str, Any]
+) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={"latency": results["latencies"]},
- extra_info={k: results[k]
- for k in ["avg_latency", "percentiles"]})
+ extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
+ )
if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
@@ -42,9 +43,11 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= (
- args.input_len +
- args.output_len), ("Please ensure that max_model_len is greater than"
- " the sum of input_len and output_len.")
+ args.input_len + args.output_len
+ ), (
+ "Please ensure that max_model_len is greater than"
+ " the sum of input_len and output_len."
+ )
sampling_params = SamplingParams(
n=args.n,
@@ -55,18 +58,16 @@ def main(args: argparse.Namespace):
detokenize=not args.disable_detokenize,
)
print(sampling_params)
- dummy_prompt_token_ids = np.random.randint(10000,
- size=(args.batch_size,
- args.input_len))
- dummy_prompts: list[PromptType] = [{
- "prompt_token_ids": batch
- } for batch in dummy_prompt_token_ids.tolist()]
+ dummy_prompt_token_ids = np.random.randint(
+ 10000, size=(args.batch_size, args.input_len)
+ )
+ dummy_prompts: list[PromptType] = [
+ {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
+ ]
def llm_generate():
if not args.use_beam_search:
- llm.generate(dummy_prompts,
- sampling_params=sampling_params,
- use_tqdm=False)
+ llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
@@ -80,12 +81,13 @@ def llm_generate():
def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
- activities=[
- torch.profiler.ProfilerActivity.CPU,
- torch.profiler.ProfilerActivity.CUDA,
- ],
- on_trace_ready=torch.profiler.tensorboard_trace_handler(
- str(profile_dir)),
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
+ str(profile_dir)
+ ),
) as p:
llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total"))
@@ -103,8 +105,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
- profile_dir = (Path(".") / "vllm_benchmark_result" /
- f"latency_result_{time.time()}")
+ profile_dir = (
+ Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
+ )
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
@@ -135,7 +138,8 @@ def run_to_completion(profile_dir: Optional[str] = None):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the latency of processing a single batch of "
- "requests till completion.")
+ "requests till completion."
+ )
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
@@ -152,10 +156,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
default=10,
help="Number of iterations to run for warmup.",
)
- parser.add_argument("--num-iters",
- type=int,
- default=30,
- help="Number of iterations to run.")
+ parser.add_argument(
+ "--num-iters", type=int, default=30, help="Number of iterations to run."
+ )
parser.add_argument(
"--profile",
action="store_true",
@@ -165,8 +168,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
"--profile-result-dir",
type=str,
default=None,
- help=("path to save the pytorch profiler output. Can be visualized "
- "with ui.perfetto.dev or Tensorboard."),
+ help=(
+ "path to save the pytorch profiler output. Can be visualized "
+ "with ui.perfetto.dev or Tensorboard."
+ ),
)
parser.add_argument(
"--output-json",
@@ -177,10 +182,15 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument(
"--disable-detokenize",
action="store_true",
- help=("Do not detokenize responses (i.e. do not include "
- "detokenization time in the latency measurement)"),
+ help=(
+ "Do not detokenize responses (i.e. do not include "
+ "detokenization time in the latency measurement)"
+ ),
)
parser = EngineArgs.add_cli_args(parser)
+ # V1 enables prefix caching by default which skews the latency
+ # numbers. We need to disable prefix caching by default.
+ parser.set_defaults(enable_prefix_caching=False)
args = parser.parse_args()
main(args)
diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py
index 21480578edbd..109624c87789 100644
--- a/benchmarks/benchmark_long_document_qa_throughput.py
+++ b/benchmarks/benchmark_long_document_qa_throughput.py
@@ -76,7 +76,7 @@ def repeat_prompts(prompts, repeat_count, mode: str):
- 'random': Shuffle the prompts randomly after repetition.
- 'tile': Repeat the entire prompt list in sequence.
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
- - 'interleave': Repeat each prompt consecutively before moving to
+ - 'interleave': Repeat each prompt consecutively before moving to
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
Returns:
@@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str):
ValueError: If an invalid mode is provided.
"""
print("Repeat mode: ", mode)
- if mode == 'random':
+ if mode == "random":
repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts)
return repeated_prompts
- elif mode == 'tile':
+ elif mode == "tile":
return prompts * repeat_count
- elif mode == 'interleave':
+ elif mode == "interleave":
repeated_prompts = []
for prompt in prompts:
repeated_prompts.extend([prompt] * repeat_count)
return repeated_prompts
else:
- raise ValueError(f"Invalid mode: {mode}, only support "
- "'random', 'tile', 'interleave'")
+ raise ValueError(
+ f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
+ )
def main(args):
@@ -109,16 +110,16 @@ def main(args):
# we append the document id at the beginning to avoid any of the document
# being the prefix of other documents
prompts = [
- str(i) + ' '.join(['hi'] * args.document_length)
+ str(i) + " ".join(["hi"] * args.document_length)
for i in range(args.num_documents)
]
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
warmup_prompts = [
- "This is warm up request " + str(i) + \
- ' '.join(['hi'] * args.document_length)
- for i in range(args.num_documents)]
+ "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
+ for i in range(args.num_documents)
+ ]
# Create the LLM engine
engine_args = EngineArgs.from_cli_args(args)
@@ -142,42 +143,52 @@ def main(args):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description=
- 'Benchmark the performance with or without automatic prefix caching.')
+ description="Benchmark the performance with or "
+ "without automatic prefix caching."
+ )
parser.add_argument(
- '--document-length',
+ "--document-length",
type=int,
# Roughly the number of tokens for a system paper,
# excluding images
default=20000,
- help='Range of input lengths for sampling prompts,'
- 'specified as "min:max" (e.g., "128:256").')
-
- parser.add_argument('--num-documents',
- type=int,
- default=8,
- help='Range of input lengths for sampling prompts,'
- 'specified as "min:max" (e.g., "128:256").')
-
- parser.add_argument('--output-len', type=int, default=10)
-
- parser.add_argument('--repeat-count',
- type=int,
- default=2,
- help='Number of times to repeat each prompt')
-
- parser.add_argument("--repeat-mode",
- type=str,
- default='random',
- help='The mode to repeat prompts. The supported '
- 'modes are "random", "tile", and "interleave". '
- 'See repeat_prompts() in the source code for details.')
-
- parser.add_argument("--shuffle-seed",
- type=int,
- default=0,
- help='Random seed when the repeat mode is "random"')
+ help="Range of input lengths for sampling prompts, "
+ 'specified as "min:max" (e.g., "128:256").',
+ )
+
+ parser.add_argument(
+ "--num-documents",
+ type=int,
+ default=8,
+ help="Range of input lengths for sampling prompts, "
+ 'specified as "min:max" (e.g., "128:256").',
+ )
+
+ parser.add_argument("--output-len", type=int, default=10)
+
+ parser.add_argument(
+ "--repeat-count",
+ type=int,
+ default=2,
+ help="Number of times to repeat each prompt",
+ )
+
+ parser.add_argument(
+ "--repeat-mode",
+ type=str,
+ default="random",
+ help="The mode to repeat prompts. The supported "
+ 'modes are "random", "tile", and "interleave". '
+ "See repeat_prompts() in the source code for details.",
+ )
+
+ parser.add_argument(
+ "--shuffle-seed",
+ type=int,
+ default=0,
+ help='Random seed when the repeat mode is "random"',
+ )
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py
index f44da95d3216..ffaa8035797c 100644
--- a/benchmarks/benchmark_prefix_caching.py
+++ b/benchmarks/benchmark_prefix_caching.py
@@ -63,8 +63,7 @@ class Request:
output_len: int
-def sample_tokens(tokenizer: PreTrainedTokenizerBase,
- length: int) -> list[int]:
+def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids)
@@ -91,8 +90,10 @@ def sample_requests_from_dataset(
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
- dataset = [(data["conversations"][0]["value"],
- data["conversations"][1]["value"]) for data in dataset]
+ dataset = [
+ (data["conversations"][0]["value"], data["conversations"][1]["value"])
+ for data in dataset
+ ]
# Shuffle the dataset.
random.shuffle(dataset)
@@ -113,8 +114,9 @@ def sample_requests_from_dataset(
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
- output_len = (len(completion_token_ids)
- if fixed_output_len is None else fixed_output_len)
+ output_len = (
+ len(completion_token_ids) if fixed_output_len is None else fixed_output_len
+ )
if min_len <= prompt_len <= max_len:
filtered_requests.append(Request(prompt, prompt_len, output_len))
@@ -128,27 +130,27 @@ def sample_requests_from_random(
fixed_output_len: Optional[int],
prefix_len: int,
) -> list[Request]:
-
requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
min_len, max_len = input_length_range
for i in range(num_requests):
unique_part_token_ids = sample_tokens(
- tokenizer,
- random.randint(min_len - prefix_len, max_len - prefix_len))
+ tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
+ )
prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer.decode(prompt_token_ids)
prompt_len = len(prompt_token_ids)
- assert (min_len <= prompt_len <= max_len
- ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
+ assert min_len <= prompt_len <= max_len, (
+ f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
+ )
requests.append(Request(prompt, prompt_len, fixed_output_len))
return requests
-def repeat_and_sort_requests(requests: list[Request],
- repeat_count: int,
- sort: bool = False) -> list[str]:
+def repeat_and_sort_requests(
+ requests: list[Request], repeat_count: int, sort: bool = False
+) -> list[str]:
repeated_requests = requests * repeat_count
if sort:
repeated_requests.sort(key=lambda x: x[1])
@@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request],
def main(args):
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
- input_length_range = tuple(map(int, args.input_length_range.split(':')))
+ input_length_range = tuple(map(int, args.input_length_range.split(":")))
random.seed(args.seed)
if args.dataset_path is not None:
if args.prefix_len > 0:
- raise ValueError("prefix-len is not supported when "
- "dataset-path is provided.")
- print(f"Start to sample {args.num_prompts} prompts "
- f"from {args.dataset_path}")
+ raise ValueError(
+ "prefix-len is not supported when dataset-path is provided."
+ )
+ print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
@@ -196,14 +198,16 @@ def main(args):
llm = LLM(**dataclasses.asdict(engine_args))
- sampling_params = SamplingParams(temperature=0,
- max_tokens=args.output_len,
- detokenize=not args.disable_detokenize)
+ sampling_params = SamplingParams(
+ temperature=0,
+ max_tokens=args.output_len,
+ detokenize=not args.disable_detokenize,
+ )
print("Testing filtered requests")
- prompts = repeat_and_sort_requests(filtered_requests,
- repeat_count=args.repeat_count,
- sort=args.sort)
+ prompts = repeat_and_sort_requests(
+ filtered_requests, repeat_count=args.repeat_count, sort=args.sort
+ )
print("------start generating------")
test_prefix(
@@ -215,29 +219,35 @@ def main(args):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description=
- 'Benchmark the performance with or without automatic prefix caching.')
- parser.add_argument("--dataset-path",
- type=str,
- default=None,
- help="Path to the dataset.")
- parser.add_argument('--output-len', type=int, default=10)
- parser.add_argument('--num-prompts',
- type=int,
- required=True,
- help="Number of the prompts sampled from dataset")
- parser.add_argument('--repeat-count',
- type=int,
- default=1,
- help='Number of times to repeat each prompt')
- parser.add_argument('--sort',
- action='store_true',
- help='Sort prompts by input length')
- parser.add_argument('--input-length-range',
- type=str,
- required=True,
- help='Range of input lengths for sampling prompts,'
- 'specified as "min:max" (e.g., "128:256").')
+ description="Benchmark the performance with or without "
+ "automatic prefix caching."
+ )
+ parser.add_argument(
+ "--dataset-path", type=str, default=None, help="Path to the dataset."
+ )
+ parser.add_argument("--output-len", type=int, default=10)
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ required=True,
+ help="Number of the prompts sampled from dataset",
+ )
+ parser.add_argument(
+ "--repeat-count",
+ type=int,
+ default=1,
+ help="Number of times to repeat each prompt",
+ )
+ parser.add_argument(
+ "--sort", action="store_true", help="Sort prompts by input length"
+ )
+ parser.add_argument(
+ "--input-length-range",
+ type=str,
+ required=True,
+ help="Range of input lengths for sampling prompts,"
+ 'specified as "min:max" (e.g., "128:256").',
+ )
parser.add_argument(
"--prefix-len",
type=int,
@@ -248,10 +258,12 @@ def main(args):
"when dataset-path is not provided.",
)
parser.add_argument(
- '--disable-detokenize',
- action='store_true',
- help=("Do not detokenize responses (i.e. do not include "
- "detokenization time in the latency measurement)"),
+ "--disable-detokenize",
+ action="store_true",
+ help=(
+ "Do not detokenize responses (i.e. do not include "
+ "detokenization time in the latency measurement)"
+ ),
)
parser = EngineArgs.add_cli_args(parser)
diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py
index 76fe00ede249..a05dd24dece8 100644
--- a/benchmarks/benchmark_prioritization.py
+++ b/benchmarks/benchmark_prioritization.py
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark offline prioritization."""
+
import argparse
import dataclasses
import json
@@ -13,7 +14,7 @@
from vllm.utils import FlexibleArgumentParser
-#Select a equi-probable random priority
+# Select a equi-probable random priority
def get_random_flag():
return 0 if random.random() < 0.5 else 1
@@ -33,8 +34,10 @@ def sample_requests(
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
- dataset = [(data["conversations"][0]["value"],
- data["conversations"][1]["value"]) for data in dataset]
+ dataset = [
+ (data["conversations"][0]["value"], data["conversations"][1]["value"])
+ for data in dataset
+ ]
# Shuffle the dataset.
random.shuffle(dataset)
@@ -51,8 +54,9 @@ def sample_requests(
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
- output_len = len(completion_token_ids
- ) if fixed_output_len is None else fixed_output_len
+ output_len = (
+ len(completion_token_ids) if fixed_output_len is None else fixed_output_len
+ )
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
@@ -74,13 +78,16 @@ def run_vllm(
disable_detokenize: bool = False,
) -> float:
from vllm import LLM, SamplingParams
+
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
- for request in requests), (
- "Please ensure that max_model_len is greater than the sum of"
- " input_len and output_len for all requests.")
+ for request in requests
+ ), (
+ "Please ensure that max_model_len is greater than the sum of"
+ " input_len and output_len for all requests."
+ )
# Add the requests to the engine.
prompts = []
@@ -97,7 +104,8 @@ def run_vllm(
ignore_eos=True,
max_tokens=output_len,
detokenize=not disable_detokenize,
- ))
+ )
+ )
start = time.perf_counter()
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
@@ -111,26 +119,33 @@ def main(args: argparse.Namespace):
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
- args.tokenizer, trust_remote_code=args.trust_remote_code)
+ args.tokenizer, trust_remote_code=args.trust_remote_code
+ )
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
- requests = [(prompt, args.input_len, args.output_len,
- get_random_flag()) for _ in range(args.num_prompts)]
+ requests = [
+ (prompt, args.input_len, args.output_len, get_random_flag())
+ for _ in range(args.num_prompts)
+ ]
else:
- requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
- args.output_len)
+ requests = sample_requests(
+ args.dataset, args.num_prompts, tokenizer, args.output_len
+ )
if args.backend == "vllm":
- elapsed_time = run_vllm(requests, args.n,
- EngineArgs.from_cli_args(args),
- args.disable_detokenize)
+ elapsed_time = run_vllm(
+ requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
+ )
else:
raise ValueError(f"Unknown backend: {args.backend}")
- total_num_tokens = sum(prompt_len + output_len
- for _, prompt_len, output_len, priority in requests)
- print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
- f"{total_num_tokens / elapsed_time:.2f} tokens/s")
+ total_num_tokens = sum(
+ prompt_len + output_len for _, prompt_len, output_len, priority in requests
+ )
+ print(
+ f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
+ f"{total_num_tokens / elapsed_time:.2f} tokens/s"
+ )
# Output JSON results if specified
if args.output_json:
@@ -147,41 +162,44 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
- parser.add_argument("--backend",
- type=str,
- choices=["vllm", "hf", "mii"],
- default="vllm")
- parser.add_argument("--dataset",
- type=str,
- default=None,
- help="Path to the dataset.")
- parser.add_argument("--input-len",
- type=int,
- default=None,
- help="Input prompt length for each request")
- parser.add_argument("--output-len",
- type=int,
- default=None,
- help="Output length for each request. Overrides the "
- "output length from the dataset.")
- parser.add_argument("--n",
- type=int,
- default=1,
- help="Number of generated sequences per prompt.")
- parser.add_argument("--num-prompts",
- type=int,
- default=200,
- help="Number of prompts to process.")
parser.add_argument(
- '--output-json',
+ "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
+ )
+ parser.add_argument(
+ "--dataset", type=str, default=None, help="Path to the dataset."
+ )
+ parser.add_argument(
+ "--input-len",
+ type=int,
+ default=None,
+ help="Input prompt length for each request",
+ )
+ parser.add_argument(
+ "--output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the "
+ "output length from the dataset.",
+ )
+ parser.add_argument(
+ "--n", type=int, default=1, help="Number of generated sequences per prompt."
+ )
+ parser.add_argument(
+ "--num-prompts", type=int, default=200, help="Number of prompts to process."
+ )
+ parser.add_argument(
+ "--output-json",
type=str,
default=None,
- help='Path to save the throughput results in JSON format.')
+ help="Path to save the throughput results in JSON format.",
+ )
parser.add_argument(
- '--disable-detokenize',
- action='store_true',
- help=("Do not detokenize responses (i.e. do not include "
- "detokenization time in the latency measurement)"),
+ "--disable-detokenize",
+ action="store_true",
+ help=(
+ "Do not detokenize responses (i.e. do not include "
+ "detokenization time in the latency measurement)"
+ ),
)
parser = EngineArgs.add_cli_args(parser)
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index 89fb0e1df035..a887e7150dc7 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -20,6 +20,7 @@
--endpoint /generate_stream
to the end of the command above.
"""
+
import argparse
import asyncio
import gc
@@ -34,12 +35,16 @@
from typing import Any, Optional
import numpy as np
-from backend_request_func import (ASYNC_REQUEST_FUNCS,
- OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
- RequestFuncOutput)
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
+from backend_request_func import (
+ ASYNC_REQUEST_FUNCS,
+ OPENAI_COMPATIBLE_BACKENDS,
+ RequestFuncInput,
+ RequestFuncOutput,
+)
+
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
@@ -50,12 +55,21 @@
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
-from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
- ConversationDataset, HuggingFaceDataset,
- InstructCoderDataset, MTBenchDataset,
- NextEditPredictionDataset, RandomDataset,
- SampleRequest, ShareGPTDataset, SonnetDataset,
- VisionArenaDataset)
+from benchmark_dataset import (
+ AIMODataset,
+ ASRDataset,
+ BurstGPTDataset,
+ ConversationDataset,
+ HuggingFaceDataset,
+ InstructCoderDataset,
+ MTBenchDataset,
+ NextEditPredictionDataset,
+ RandomDataset,
+ SampleRequest,
+ ShareGPTDataset,
+ SonnetDataset,
+ VisionArenaDataset,
+)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@@ -118,7 +132,8 @@ async def get_request(
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
- f"A positive burstiness factor is expected, but given {burstiness}.")
+ f"A positive burstiness factor is expected, but given {burstiness}."
+ )
theta = 1.0 / (request_rate * burstiness)
for request in input_requests:
@@ -164,8 +179,10 @@ def calculate_metrics(
# bundled together
# Note : this may inflate the output token count slightly
output_len = len(
- tokenizer(outputs[i].generated_text,
- add_special_tokens=False).input_ids)
+ tokenizer(
+ outputs[i].generated_text, add_special_tokens=False
+ ).input_ids
+ )
actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len
tpot = 0
@@ -188,16 +205,19 @@ def calculate_metrics(
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
- slo_values.append(goodput_config_dict["ttft"] /
- MILLISECONDS_TO_SECONDS_CONVERSION)
+ slo_values.append(
+ goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
+ )
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
- slo_values.append(goodput_config_dict["tpot"] /
- MILLISECONDS_TO_SECONDS_CONVERSION)
+ slo_values.append(
+ goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
+ )
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
- slo_values.append(goodput_config_dict["e2el"] /
- MILLISECONDS_TO_SECONDS_CONVERSION)
+ slo_values.append(
+ goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
+ )
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@@ -208,7 +228,8 @@ def calculate_metrics(
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
- stacklevel=2)
+ stacklevel=2,
+ )
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@@ -217,27 +238,31 @@ def calculate_metrics(
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
- mean_ttft_ms=np.mean(ttfts or 0) *
- 1000, # ttfts is empty if streaming is not supported by backend
+ mean_ttft_ms=np.mean(ttfts or 0)
+ * 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
- percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
- for p in selected_percentiles],
+ percentiles_ttft_ms=[
+ (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
+ ],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
- percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
- for p in selected_percentiles],
+ percentiles_tpot_ms=[
+ (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
+ ],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
- percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
- for p in selected_percentiles],
+ percentiles_itl_ms=[
+ (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
+ ],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
- percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
- for p in selected_percentiles],
+ percentiles_e2el_ms=[
+ (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
+ ],
)
return metrics, actual_output_lens
@@ -270,10 +295,12 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
- test_prompt, test_prompt_len, test_output_len, test_mm_content = \
- input_requests[0].prompt, input_requests[0].prompt_len, \
- input_requests[0].expected_output_len, \
- input_requests[0].multi_modal_data
+ test_prompt, test_prompt_len, test_output_len, test_mm_content = (
+ input_requests[0].prompt,
+ input_requests[0].prompt_len,
+ input_requests[0].expected_output_len,
+ input_requests[0].multi_modal_data,
+ )
assert test_mm_content is None or isinstance(test_mm_content, dict)
test_input = RequestFuncInput(
@@ -293,36 +320,36 @@ async def benchmark(
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
- f"are correctly specified. Error: {test_output.error}")
+ f"are correctly specified. Error: {test_output.error}"
+ )
else:
print("Initial test run completed. Starting main benchmark run...")
if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter(
- [random.choice(lora_modules) \
- for _ in range(len(input_requests))])
+ [random.choice(lora_modules) for _ in range(len(input_requests))]
+ )
if profile:
print("Starting profiler...")
- profile_input = RequestFuncInput(model=model_id,
- model_name=model_name,
- prompt=test_prompt,
- api_url=base_url + "/start_profile",
- prompt_len=test_prompt_len,
- output_len=test_output_len,
- logprobs=logprobs,
- multi_modal_content=test_mm_content,
- ignore_eos=ignore_eos,
- extra_body=extra_body)
+ profile_input = RequestFuncInput(
+ model=model_id,
+ model_name=model_name,
+ prompt=test_prompt,
+ api_url=base_url + "/start_profile",
+ prompt_len=test_prompt_len,
+ output_len=test_output_len,
+ logprobs=logprobs,
+ multi_modal_content=test_mm_content,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body,
+ )
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
- if burstiness == 1.0:
- distribution = "Poisson process"
- else:
- distribution = "Gamma distribution"
+ distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})")
@@ -334,42 +361,45 @@ async def benchmark(
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
- semaphore = (asyncio.Semaphore(max_concurrency)
- if max_concurrency else None)
+ semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
- return await request_func(request_func_input=request_func_input,
- pbar=pbar)
+ return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
- return await request_func(request_func_input=request_func_input,
- pbar=pbar)
+ return await request_func(request_func_input=request_func_input, pbar=pbar)
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
- prompt, prompt_len, output_len, mm_content = request.prompt, \
- request.prompt_len, request.expected_output_len, \
- request.multi_modal_data
+ prompt, prompt_len, output_len, mm_content = (
+ request.prompt,
+ request.prompt_len,
+ request.expected_output_len,
+ request.multi_modal_data,
+ )
req_model_id, req_model_name = model_id, model_name
if lora_modules:
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
- request_func_input = RequestFuncInput(model=req_model_id,
- model_name=req_model_name,
- prompt=prompt,
- api_url=api_url,
- prompt_len=prompt_len,
- output_len=output_len,
- logprobs=logprobs,
- multi_modal_content=mm_content,
- ignore_eos=ignore_eos,
- extra_body=extra_body)
+ request_func_input = RequestFuncInput(
+ model=req_model_id,
+ model_name=req_model_name,
+ prompt=prompt,
+ api_url=api_url,
+ prompt_len=prompt_len,
+ output_len=output_len,
+ logprobs=logprobs,
+ multi_modal_content=mm_content,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body,
+ )
tasks.append(
asyncio.create_task(
- limited_request_func(request_func_input=request_func_input,
- pbar=pbar)))
+ limited_request_func(request_func_input=request_func_input, pbar=pbar)
+ )
+ )
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
@@ -401,22 +431,32 @@ async def limited_request_func(request_func_input, pbar):
goodput_config_dict=goodput_config_dict,
)
- print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
+ print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
- print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
- benchmark_duration))
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
- print("{:<40} {:<10}".format("Total generated tokens:",
- metrics.total_output))
- print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
- metrics.request_throughput))
+ print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
+ print(
+ "{:<40} {:<10.2f}".format(
+ "Request throughput (req/s):", metrics.request_throughput
+ )
+ )
if goodput_config_dict:
- print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
- metrics.request_goodput))
- print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
- metrics.output_throughput))
- print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
- metrics.total_token_throughput))
+ print(
+ "{:<40} {:<10.2f}".format(
+ "Request goodput (req/s):", metrics.request_goodput
+ )
+ )
+ print(
+ "{:<40} {:<10.2f}".format(
+ "Output token throughput (tok/s):", metrics.output_throughput
+ )
+ )
+ print(
+ "{:<40} {:<10.2f}".format(
+ "Total Token throughput (tok/s):", metrics.total_token_throughput
+ )
+ )
result = {
"duration": benchmark_duration,
@@ -424,8 +464,7 @@ async def limited_request_func(request_func_input, pbar):
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
- "request_goodput:":
- metrics.request_goodput if goodput_config_dict else None,
+ "request_goodput:": metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
@@ -448,29 +487,35 @@ def process_one_metric(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
- print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
- print("{:<40} {:<10.2f}".format(
- f"Mean {metric_name} (ms):",
- getattr(metrics, f"mean_{metric_attribute_name}_ms")))
- print("{:<40} {:<10.2f}".format(
- f"Median {metric_name} (ms):",
- getattr(metrics, f"median_{metric_attribute_name}_ms")))
+ print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
+ print(
+ "{:<40} {:<10.2f}".format(
+ f"Mean {metric_name} (ms):",
+ getattr(metrics, f"mean_{metric_attribute_name}_ms"),
+ )
+ )
+ print(
+ "{:<40} {:<10.2f}".format(
+ f"Median {metric_name} (ms):",
+ getattr(metrics, f"median_{metric_attribute_name}_ms"),
+ )
+ )
result[f"mean_{metric_attribute_name}_ms"] = getattr(
- metrics, f"mean_{metric_attribute_name}_ms")
+ metrics, f"mean_{metric_attribute_name}_ms"
+ )
result[f"median_{metric_attribute_name}_ms"] = getattr(
- metrics, f"median_{metric_attribute_name}_ms")
+ metrics, f"median_{metric_attribute_name}_ms"
+ )
result[f"std_{metric_attribute_name}_ms"] = getattr(
- metrics, f"std_{metric_attribute_name}_ms")
- for p, value in getattr(metrics,
- f"percentiles_{metric_attribute_name}_ms"):
+ metrics, f"std_{metric_attribute_name}_ms"
+ )
+ for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
- print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
- value))
+ print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token")
- process_one_metric("tpot", "TPOT",
- "Time per Output Token (excl. 1st token)")
+ process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -490,12 +535,14 @@ def check_goodput_args(args):
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of "
- f"{str(VALID_NAMES)}. ")
+ f"{str(VALID_NAMES)}. "
+ )
if slo_val < 0:
raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
- "non-negative.")
+ "non-negative."
+ )
return goodput_config_dict
@@ -508,31 +555,42 @@ def parse_goodput(slo_pairs):
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
- "Specify service level objectives for goodput as \"KEY:VALUE\" "
+ 'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a "
- "number in milliseconds.") from err
+ "number in milliseconds."
+ ) from err
return goodput_config_dict
-def save_to_pytorch_benchmark_format(args: argparse.Namespace,
- results: dict[str, Any],
- file_name: str) -> None:
+def save_to_pytorch_benchmark_format(
+ args: argparse.Namespace, results: dict[str, Any], file_name: str
+) -> None:
metrics = [
- "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
- "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
- "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
+ "median_ttft_ms",
+ "mean_ttft_ms",
+ "std_ttft_ms",
+ "p99_ttft_ms",
+ "mean_tpot_ms",
+ "median_tpot_ms",
+ "std_tpot_ms",
+ "p99_tpot_ms",
+ "median_itl_ms",
+ "mean_itl_ms",
+ "std_itl_ms",
+ "p99_itl_ms",
]
# These raw data might be useful, but they are rather big. They can be added
# later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format(
args=args,
- metrics={k: [results[k]]
- for k in metrics},
+ metrics={k: [results[k]] for k in metrics},
extra_info={
k: results[k]
- for k in results if k not in metrics and k not in ignored_metrics
- })
+ for k in results
+ if k not in metrics and k not in ignored_metrics
+ },
+ )
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
@@ -557,34 +615,42 @@ def main(args: argparse.Namespace):
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
- tokenizer = get_tokenizer(tokenizer_id,
- tokenizer_mode=tokenizer_mode,
- trust_remote_code=args.trust_remote_code)
+ tokenizer = get_tokenizer(
+ tokenizer_id,
+ tokenizer_mode=tokenizer_mode,
+ trust_remote_code=args.trust_remote_code,
+ )
if args.dataset_name is None:
raise ValueError(
"Please specify '--dataset-name' and the corresponding "
- "'--dataset-path' if required.")
+ "'--dataset-path' if required."
+ )
if args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path)
# For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat":
- input_requests = dataset.sample(num_requests=args.num_prompts,
- input_len=args.sonnet_input_len,
- output_len=args.sonnet_output_len,
- prefix_len=args.sonnet_prefix_len,
- tokenizer=tokenizer,
- return_prompt_formatted=False)
+ input_requests = dataset.sample(
+ num_requests=args.num_prompts,
+ input_len=args.sonnet_input_len,
+ output_len=args.sonnet_output_len,
+ prefix_len=args.sonnet_prefix_len,
+ tokenizer=tokenizer,
+ return_prompt_formatted=False,
+ )
else:
assert tokenizer.chat_template or tokenizer.default_chat_template, (
- "Tokenizer/model must have chat template for sonnet dataset.")
- input_requests = dataset.sample(num_requests=args.num_prompts,
- input_len=args.sonnet_input_len,
- output_len=args.sonnet_output_len,
- prefix_len=args.sonnet_prefix_len,
- tokenizer=tokenizer,
- return_prompt_formatted=True)
+ "Tokenizer/model must have chat template for sonnet dataset."
+ )
+ input_requests = dataset.sample(
+ num_requests=args.num_prompts,
+ input_len=args.sonnet_input_len,
+ output_len=args.sonnet_output_len,
+ prefix_len=args.sonnet_prefix_len,
+ tokenizer=tokenizer,
+ return_prompt_formatted=True,
+ )
elif args.dataset_name == "hf":
# all following datasets are implemented from the
@@ -611,23 +677,30 @@ def main(args: argparse.Namespace):
dataset_class = ASRDataset
args.hf_split = "train"
else:
- supported_datasets = set([
- dataset_name for cls in HuggingFaceDataset.__subclasses__()
- for dataset_name in cls.SUPPORTED_DATASET_PATHS
- ])
+ supported_datasets = set(
+ [
+ dataset_name
+ for cls in HuggingFaceDataset.__subclasses__()
+ for dataset_name in cls.SUPPORTED_DATASET_PATHS
+ ]
+ )
raise ValueError(
f"Unsupported dataset path: {args.dataset_path}. "
"Huggingface dataset only supports dataset_path"
f" from one of following: {supported_datasets}. "
"Please consider contributing if you would "
- "like to add support for additional dataset formats.")
+ "like to add support for additional dataset formats."
+ )
- if (dataset_class.IS_MULTIMODAL and backend not in \
- ["openai-chat", "openai-audio"]):
+ if dataset_class.IS_MULTIMODAL and backend not in [
+ "openai-chat",
+ "openai-audio",
+ ]:
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
- "Multi-modal content is only supported on 'openai-chat' and " \
- "'openai-audio' backend.")
+ "Multi-modal content is only supported on 'openai-chat' and "
+ "'openai-audio' backend."
+ )
input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
@@ -642,26 +715,24 @@ def main(args: argparse.Namespace):
else:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
- "sharegpt":
- lambda: ShareGPTDataset(random_seed=args.seed,
- dataset_path=args.dataset_path).sample(
- tokenizer=tokenizer,
- num_requests=args.num_prompts,
- output_len=args.sharegpt_output_len,
- ),
- "burstgpt":
- lambda: BurstGPTDataset(random_seed=args.seed,
- dataset_path=args.dataset_path).
- sample(tokenizer=tokenizer, num_requests=args.num_prompts),
- "random":
- lambda: RandomDataset(dataset_path=args.dataset_path).sample(
+ "sharegpt": lambda: ShareGPTDataset(
+ random_seed=args.seed, dataset_path=args.dataset_path
+ ).sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ output_len=args.sharegpt_output_len,
+ ),
+ "burstgpt": lambda: BurstGPTDataset(
+ random_seed=args.seed, dataset_path=args.dataset_path
+ ).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
+ "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
- )
+ ),
}
try:
@@ -677,15 +748,16 @@ def main(args: argparse.Namespace):
"top_p": args.top_p,
"top_k": args.top_k,
"min_p": args.min_p,
- "temperature": args.temperature
- }.items() if v is not None
+ "temperature": args.temperature,
+ }.items()
+ if v is not None
}
# Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError(
- "Sampling parameters are only supported by openai-compatible "
- "backends.")
+ "Sampling parameters are only supported by openai-compatible backends."
+ )
if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
@@ -709,15 +781,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
- selected_percentiles=[
- float(p) for p in args.metric_percentiles.split(",")
- ],
+ selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
- ))
+ )
+ )
# Save config and results to json
if args.save_result or args.append_result:
@@ -742,8 +813,9 @@ def main(args: argparse.Namespace):
"Invalid metadata format. Please use KEY=VALUE format."
)
# Traffic
- result_json["request_rate"] = (args.request_rate if args.request_rate
- < float("inf") else "inf")
+ result_json["request_rate"] = (
+ args.request_rate if args.request_rate < float("inf") else "inf"
+ )
result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency
@@ -753,24 +825,31 @@ def main(args: argparse.Namespace):
if not args.save_detailed:
# Remove fields with too many data points
for field in [
- "input_lens", "output_lens", "ttfts", "itls",
- "generated_texts", "errors"
+ "input_lens",
+ "output_lens",
+ "ttfts",
+ "itls",
+ "generated_texts",
+ "errors",
]:
if field in result_json:
del result_json[field]
# Save to file
base_model_id = model_id.split("/")[-1]
- max_concurrency_str = (f"-concurrency{args.max_concurrency}"
- if args.max_concurrency is not None else "")
- file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
+ max_concurrency_str = (
+ f"-concurrency{args.max_concurrency}"
+ if args.max_concurrency is not None
+ else ""
+ )
+ file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
file_name = os.path.join(args.result_dir, file_name)
- with open(file_name,
- mode="a+" if args.append_result else "w",
- encoding='utf-8') as outfile:
+ with open(
+ file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
+ ) as outfile:
# Append a newline.
if args.append_result and outfile.tell() != 0:
outfile.write("\n")
@@ -780,7 +859,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description="Benchmark the online serving throughput.")
+ description="Benchmark the online serving throughput."
+ )
parser.add_argument(
"--backend",
type=str,
@@ -809,11 +889,13 @@ def main(args: argparse.Namespace):
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
help="Name of the dataset to benchmark on.",
)
- parser.add_argument("--dataset-path",
- type=str,
- default=None,
- help="Path to the sharegpt/sonnet dataset. "
- "Or the huggingface dataset ID if using HF dataset.")
+ parser.add_argument(
+ "--dataset-path",
+ type=str,
+ default=None,
+ help="Path to the sharegpt/sonnet dataset. "
+ "Or the huggingface dataset ID if using HF dataset.",
+ )
parser.add_argument(
"--max-concurrency",
type=int,
@@ -825,7 +907,8 @@ def main(args: argparse.Namespace):
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
- "if the server is not processing requests fast enough to keep up.")
+ "if the server is not processing requests fast enough to keep up.",
+ )
parser.add_argument(
"--model",
@@ -836,8 +919,7 @@ def main(args: argparse.Namespace):
parser.add_argument(
"--tokenizer",
type=str,
- help=
- "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
+ help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
@@ -850,11 +932,13 @@ def main(args: argparse.Namespace):
"--logprobs",
type=int,
default=None,
- help=("Number of logprobs-per-token to compute & return as part of "
- "the request. If unspecified, then either (1) if beam search "
- "is disabled, no logprobs are computed & a single dummy "
- "logprob is returned for each token; or (2) if beam search "
- "is enabled 1 logprob per token is computed"),
+ help=(
+ "Number of logprobs-per-token to compute & return as part of "
+ "the request. If unspecified, then either (1) if beam search "
+ "is disabled, no logprobs are computed & a single dummy "
+ "logprob is returned for each token; or (2) if beam search "
+ "is enabled 1 logprob per token is computed"
+ ),
)
parser.add_argument(
"--request-rate",
@@ -938,35 +1022,38 @@ def main(args: argparse.Namespace):
"--ignore-eos",
action="store_true",
help="Set ignore_eos flag when sending the benchmark request."
- "Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
+ "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
+ )
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
- "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
- "Default value is \"ttft,tpot,itl\".")
+ 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
+ 'Default value is "ttft,tpot,itl".',
+ )
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles for selected metrics. "
- "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
- "Default value is \"99\". "
- "Use \"--percentile-metrics\" to select metrics.",
+ 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
+ 'Default value is "99". '
+ 'Use "--percentile-metrics" to select metrics.',
)
parser.add_argument(
"--goodput",
nargs="+",
required=False,
- help="Specify service level objectives for goodput as \"KEY:VALUE\" "
+ help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in "
- "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
+ 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are "
- "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
+ '"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
- "and the blog: https://hao-ai-lab.github.io/blogs/distserve")
+ "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
+ )
# group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options")
@@ -974,22 +1061,19 @@ def main(args: argparse.Namespace):
"--sonnet-input-len",
type=int,
default=550,
- help=
- "Number of input tokens per request, used only for sonnet dataset.",
+ help="Number of input tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-output-len",
type=int,
default=150,
- help=
- "Number of output tokens per request, used only for sonnet dataset.",
+ help="Number of output tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
- help=
- "Number of prefix tokens per request, used only for sonnet dataset.",
+ help="Number of prefix tokens per request, used only for sonnet dataset.",
)
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
@@ -998,22 +1082,21 @@ def main(args: argparse.Namespace):
type=int,
default=None,
help="Output length for each request. Overrides the output length "
- "from the ShareGPT dataset.")
+ "from the ShareGPT dataset.",
+ )
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
"--random-input-len",
type=int,
default=1024,
- help=
- "Number of input tokens per request, used only for random sampling.",
+ help="Number of input tokens per request, used only for random sampling.",
)
random_group.add_argument(
"--random-output-len",
type=int,
default=128,
- help=
- "Number of output tokens per request, used only for random sampling.",
+ help="Number of output tokens per request, used only for random sampling.",
)
random_group.add_argument(
"--random-range-ratio",
@@ -1028,23 +1111,23 @@ def main(args: argparse.Namespace):
"--random-prefix-len",
type=int,
default=0,
- help=("Number of fixed prefix tokens before the random context "
- "in a request. "
- "The total input length is the sum of `random-prefix-len` and "
- "a random "
- "context length sampled from [input_len * (1 - range_ratio), "
- "input_len * (1 + range_ratio)]."),
+ help=(
+ "Number of fixed prefix tokens before the random context "
+ "in a request. "
+ "The total input length is the sum of `random-prefix-len` and "
+ "a random "
+ "context length sampled from [input_len * (1 - range_ratio), "
+ "input_len * (1 + range_ratio)]."
+ ),
)
hf_group = parser.add_argument_group("hf dataset options")
- hf_group.add_argument("--hf-subset",
- type=str,
- default=None,
- help="Subset of the HF dataset.")
- hf_group.add_argument("--hf-split",
- type=str,
- default=None,
- help="Split of the HF dataset.")
+ hf_group.add_argument(
+ "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
+ )
+ hf_group.add_argument(
+ "--hf-split", type=str, default=None, help="Split of the HF dataset."
+ )
hf_group.add_argument(
"--hf-output-len",
type=int,
@@ -1058,52 +1141,58 @@ def main(args: argparse.Namespace):
"--top-p",
type=float,
default=None,
- help="Top-p sampling parameter. Only has effect on openai-compatible "
- "backends.")
+ help="Top-p sampling parameter. Only has effect on openai-compatible backends.",
+ )
sampling_group.add_argument(
"--top-k",
type=int,
default=None,
- help="Top-k sampling parameter. Only has effect on openai-compatible "
- "backends.")
+ help="Top-k sampling parameter. Only has effect on openai-compatible backends.",
+ )
sampling_group.add_argument(
"--min-p",
type=float,
default=None,
- help="Min-p sampling parameter. Only has effect on openai-compatible "
- "backends.")
+ help="Min-p sampling parameter. Only has effect on openai-compatible backends.",
+ )
sampling_group.add_argument(
"--temperature",
type=float,
default=None,
help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy "
- "decoding (i.e. temperature==0.0).")
+ "decoding (i.e. temperature==0.0).",
+ )
parser.add_argument(
- '--tokenizer-mode',
+ "--tokenizer-mode",
type=str,
default="auto",
- choices=['auto', 'slow', 'mistral', 'custom'],
+ choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
- 'always use the slow tokenizer. \n* '
+ "always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*'
- '"custom" will use --tokenizer to select the preregistered tokenizer.')
-
- parser.add_argument("--served-model-name",
- type=str,
- default=None,
- help="The model name used in the API. "
- "If not specified, the model name will be the "
- "same as the ``--model`` argument. ")
-
- parser.add_argument("--lora-modules",
- nargs='+',
- default=None,
- help="A subset of LoRA module names passed in when "
- "launching the server. For each request, the "
- "script chooses a LoRA module at random.")
+ '"custom" will use --tokenizer to select the preregistered tokenizer.',
+ )
+
+ parser.add_argument(
+ "--served-model-name",
+ type=str,
+ default=None,
+ help="The model name used in the API. "
+ "If not specified, the model name will be the "
+ "same as the ``--model`` argument. ",
+ )
+
+ parser.add_argument(
+ "--lora-modules",
+ nargs="+",
+ default=None,
+ help="A subset of LoRA module names passed in when "
+ "launching the server. For each request, the "
+ "script chooses a LoRA module at random.",
+ )
args = parser.parse_args()
diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py
index 9084255d2440..6a50f47d3951 100644
--- a/benchmarks/benchmark_serving_structured_output.py
+++ b/benchmarks/benchmark_serving_structured_output.py
@@ -19,6 +19,7 @@
--endpoint /generate_stream
to the end of the command above.
"""
+
import argparse
import asyncio
import copy
@@ -36,11 +37,15 @@
import datasets
import numpy as np
import pandas as pd
-from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
- RequestFuncOutput)
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
+from backend_request_func import (
+ ASYNC_REQUEST_FUNCS,
+ RequestFuncInput,
+ RequestFuncOutput,
+)
+
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
@@ -52,7 +57,8 @@
from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.v1.structured_output.backend_xgrammar import (
- has_xgrammar_unsupported_json_features)
+ has_xgrammar_unsupported_json_features,
+)
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@@ -98,6 +104,7 @@ class SampleRequest:
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
+
prompt: str
prompt_len: int
expected_output_len: int
@@ -106,32 +113,28 @@ class SampleRequest:
completion: str = None
-def sample_requests(tokenizer: PreTrainedTokenizerBase,
- args: argparse.Namespace) -> list[SampleRequest]:
- if args.dataset == 'json' or args.dataset == 'json-unique':
+def sample_requests(
+ tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace
+) -> list[SampleRequest]:
+ if args.dataset == "json" or args.dataset == "json-unique":
if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__))
- args.json_schema_path = os.path.join(dir_path,
- "structured_schemas",
- "structured_schema_1.json")
+ args.json_schema_path = os.path.join(
+ dir_path, "structured_schemas", "structured_schema_1.json"
+ )
json_schemas = []
with open(args.json_schema_path) as f:
schema = json.load(f)
- if args.dataset == 'json-unique':
- json_schemas = [
- copy.deepcopy(schema) for _ in range(args.num_prompts)
- ]
+ if args.dataset == "json-unique":
+ json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)]
for i in range(len(json_schemas)):
if "properties" not in json_schemas[i]:
json_schemas[i]["properties"] = {}
- json_schemas[i]["properties"][
- f"__optional_field_{uuid.uuid4()}"] = {
- "type":
- "string",
- "description":
- "An unique optional field to avoid cached schemas"
- }
+ json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = {
+ "type": "string",
+ "description": "An unique optional field to avoid cached schemas",
+ }
else:
json_schemas = [schema] * args.num_prompts
@@ -142,11 +145,13 @@ def get_schema(index: int):
return json_schemas[index % len(json_schemas)]
requests = [
- SampleRequest(prompt=gen_prompt(i),
- prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
- expected_output_len=args.output_len,
- schema=get_schema(i),
- structure_type=args.structure_type)
+ SampleRequest(
+ prompt=gen_prompt(i),
+ prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
+ expected_output_len=args.output_len,
+ schema=get_schema(i),
+ structure_type=args.structure_type,
+ )
for i in range(args.num_prompts)
]
@@ -170,11 +175,13 @@ def get_schema(index: int):
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
- SampleRequest(prompt=prompt,
- prompt_len=input_len,
- expected_output_len=args.output_len,
- schema=schema,
- structure_type=args.structure_type)
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=schema,
+ structure_type=args.structure_type,
+ )
for _ in range(args.num_prompts)
]
@@ -188,11 +195,13 @@ def get_schema(index: int):
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
- SampleRequest(prompt=prompt,
- prompt_len=input_len,
- expected_output_len=args.output_len,
- schema=regex,
- structure_type=args.structure_type)
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=regex,
+ structure_type=args.structure_type,
+ )
for _ in range(args.num_prompts)
]
@@ -203,48 +212,55 @@ def get_schema(index: int):
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
- SampleRequest(prompt=prompt,
- prompt_len=input_len,
- expected_output_len=args.output_len,
- schema=choice,
- structure_type=args.structure_type)
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=choice,
+ structure_type=args.structure_type,
+ )
for _ in range(args.num_prompts)
]
elif args.dataset == "xgrammar_bench":
requests: list[SampleRequest] = []
- dataset = datasets.load_dataset("NousResearch/json-mode-eval",
- split="train")
+ dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train")
full_dataset_len = len(dataset)
def _filter_func(item):
import json
+
schema = json.loads(item["schema"])
return not has_xgrammar_unsupported_json_features(schema)
dataset = dataset.filter(_filter_func)
num_filtered_out = full_dataset_len - len(dataset)
- print(f"dataset has {len(dataset)} entries after filtering "
- f"out {num_filtered_out} entries with unsupported features")
+ print(
+ f"dataset has {len(dataset)} entries after filtering "
+ f"out {num_filtered_out} entries with unsupported features"
+ )
len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts):
idx = data_point_idx
while idx >= len_dataset:
idx -= len_dataset
schema = dataset["schema"][idx]
- prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
- tokenize=False,
- add_generation_prompt=True)
+ prompt = tokenizer.apply_chat_template(
+ dataset["prompt"][idx], tokenize=False, add_generation_prompt=True
+ )
input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx]
requests.append(
- SampleRequest(prompt=prompt,
- prompt_len=input_len,
- expected_output_len=args.output_len,
- schema=schema,
- structure_type=args.structure_type,
- completion=completion))
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=schema,
+ structure_type=args.structure_type,
+ completion=completion,
+ )
+ )
return requests
@@ -276,7 +292,8 @@ async def get_request(
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
- f"A positive burstiness factor is expected, but given {burstiness}.")
+ f"A positive burstiness factor is expected, but given {burstiness}."
+ )
theta = 1.0 / (request_rate * burstiness)
for i, request in enumerate(input_requests):
@@ -318,8 +335,8 @@ def calculate_metrics(
# multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly
output_len = len(
- tokenizer(outputs[i].generated_text,
- add_special_tokens=False).input_ids)
+ tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
+ )
actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len
tpot = 0
@@ -343,16 +360,19 @@ def calculate_metrics(
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
- slo_values.append(goodput_config_dict["ttft"] /
- MILLISECONDS_TO_SECONDS_CONVERSION)
+ slo_values.append(
+ goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
+ )
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
- slo_values.append(goodput_config_dict["tpot"] /
- MILLISECONDS_TO_SECONDS_CONVERSION)
+ slo_values.append(
+ goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
+ )
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
- slo_values.append(goodput_config_dict["e2el"] /
- MILLISECONDS_TO_SECONDS_CONVERSION)
+ slo_values.append(
+ goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
+ )
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@@ -363,7 +383,8 @@ def calculate_metrics(
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
- stacklevel=2)
+ stacklevel=2,
+ )
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@@ -372,27 +393,31 @@ def calculate_metrics(
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
- mean_ttft_ms=np.mean(ttfts or 0) *
- 1000, # ttfts is empty if streaming is not supported by backend
+ mean_ttft_ms=np.mean(ttfts or 0)
+ * 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
- percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
- for p in selected_percentiles],
+ percentiles_ttft_ms=[
+ (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
+ ],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
- percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
- for p in selected_percentiles],
+ percentiles_tpot_ms=[
+ (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
+ ],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
- percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
- for p in selected_percentiles],
+ percentiles_itl_ms=[
+ (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
+ ],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
- percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
- for p in selected_percentiles],
+ percentiles_e2el_ms=[
+ (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
+ ],
)
return metrics, actual_output_lens
@@ -429,12 +454,13 @@ def prepare_extra_body(request) -> dict:
print("Starting initial single prompt test run...")
structured_output_req_idx = random.sample(
- range(len(input_requests)),
- int(len(input_requests) * structured_output_ratio))
+ range(len(input_requests)), int(len(input_requests) * structured_output_ratio)
+ )
test_request = input_requests[0]
- test_req_extra_body = (prepare_extra_body(test_request)
- if 0 in structured_output_req_idx else None)
+ test_req_extra_body = (
+ prepare_extra_body(test_request) if 0 in structured_output_req_idx else None
+ )
test_input = RequestFuncInput(
model=model_id,
prompt=test_request.prompt,
@@ -448,7 +474,8 @@ def prepare_extra_body(request) -> dict:
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
- f"are correctly specified. Error: {test_output.error}")
+ f"are correctly specified. Error: {test_output.error}"
+ )
else:
print("Initial test run completed. Starting main benchmark run...")
@@ -467,10 +494,7 @@ def prepare_extra_body(request) -> dict:
if profile_output.success:
print("Profiler started")
- if burstiness == 1.0:
- distribution = "Poisson process"
- else:
- distribution = "Gamma distribution"
+ distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})")
@@ -482,24 +506,21 @@ def prepare_extra_body(request) -> dict:
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
- semaphore = (asyncio.Semaphore(max_concurrency)
- if max_concurrency else None)
+ semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
- return await request_func(request_func_input=request_func_input,
- pbar=pbar)
+ return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
- return await request_func(request_func_input=request_func_input,
- pbar=pbar)
+ return await request_func(request_func_input=request_func_input, pbar=pbar)
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
expected: list[str] = []
- async for i, request in get_request(input_requests, request_rate,
- burstiness):
- extra_body = prepare_extra_body(
- request) if i in structured_output_req_idx else None
+ async for i, request in get_request(input_requests, request_rate, burstiness):
+ extra_body = (
+ prepare_extra_body(request) if i in structured_output_req_idx else None
+ )
request_func_input = RequestFuncInput(
model=model_id,
prompt=request.prompt,
@@ -512,8 +533,9 @@ async def limited_request_func(request_func_input, pbar):
expected.append(request.completion)
tasks.append(
asyncio.create_task(
- limited_request_func(request_func_input=request_func_input,
- pbar=pbar)))
+ limited_request_func(request_func_input=request_func_input, pbar=pbar)
+ )
+ )
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
@@ -545,54 +567,58 @@ async def limited_request_func(request_func_input, pbar):
goodput_config_dict=goodput_config_dict,
)
- print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
+ print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
- print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
- benchmark_duration))
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
- print("{:<40} {:<10}".format("Total generated tokens:",
- metrics.total_output))
- print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
- metrics.request_throughput))
+ print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
+ print(
+ "{:<40} {:<10.2f}".format(
+ "Request throughput (req/s):", metrics.request_throughput
+ )
+ )
if goodput_config_dict:
- print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
- metrics.request_goodput))
- print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
- metrics.output_throughput))
- print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
- metrics.total_token_throughput))
+ print(
+ "{:<40} {:<10.2f}".format(
+ "Request goodput (req/s):", metrics.request_goodput
+ )
+ )
+ print(
+ "{:<40} {:<10.2f}".format(
+ "Output token throughput (tok/s):", metrics.output_throughput
+ )
+ )
+ print(
+ "{:<40} {:<10.2f}".format(
+ "Total Token throughput (tok/s):", metrics.total_token_throughput
+ )
+ )
result = {
- "duration":
- benchmark_duration,
- "completed":
- metrics.completed,
- "total_input_tokens":
- metrics.total_input,
- "total_output_tokens":
- metrics.total_output,
- "request_throughput":
- metrics.request_throughput,
- "output_throughput":
- metrics.output_throughput,
- "total_token_throughput":
- metrics.total_token_throughput,
- "ttft_description":
- pd.Series([output.ttft for output in outputs]).describe().to_dict(),
- "tpot_description":
- pd.Series([output.tpot for output in outputs]).describe().to_dict(),
+ "duration": benchmark_duration,
+ "completed": metrics.completed,
+ "total_input_tokens": metrics.total_input,
+ "total_output_tokens": metrics.total_output,
+ "request_throughput": metrics.request_throughput,
+ "output_throughput": metrics.output_throughput,
+ "total_token_throughput": metrics.total_token_throughput,
+ "ttft_description": pd.Series([output.ttft for output in outputs])
+ .describe()
+ .to_dict(),
+ "tpot_description": pd.Series([output.tpot for output in outputs])
+ .describe()
+ .to_dict(),
"input_lens": [output.prompt_len for output in outputs],
- "output_lens":
- actual_output_lens,
+ "output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"errors": [output.error for output in outputs],
}
- ret = [{
- 'generated': output.generated_text,
- 'expected': gt
- } for output, gt in zip(outputs, expected)]
+ ret = [
+ {"generated": output.generated_text, "expected": gt}
+ for output, gt in zip(outputs, expected)
+ ]
def process_one_metric(
# E.g., "ttft"
@@ -606,29 +632,35 @@ def process_one_metric(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
- print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
- print("{:<40} {:<10.2f}".format(
- f"Mean {metric_name} (ms):",
- getattr(metrics, f"mean_{metric_attribute_name}_ms")))
- print("{:<40} {:<10.2f}".format(
- f"Median {metric_name} (ms):",
- getattr(metrics, f"median_{metric_attribute_name}_ms")))
+ print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
+ print(
+ "{:<40} {:<10.2f}".format(
+ f"Mean {metric_name} (ms):",
+ getattr(metrics, f"mean_{metric_attribute_name}_ms"),
+ )
+ )
+ print(
+ "{:<40} {:<10.2f}".format(
+ f"Median {metric_name} (ms):",
+ getattr(metrics, f"median_{metric_attribute_name}_ms"),
+ )
+ )
result[f"mean_{metric_attribute_name}_ms"] = getattr(
- metrics, f"mean_{metric_attribute_name}_ms")
+ metrics, f"mean_{metric_attribute_name}_ms"
+ )
result[f"median_{metric_attribute_name}_ms"] = getattr(
- metrics, f"median_{metric_attribute_name}_ms")
+ metrics, f"median_{metric_attribute_name}_ms"
+ )
result[f"std_{metric_attribute_name}_ms"] = getattr(
- metrics, f"std_{metric_attribute_name}_ms")
- for p, value in getattr(metrics,
- f"percentiles_{metric_attribute_name}_ms"):
+ metrics, f"std_{metric_attribute_name}_ms"
+ )
+ for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
- print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
- value))
+ print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token")
- process_one_metric("tpot", "TPOT",
- "Time per Output Token (excl. 1st token)")
+ process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -638,13 +670,13 @@ def process_one_metric(
def evaluate(ret, args):
-
def _eval_correctness_json(expected, actual):
# extract json string from string using regex
- import re
- actual = actual.replace('\n', '').replace(' ', '').strip()
+ import regex as re
+
+ actual = actual.replace("\n", "").replace(" ", "").strip()
try:
- actual = re.search(r'\{.*\}', actual).group()
+ actual = re.search(r"\{.*\}", actual).group()
actual = json.loads(actual)
except Exception:
return False
@@ -655,29 +687,33 @@ def _eval_correctness_choice(expected, actual):
return actual in args.choice
def _eval_correctness_regex(expected, actual):
- import re
+ import regex as re
+
return re.match(args.regex, actual) is not None
def _eval_correctness(expected, actual):
- if args.structure_type == 'guided_json':
+ if args.structure_type == "guided_json":
return _eval_correctness_json(expected, actual)
- elif args.structure_type == 'guided_regex':
+ elif args.structure_type == "guided_regex":
return _eval_correctness_regex(expected, actual)
- elif args.structure_type == 'guided_choice':
+ elif args.structure_type == "guided_choice":
return _eval_correctness_choice(expected, actual)
else:
return None
scores = []
for res in ret:
- score = _eval_correctness(res['expected'], res['generated'])
- res['correctness'] = score
+ score = _eval_correctness(res["expected"], res["generated"])
+ res["correctness"] = score
scores.append(score)
not_none_scores = [score for score in scores if score is not None]
- return (sum(not_none_scores) / len(not_none_scores) *
- 100) if len(not_none_scores) > 0 else None
+ return (
+ (sum(not_none_scores) / len(not_none_scores) * 100)
+ if len(not_none_scores) > 0
+ else None
+ )
def parse_goodput(slo_pairs):
@@ -689,9 +725,10 @@ def parse_goodput(slo_pairs):
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
- "Specify service level objectives for goodput as \"KEY:VALUE\" "
+ 'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a "
- "number in milliseconds.") from err
+ "number in milliseconds."
+ ) from err
return goodput_config_dict
@@ -705,12 +742,14 @@ def check_goodput_args(args):
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of "
- f"{str(VALID_NAMES)}. ")
+ f"{str(VALID_NAMES)}. "
+ )
if slo_val < 0:
raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
- "non-negative.")
+ "non-negative."
+ )
return goodput_config_dict
@@ -736,19 +775,19 @@ def main(args: argparse.Namespace):
tokenizer_mode=args.tokenizer_mode,
)
- if args.dataset == 'grammar':
- args.structure_type = 'guided_grammar'
- elif args.dataset == 'regex':
- args.structure_type = 'guided_regex'
- elif args.dataset == 'choice':
- args.structure_type = 'guided_choice'
+ if args.dataset == "grammar":
+ args.structure_type = "guided_grammar"
+ elif args.dataset == "regex":
+ args.structure_type = "guided_regex"
+ elif args.dataset == "choice":
+ args.structure_type = "guided_choice"
else:
- args.structure_type = 'guided_json'
+ args.structure_type = "guided_json"
if args.no_structured_output:
args.structured_output_ratio = 0
if args.save_results:
- result_file_name = f'{args.structured_output_ratio}guided'
+ result_file_name = f"{args.structured_output_ratio}guided"
result_file_name += f"_{backend}"
result_file_name += f"_{args.request_rate}qps"
result_file_name += f"_{args.model.split('/')[-1]}"
@@ -776,36 +815,29 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
- selected_percentiles=[
- float(p) for p in args.metric_percentiles.split(",")
- ],
+ selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency,
structured_output_ratio=args.structured_output_ratio,
goodput_config_dict=goodput_config_dict,
- ))
+ )
+ )
# Save config and results to json
score = evaluate(ret, args)
- print("correct_rate(%)", score, '\n')
+ print("correct_rate(%)", score, "\n")
if args.save_results:
results = {
- "backend":
- backend,
- "model_id":
- model_id,
- "tokenizer_id":
- tokenizer_id,
- "num_prompts":
- args.num_prompts,
- "request_rate":
- args.request_rate if args.request_rate < float("inf") else "inf",
- "burstiness":
- args.burstiness,
- "max_concurrency":
- args.max_concurrency,
- "correct_rate(%)":
- score
+ "backend": backend,
+ "model_id": model_id,
+ "tokenizer_id": tokenizer_id,
+ "num_prompts": args.num_prompts,
+ "request_rate": args.request_rate
+ if args.request_rate < float("inf")
+ else "inf",
+ "burstiness": args.burstiness,
+ "max_concurrency": args.max_concurrency,
+ "correct_rate(%)": score,
}
results = {"outputs": ret, **results, **benchmark_result}
@@ -814,13 +846,14 @@ def main(args: argparse.Namespace):
result_file_name = args.result_filename
if args.result_dir:
result_file_name = os.path.join(args.result_dir, result_file_name)
- with open(result_file_name, "w", encoding='utf-8') as outfile:
+ with open(result_file_name, "w", encoding="utf-8") as outfile:
json.dump(results, outfile, indent=4)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description="Benchmark the online serving throughput.")
+ description="Benchmark the online serving throughput."
+ )
parser.add_argument(
"--backend",
type=str,
@@ -842,16 +875,14 @@ def main(args: argparse.Namespace):
default="/v1/completions",
help="API endpoint.",
)
- parser.add_argument("--dataset",
- default='json',
- choices=[
- 'json', 'json-unique', 'grammar', 'regex',
- 'choice', 'xgrammar_bench'
- ])
- parser.add_argument("--json-schema-path",
- type=str,
- default=None,
- help="Path to json schema.")
+ parser.add_argument(
+ "--dataset",
+ default="json",
+ choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"],
+ )
+ parser.add_argument(
+ "--json-schema-path", type=str, default=None, help="Path to json schema."
+ )
parser.add_argument(
"--max-concurrency",
type=int,
@@ -863,7 +894,8 @@ def main(args: argparse.Namespace):
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
- "if the server is not processing requests fast enough to keep up.")
+ "if the server is not processing requests fast enough to keep up.",
+ )
parser.add_argument(
"--model",
type=str,
@@ -873,15 +905,13 @@ def main(args: argparse.Namespace):
parser.add_argument(
"--tokenizer",
type=str,
- help=
- "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
+ help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
- help=
- "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
+ help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--num-prompts",
@@ -958,44 +988,51 @@ def main(args: argparse.Namespace):
"--ignore-eos",
action="store_true",
help="Set ignore_eos flag when sending the benchmark request."
- "Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
+ "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
+ )
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
- "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
- "Default value is \"ttft,tpot,itl\".")
+ 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
+ 'Default value is "ttft,tpot,itl".',
+ )
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles for selected metrics. "
- "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
- "Default value is \"99\". "
- "Use \"--percentile-metrics\" to select metrics.",
+ 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
+ 'Default value is "99". '
+ 'Use "--percentile-metrics" to select metrics.',
)
parser.add_argument(
"--goodput",
nargs="+",
required=False,
- help="Specify service level objectives for goodput as \"KEY:VALUE\" "
+ help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in "
- "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
+ 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are "
- "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
+ '"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
- "and the blog: https://hao-ai-lab.github.io/blogs/distserve")
-
- parser.add_argument("--no-structured-output",
- action='store_true',
- default=False,
- help="Whether to disable JSON decoding or not.")
- parser.add_argument("--structured-output-ratio",
- type=float,
- default=1.0,
- help="Ratio of Structured Outputs requests")
+ "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
+ )
+
+ parser.add_argument(
+ "--no-structured-output",
+ action="store_true",
+ default=False,
+ help="Whether to disable JSON decoding or not.",
+ )
+ parser.add_argument(
+ "--structured-output-ratio",
+ type=float,
+ default=1.0,
+ help="Ratio of Structured Outputs requests",
+ )
args = parser.parse_args()
main(args)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index 1f65277e1bfe..7a13babda9d1 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark offline inference throughput."""
+
import argparse
import dataclasses
import json
@@ -11,18 +12,25 @@
import torch
import uvloop
-from benchmark_dataset import (AIMODataset, BurstGPTDataset,
- ConversationDataset, InstructCoderDataset,
- RandomDataset, SampleRequest, ShareGPTDataset,
- SonnetDataset, VisionArenaDataset)
-from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- PreTrainedTokenizerBase)
-
+from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
+
+from benchmark_dataset import (
+ AIMODataset,
+ BurstGPTDataset,
+ ConversationDataset,
+ InstructCoderDataset,
+ RandomDataset,
+ SampleRequest,
+ ShareGPTDataset,
+ SonnetDataset,
+ VisionArenaDataset,
+)
+from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
- build_async_engine_client_from_engine_args)
+ build_async_engine_client_from_engine_args,
+)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
@@ -37,23 +45,30 @@ def run_vllm(
disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]:
from vllm import LLM, SamplingParams
+
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
- llm.llm_engine.model_config.max_model_len >= (
- request.prompt_len + request.expected_output_len)
- for request in requests), (
- "Please ensure that max_model_len is greater than the sum of"
- " prompt_len and expected_output_len for all requests.")
+ llm.llm_engine.model_config.max_model_len
+ >= (request.prompt_len + request.expected_output_len)
+ for request in requests
+ ), (
+ "Please ensure that max_model_len is greater than the sum of"
+ " prompt_len and expected_output_len for all requests."
+ )
# Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(
- TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
- multi_modal_data=request.multi_modal_data)
- if "prompt_token_ids" in request.prompt else \
- TextPrompt(prompt=request.prompt,
- multi_modal_data=request.multi_modal_data))
+ TokensPrompt(
+ prompt_token_ids=request.prompt["prompt_token_ids"],
+ multi_modal_data=request.multi_modal_data,
+ )
+ if "prompt_token_ids" in request.prompt
+ else TextPrompt(
+ prompt=request.prompt, multi_modal_data=request.multi_modal_data
+ )
+ )
sampling_params.append(
SamplingParams(
n=n,
@@ -62,7 +77,8 @@ def run_vllm(
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
- ))
+ )
+ )
lora_requests: Optional[list[LoRARequest]] = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]
@@ -72,10 +88,9 @@ def run_vllm(
outputs = None
if not use_beam_search:
start = time.perf_counter()
- outputs = llm.generate(prompts,
- sampling_params,
- lora_request=lora_requests,
- use_tqdm=True)
+ outputs = llm.generate(
+ prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
+ )
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
@@ -91,30 +106,35 @@ def run_vllm(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
- ))
+ ),
+ )
end = time.perf_counter()
return end - start, outputs
def run_vllm_chat(
- requests: list[SampleRequest],
- n: int,
- engine_args: EngineArgs,
- disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
+ requests: list[SampleRequest],
+ n: int,
+ engine_args: EngineArgs,
+ disable_detokenize: bool = False,
+) -> tuple[float, list[RequestOutput]]:
"""
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
multimodal models as it properly handles multimodal inputs and chat
formatting. For non-multimodal models, use run_vllm() instead.
"""
from vllm import LLM, SamplingParams
+
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
- llm.llm_engine.model_config.max_model_len >= (
- request.prompt_len + request.expected_output_len)
- for request in requests), (
- "Please ensure that max_model_len is greater than the sum of "
- "prompt_len and expected_output_len for all requests.")
+ llm.llm_engine.model_config.max_model_len
+ >= (request.prompt_len + request.expected_output_len)
+ for request in requests
+ ), (
+ "Please ensure that max_model_len is greater than the sum of "
+ "prompt_len and expected_output_len for all requests."
+ )
prompts = []
sampling_params: list[SamplingParams] = []
@@ -128,7 +148,8 @@ def run_vllm_chat(
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
- ))
+ )
+ )
start = time.perf_counter()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
@@ -145,13 +166,17 @@ async def run_vllm_async(
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
- engine_args, disable_frontend_multiprocessing) as llm:
+ engine_args, disable_frontend_multiprocessing
+ ) as llm:
+ model_config = await llm.get_model_config()
assert all(
- llm.model_config.max_model_len >= (request.prompt_len +
- request.expected_output_len)
- for request in requests), (
- "Please ensure that max_model_len is greater than the sum of"
- " prompt_len and expected_output_len for all requests.")
+ model_config.max_model_len
+ >= (request.prompt_len + request.expected_output_len)
+ for request in requests
+ ), (
+ "Please ensure that max_model_len is greater than the sum of"
+ " prompt_len and expected_output_len for all requests."
+ )
# Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = []
@@ -159,11 +184,15 @@ async def run_vllm_async(
lora_requests: list[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
- TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
- multi_modal_data=request.multi_modal_data)
- if "prompt_token_ids" in request.prompt else \
- TextPrompt(prompt=request.prompt,
- multi_modal_data=request.multi_modal_data))
+ TokensPrompt(
+ prompt_token_ids=request.prompt["prompt_token_ids"],
+ multi_modal_data=request.multi_modal_data,
+ )
+ if "prompt_token_ids" in request.prompt
+ else TextPrompt(
+ prompt=request.prompt, multi_modal_data=request.multi_modal_data
+ )
+ )
sampling_params.append(
SamplingParams(
n=n,
@@ -172,17 +201,16 @@ async def run_vllm_async(
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
- ))
+ )
+ )
lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
- for i, (prompt, sp,
- lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
- generator = llm.generate(prompt,
- sp,
- lora_request=lr,
- request_id=f"test{i}")
+ for i, (prompt, sp, lr) in enumerate(
+ zip(prompts, sampling_params, lora_requests)
+ ):
+ generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
@@ -201,7 +229,8 @@ def run_hf(
disable_detokenize: bool = False,
) -> float:
llm = AutoModelForCausalLM.from_pretrained(
- model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
+ model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
+ )
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
@@ -224,14 +253,15 @@ def run_hf(
# Check if we can add more requests to the batch.
next_prompt_len = requests[i + 1].prompt_len
next_output_len = requests[i + 1].expected_output_len
- if (max(max_prompt_len, next_prompt_len) +
- max(max_output_len, next_output_len)) <= 2048:
+ if (
+ max(max_prompt_len, next_prompt_len)
+ + max(max_output_len, next_output_len)
+ ) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
- input_ids = tokenizer(batch, return_tensors="pt",
- padding=True).input_ids
+ input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=True,
@@ -261,6 +291,7 @@ def run_mii(
output_len: int,
) -> float:
from mii import client, serve
+
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [request.prompt for request in requests]
@@ -272,8 +303,9 @@ def run_mii(
return end - start
-def save_to_pytorch_benchmark_format(args: argparse.Namespace,
- results: dict[str, Any]) -> None:
+def save_to_pytorch_benchmark_format(
+ args: argparse.Namespace, results: dict[str, Any]
+) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
@@ -281,9 +313,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
- k: results[k]
- for k in ["elapsed_time", "num_requests", "total_num_tokens"]
- })
+ k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
+ },
+ )
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
@@ -315,7 +347,8 @@ def get_requests(args, tokenizer):
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
- "Tokenizer/model must have chat template for sonnet dataset.")
+ "Tokenizer/model must have chat template for sonnet dataset."
+ )
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
@@ -324,21 +357,21 @@ def get_requests(args, tokenizer):
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
- common_kwargs['dataset_subset'] = None
- common_kwargs['dataset_split'] = "train"
+ common_kwargs["dataset_subset"] = None
+ common_kwargs["dataset_split"] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
- common_kwargs['dataset_split'] = "train"
+ common_kwargs["dataset_split"] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
- common_kwargs['dataset_subset'] = args.hf_subset
- common_kwargs['dataset_split'] = args.hf_split
+ common_kwargs["dataset_subset"] = args.hf_subset
+ common_kwargs["dataset_split"] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset
- common_kwargs['dataset_subset'] = None
- common_kwargs['dataset_split'] = "train"
+ common_kwargs["dataset_subset"] = None
+ common_kwargs["dataset_split"] = "train"
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
@@ -353,10 +386,10 @@ def main(args: argparse.Namespace):
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
- args.tokenizer, trust_remote_code=args.trust_remote_code)
+ args.tokenizer, trust_remote_code=args.trust_remote_code
+ )
requests = get_requests(args, tokenizer)
- is_multi_modal = any(request.multi_modal_data is not None
- for request in requests)
+ is_multi_modal = any(request.multi_modal_data is not None for request in requests)
request_outputs: Optional[list[RequestOutput]] = None
if args.backend == "vllm":
if args.async_engine:
@@ -367,23 +400,34 @@ def main(args: argparse.Namespace):
AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing,
args.disable_detokenize,
- ))
+ )
+ )
else:
elapsed_time, request_outputs = run_vllm(
- requests, args.n, EngineArgs.from_cli_args(args),
- args.disable_detokenize)
+ requests,
+ args.n,
+ EngineArgs.from_cli_args(args),
+ args.disable_detokenize,
+ )
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
- elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
- args.hf_max_batch_size, args.trust_remote_code,
- args.disable_detokenize)
+ elapsed_time = run_hf(
+ requests,
+ args.model,
+ tokenizer,
+ args.n,
+ args.hf_max_batch_size,
+ args.trust_remote_code,
+ args.disable_detokenize,
+ )
elif args.backend == "mii":
- elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
- args.output_len)
+ elapsed_time = run_mii(
+ requests, args.model, args.tensor_parallel_size, args.output_len
+ )
elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat(
- requests, args.n, EngineArgs.from_cli_args(args),
- args.disable_detokenize)
+ requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
+ )
else:
raise ValueError(f"Unknown backend: {args.backend}")
@@ -395,28 +439,31 @@ def main(args: argparse.Namespace):
for ro in request_outputs:
if not isinstance(ro, RequestOutput):
continue
- total_prompt_tokens += len(
- ro.prompt_token_ids) if ro.prompt_token_ids else 0
- total_output_tokens += sum(
- len(o.token_ids) for o in ro.outputs if o)
+ total_prompt_tokens += (
+ len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
+ )
+ total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
total_num_tokens = total_prompt_tokens + total_output_tokens
else:
- total_num_tokens = sum(r.prompt_len + r.expected_output_len
- for r in requests)
+ total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
total_output_tokens = sum(r.expected_output_len for r in requests)
total_prompt_tokens = total_num_tokens - total_output_tokens
if is_multi_modal and args.backend != "vllm-chat":
- print("\033[91mWARNING\033[0m: Multi-modal request with "
- f"{args.backend} backend detected. The "
- "following metrics are not accurate because image tokens are not"
- " counted. See vllm-project/vllm/issues/9778 for details.")
+ print(
+ "\033[91mWARNING\033[0m: Multi-modal request with "
+ f"{args.backend} backend detected. The "
+ "following metrics are not accurate because image tokens are not"
+ " counted. See vllm-project/vllm/issues/9778 for details."
+ )
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
# vllm-chat backend counts the image tokens now
- print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
- f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
- f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
+ print(
+ f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
+ f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
+ f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
+ )
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
@@ -444,7 +491,8 @@ def validate_args(args):
warnings.warn(
"The '--dataset' argument will be deprecated in the next release. "
"Please use '--dataset-name' and '--dataset-path' instead.",
- stacklevel=2)
+ stacklevel=2,
+ )
args.dataset_path = args.dataset
if not getattr(args, "tokenizer", None):
@@ -457,9 +505,8 @@ def validate_args(args):
# === Dataset Configuration ===
if not args.dataset and not args.dataset_path:
- print(
- "When dataset path is not set, it will default to random dataset")
- args.dataset_name = 'random'
+ print("When dataset path is not set, it will default to random dataset")
+ args.dataset_name = "random"
if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset")
@@ -467,41 +514,55 @@ def validate_args(args):
# --hf-subset and --hf-split: only used
# when dataset_name is 'hf'
if args.dataset_name != "hf" and (
- getattr(args, "hf_subset", None) is not None
- or getattr(args, "hf_split", None) is not None):
- warnings.warn("--hf-subset and --hf-split will be ignored \
+ getattr(args, "hf_subset", None) is not None
+ or getattr(args, "hf_split", None) is not None
+ ):
+ warnings.warn(
+ "--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
- stacklevel=2)
+ stacklevel=2,
+ )
elif args.dataset_name == "hf":
if args.dataset_path in (
- VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
- | ConversationDataset.SUPPORTED_DATASET_PATHS):
- assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
- elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
- | AIMODataset.SUPPORTED_DATASET_PATHS):
- assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
+ VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
+ | ConversationDataset.SUPPORTED_DATASET_PATHS
+ ):
+ assert args.backend == "vllm-chat", (
+ f"{args.dataset_path} needs to use vllm-chat as the backend."
+ ) # noqa: E501
+ elif args.dataset_path in (
+ InstructCoderDataset.SUPPORTED_DATASET_PATHS
+ | AIMODataset.SUPPORTED_DATASET_PATHS
+ ):
+ assert args.backend == "vllm", (
+ f"{args.dataset_path} needs to use vllm as the backend."
+ ) # noqa: E501
else:
- raise ValueError(
- f"{args.dataset_path} is not supported by hf dataset.")
+ raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
- if args.dataset_name != 'random' and args.random_range_ratio is not None:
- warnings.warn("--random-range-ratio will be ignored since \
+ if args.dataset_name != "random" and args.random_range_ratio is not None:
+ warnings.warn(
+ "--random-range-ratio will be ignored since \
--dataset-name is not 'random'.",
- stacklevel=2)
+ stacklevel=2,
+ )
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set.
- if args.dataset_name not in {"random", "sonnet", None
- } and args.prefix_len is not None:
- warnings.warn("--prefix-len will be ignored since --dataset-name\
+ if (
+ args.dataset_name not in {"random", "sonnet", None}
+ and args.prefix_len is not None
+ ):
+ warnings.warn(
+ "--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.",
- stacklevel=2)
+ stacklevel=2,
+ )
# === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm":
- raise ValueError(
- "LoRA benchmarking is only supported for vLLM backend")
+ raise ValueError("LoRA benchmarking is only supported for vLLM backend")
if getattr(args, "enable_lora", False) and args.lora_path is None:
raise ValueError("LoRA path must be provided when enable_lora is True")
@@ -511,8 +572,10 @@ def validate_args(args):
if args.backend != "hf" and args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
- if args.backend in {"hf", "mii"} and getattr(args, "quantization",
- None) is not None:
+ if (
+ args.backend in {"hf", "mii"}
+ and getattr(args, "quantization", None) is not None
+ ):
raise ValueError("Quantization is only for vLLM backend.")
if args.backend == "mii" and args.dtype != "auto":
@@ -520,29 +583,32 @@ def validate_args(args):
if args.backend == "mii" and args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.backend == "mii" and args.tokenizer != args.model:
- raise ValueError(
- "Tokenizer must be the same as the model for MII backend.")
+ raise ValueError("Tokenizer must be the same as the model for MII backend.")
# --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1:
raise ValueError(
"Data parallel is not supported in offline benchmark, \
- please use benchmark serving instead")
+ please use benchmark serving instead"
+ )
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
- parser.add_argument("--backend",
- type=str,
- choices=["vllm", "hf", "mii", "vllm-chat"],
- default="vllm")
+ parser.add_argument(
+ "--backend",
+ type=str,
+ choices=["vllm", "hf", "mii", "vllm-chat"],
+ default="vllm",
+ )
parser.add_argument(
"--dataset-name",
type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
help="Name of the dataset to benchmark on.",
- default="sharegpt")
+ default="sharegpt",
+ )
parser.add_argument(
"--dataset",
type=str,
@@ -550,57 +616,70 @@ def validate_args(args):
help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: "
- "list[dict[..., value: ]]]]")
- parser.add_argument("--dataset-path",
- type=str,
- default=None,
- help="Path to the dataset")
- parser.add_argument("--input-len",
- type=int,
- default=None,
- help="Input prompt length for each request")
- parser.add_argument("--output-len",
- type=int,
- default=None,
- help="Output length for each request. Overrides the "
- "output length from the dataset.")
- parser.add_argument("--n",
- type=int,
- default=1,
- help="Number of generated sequences per prompt.")
- parser.add_argument("--num-prompts",
- type=int,
- default=1000,
- help="Number of prompts to process.")
- parser.add_argument("--hf-max-batch-size",
- type=int,
- default=None,
- help="Maximum batch size for HF backend.")
+ "list[dict[..., value: ]]]]",
+ )
+ parser.add_argument(
+ "--dataset-path", type=str, default=None, help="Path to the dataset"
+ )
+ parser.add_argument(
+ "--input-len",
+ type=int,
+ default=None,
+ help="Input prompt length for each request",
+ )
+ parser.add_argument(
+ "--output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the "
+ "output length from the dataset.",
+ )
+ parser.add_argument(
+ "--n", type=int, default=1, help="Number of generated sequences per prompt."
+ )
parser.add_argument(
- '--output-json',
+ "--num-prompts", type=int, default=1000, help="Number of prompts to process."
+ )
+ parser.add_argument(
+ "--hf-max-batch-size",
+ type=int,
+ default=None,
+ help="Maximum batch size for HF backend.",
+ )
+ parser.add_argument(
+ "--output-json",
type=str,
default=None,
- help='Path to save the throughput results in JSON format.')
- parser.add_argument("--async-engine",
- action='store_true',
- default=False,
- help="Use vLLM async engine rather than LLM class.")
- parser.add_argument("--disable-frontend-multiprocessing",
- action='store_true',
- default=False,
- help="Disable decoupled async engine frontend.")
+ help="Path to save the throughput results in JSON format.",
+ )
+ parser.add_argument(
+ "--async-engine",
+ action="store_true",
+ default=False,
+ help="Use vLLM async engine rather than LLM class.",
+ )
+ parser.add_argument(
+ "--disable-frontend-multiprocessing",
+ action="store_true",
+ default=False,
+ help="Disable decoupled async engine frontend.",
+ )
parser.add_argument(
"--disable-detokenize",
action="store_true",
- help=("Do not detokenize the response (i.e. do not include "
- "detokenization time in the measurement)"))
+ help=(
+ "Do not detokenize the response (i.e. do not include "
+ "detokenization time in the measurement)"
+ ),
+ )
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
- help="Path to the lora adapters to use. This can be an absolute path, "
- "a relative path, or a Hugging Face model identifier.")
+ help="Path to the LoRA adapters to use. This can be an absolute path, "
+ "a relative path, or a Hugging Face model identifier.",
+ )
parser.add_argument(
"--prefix-len",
type=int,
@@ -614,7 +693,8 @@ def validate_args(args):
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
"controls how much of the input is fixed lines versus "
"random lines, but the total input length remains approximately "
- "input_len tokens.")
+ "input_len tokens.",
+ )
# random dataset
parser.add_argument(
"--random-range-ratio",
@@ -628,14 +708,12 @@ def validate_args(args):
)
# hf dtaset
- parser.add_argument("--hf-subset",
- type=str,
- default=None,
- help="Subset of the HF dataset.")
- parser.add_argument("--hf-split",
- type=str,
- default=None,
- help="Split of the HF dataset.")
+ parser.add_argument(
+ "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
+ )
+ parser.add_argument(
+ "--hf-split", type=str, default=None, help="Split of the HF dataset."
+ )
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py
index 45a0ddbd5d08..b0c4fca92c3d 100644
--- a/benchmarks/benchmark_utils.py
+++ b/benchmarks/benchmark_utils.py
@@ -7,9 +7,9 @@
from typing import Any
-def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
- metrics: dict[str, list],
- extra_info: dict[str, Any]) -> list:
+def convert_to_pytorch_benchmark_format(
+ args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
+) -> list:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
@@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
},
}
- tp = record["benchmark"]["extra_info"]["args"].get(
- "tensor_parallel_size")
+ tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info:
- record["benchmark"]["extra_info"]["args"][
- "tensor_parallel_size"] = extra_info["tensor_parallel_size"]
+ record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
+ extra_info["tensor_parallel_size"]
+ )
records.append(record)
@@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder):
-
def clear_inf(self, o: Any):
if isinstance(o, dict):
return {k: self.clear_inf(v) for k, v in o.items()}
diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py
index 9e36b0a9d3bb..da258f98e085 100644
--- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py
+++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py
@@ -23,8 +23,9 @@
# bench
-def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
- **kwargs) -> TMeasurement:
+def bench_fn(
+ label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
+) -> TMeasurement:
min_run_time = 1
globals = {
@@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
).blocked_autorange(min_run_time=min_run_time)
-def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
- sub_label: str) -> Iterable[TMeasurement]:
+def bench_int8(
+ dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
+) -> Iterable[TMeasurement]:
assert dtype == torch.int8
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
- bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
+ bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
- out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
- torch.bfloat16)
+ out = ops.cutlass_scaled_sparse_mm(
+ a, b_compressed, e, scale_a, scale_b, torch.bfloat16
+ )
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
@@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers = []
# pytorch impl - bfloat16
timers.append(
- bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
- torch.mm, a.to(dtype=torch.bfloat16),
- b.to(dtype=torch.bfloat16)))
+ bench_fn(
+ label,
+ sub_label,
+ "pytorch_bf16_bf16_bf16_matmul-no-scales",
+ torch.mm,
+ a.to(dtype=torch.bfloat16),
+ b.to(dtype=torch.bfloat16),
+ )
+ )
# pytorch impl - float16
timers.append(
- bench_fn(label, sub_label,
- "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
- a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
+ bench_fn(
+ label,
+ sub_label,
+ "pytorch_fp16_fp16_fp16_matmul-no-scales",
+ torch.mm,
+ a.to(dtype=torch.float16),
+ b.to(dtype=torch.float16),
+ )
+ )
# cutlass impl
timers.append(
- bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
- ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
- torch.bfloat16))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_i8_i8_bf16_scaled_mm",
+ ops.cutlass_scaled_mm,
+ a,
+ b,
+ scale_a,
+ scale_b,
+ torch.bfloat16,
+ )
+ )
# cutlass with bias
timers.append(
- bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
- ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
- bias))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_i8_i8_bf16_scaled_mm_bias",
+ ops.cutlass_scaled_mm,
+ a,
+ b,
+ scale_a,
+ scale_b,
+ torch.bfloat16,
+ bias,
+ )
+ )
# cutlass sparse impl
timers.append(
- bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
- ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
- scale_b, torch.bfloat16))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_i8_i8_bf16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm,
+ a,
+ b_compressed,
+ e,
+ scale_a,
+ scale_b,
+ torch.bfloat16,
+ )
+ )
# cutlass sparse with bias
timers.append(
- bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
- ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
- scale_b, torch.bfloat16, bias))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm,
+ a,
+ b_compressed,
+ e,
+ scale_a,
+ scale_b,
+ torch.bfloat16,
+ bias,
+ )
+ )
return timers
-def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
- sub_label: str) -> Iterable[TMeasurement]:
+def bench_fp8(
+ dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
+) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
- b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
- k)
+ b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
- bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
+ bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
- out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
- torch.bfloat16)
+ out = ops.cutlass_scaled_sparse_mm(
+ a, b_compressed, e, scale_a, scale_b, torch.bfloat16
+ )
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
@@ -124,97 +180,165 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# pytorch impl w. bf16
timers.append(
- bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
- torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
- b.to(dtype=torch.bfloat16, device="cuda")))
+ bench_fn(
+ label,
+ sub_label,
+ "pytorch_bf16_bf16_bf16_matmul-no-scales",
+ torch.mm,
+ a.to(dtype=torch.bfloat16, device="cuda"),
+ b.to(dtype=torch.bfloat16, device="cuda"),
+ )
+ )
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
- bench_fn(label,
- sub_label,
- "pytorch_fp8_fp8_bf16_scaled_mm",
- torch._scaled_mm,
- a,
- b,
- scale_a=scale_a,
- scale_b=scale_b,
- out_dtype=torch.bfloat16))
+ bench_fn(
+ label,
+ sub_label,
+ "pytorch_fp8_fp8_bf16_scaled_mm",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.bfloat16,
+ )
+ )
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
- bench_fn(label,
- sub_label,
- "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
- torch._scaled_mm,
- a,
- b,
- scale_a=scale_a,
- scale_b=scale_b,
- out_dtype=torch.bfloat16,
- use_fast_accum=True))
+ bench_fn(
+ label,
+ sub_label,
+ "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.bfloat16,
+ use_fast_accum=True,
+ )
+ )
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
- bench_fn(label,
- sub_label,
- "pytorch_fp8_fp8_fp16_scaled_mm",
- torch._scaled_mm,
- a,
- b,
- scale_a=scale_a,
- scale_b=scale_b,
- out_dtype=torch.float16))
+ bench_fn(
+ label,
+ sub_label,
+ "pytorch_fp8_fp8_fp16_scaled_mm",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.float16,
+ )
+ )
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
- bench_fn(label,
- sub_label,
- "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
- torch._scaled_mm,
- a,
- b,
- scale_a=scale_a,
- scale_b=scale_b,
- out_dtype=torch.float16,
- use_fast_accum=True))
+ bench_fn(
+ label,
+ sub_label,
+ "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.float16,
+ use_fast_accum=True,
+ )
+ )
# cutlass impl: bf16 output
timers.append(
- bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
- ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
- torch.bfloat16))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_fp8_fp8_bf16_scaled_mm",
+ ops.cutlass_scaled_mm,
+ a,
+ b,
+ scale_a,
+ scale_b,
+ torch.bfloat16,
+ )
+ )
# cutlass impl: bf16 output
timers.append(
- bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
- ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
- scale_b, torch.bfloat16))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm,
+ a,
+ b_compressed,
+ e,
+ scale_a,
+ scale_b,
+ torch.bfloat16,
+ )
+ )
# cutlass impl: fp16 output
timers.append(
- bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
- ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
- scale_b, torch.float16))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm,
+ a,
+ b_compressed,
+ e,
+ scale_a,
+ scale_b,
+ torch.float16,
+ )
+ )
# cutlass impl: bf16 output, with bias
timers.append(
- bench_fn(label, sub_label,
- "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
- ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
- scale_b, torch.bfloat16, bias))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm,
+ a,
+ b_compressed,
+ e,
+ scale_a,
+ scale_b,
+ torch.bfloat16,
+ bias,
+ )
+ )
# cutlass impl: fp16 output, with bias
timers.append(
- bench_fn(label, sub_label,
- "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
- ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
- scale_b, torch.float16, bias.to(dtype=torch.float16)))
+ bench_fn(
+ label,
+ sub_label,
+ "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm,
+ a,
+ b_compressed,
+ e,
+ scale_a,
+ scale_b,
+ torch.float16,
+ bias.to(dtype=torch.float16),
+ )
+ )
return timers
-def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
- sub_label: str) -> Iterable[TMeasurement]:
+def bench(
+ dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
+) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
@@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print()
-def run(dtype: torch.dtype,
- MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
+def run(
+ dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
+) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
- timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
- f"MKN=({m}x{k}x{n})")
+ timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
print_timers(timers)
results.extend(timers)
@@ -241,10 +365,12 @@ def run(dtype: torch.dtype,
# output makers
-def make_output(data: Iterable[TMeasurement],
- MKNs: Iterable[tuple[int, int, int]],
- base_description: str,
- timestamp=None):
+def make_output(
+ data: Iterable[TMeasurement],
+ MKNs: Iterable[tuple[int, int, int]],
+ base_description: str,
+ timestamp=None,
+):
print(f"== All Results {base_description} ====")
print_timers(data)
@@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args):
- dim_sizes = list(
- range(args.dim_start, args.dim_end + 1, args.dim_increment))
+ dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs)
@@ -319,7 +444,7 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
pkl.dump(all_data, f)
-if __name__ == '__main__':
+if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
@@ -344,12 +469,15 @@ def to_torch_dtype(dt):
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
- formatter_class=argparse.RawTextHelpFormatter)
-
- parser.add_argument("--dtype",
- type=to_torch_dtype,
- required=True,
- help="Available options are ['int8', 'fp8']")
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--dtype",
+ type=to_torch_dtype,
+ required=True,
+ help="Available options are ['int8', 'fp8']",
+ )
subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench")
@@ -368,19 +496,19 @@ def to_torch_dtype(dt):
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
- model_parser.add_argument("--models",
- nargs="+",
- type=str,
- default=DEFAULT_MODELS,
- choices=WEIGHT_SHAPES.keys())
- model_parser.add_argument("--tp-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_TP_SIZES)
- model_parser.add_argument("--batch-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_BATCH_SIZES)
+ model_parser.add_argument(
+ "--models",
+ nargs="+",
+ type=str,
+ default=DEFAULT_MODELS,
+ choices=WEIGHT_SHAPES.keys(),
+ )
+ model_parser.add_argument(
+ "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
+ )
+ model_parser.add_argument(
+ "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
+ )
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py
index fe4d8fdfc066..7e9f5a7fc0f4 100644
--- a/benchmarks/cutlass_benchmarks/utils.py
+++ b/benchmarks/cutlass_benchmarks/utils.py
@@ -10,8 +10,9 @@
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
- return torch.round(tensor.clamp(
- min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
+ return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
+ dtype=torch.float8_e4m3fn
+ )
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
@@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
-def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
- k: int) -> tuple[torch.Tensor, torch.Tensor]:
- a = torch.randn((m, k), device='cuda') * 5
- b = torch.randn((n, k), device='cuda').t() * 5
+def make_rand_tensors(
+ dtype: torch.dtype, m: int, n: int, k: int
+) -> tuple[torch.Tensor, torch.Tensor]:
+ a = torch.randn((m, k), device="cuda") * 5
+ b = torch.randn((n, k), device="cuda").t() * 5
if dtype == torch.int8:
return to_int8(a), to_int8(b)
@@ -49,9 +51,7 @@ def prune_to_2_4(tensor):
# Create binary mask
mask = torch.zeros_like(reshaped)
- mask.scatter_(dim=1,
- index=indices,
- src=torch.ones_like(indices, dtype=mask.dtype))
+ mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
@@ -62,10 +62,11 @@ def prune_to_2_4(tensor):
return pruned.reshape(original_shape)
-def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
- k: int) -> tuple[torch.Tensor, torch.Tensor]:
- a = torch.randn((m, k), device='cuda') * 5
- b = torch.randn((n, k), device='cuda').t() * 5
+def make_rand_sparse_tensors(
+ dtype: torch.dtype, m: int, n: int, k: int
+) -> tuple[torch.Tensor, torch.Tensor]:
+ a = torch.randn((m, k), device="cuda") * 5
+ b = torch.randn((n, k), device="cuda").t() * 5
b = prune_to_2_4(b.t()).t()
@@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
return b_compressed, e, a, b
-def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
- m: int, n: int, k: int) -> \
- tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
+def make_n_rand_sparse_tensors(
+ num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
+) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
ABs = []
for _ in range(num_tensors):
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
index e7b742d8bec9..08e93837f7dd 100644
--- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
@@ -16,7 +16,8 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
- w8a8_block_fp8_matmul)
+ w8a8_block_fp8_matmul,
+)
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@@ -25,8 +26,9 @@
# bench
-def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
- **kwargs) -> TMeasurement:
+def bench_fn(
+ label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
+) -> TMeasurement:
min_run_time = 1
globals = {
@@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
def bench_int8(
- dtype: torch.dtype,
- m: int,
- k: int,
- n: int,
- label: str,
- sub_label: str,
- bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
+ dtype: torch.dtype,
+ m: int,
+ k: int,
+ n: int,
+ label: str,
+ sub_label: str,
+ bench_kernels: Optional[list[str]] = None,
+) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels."""
assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
- bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
- azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
- azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
+ bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
+ azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
+ azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
bench_fns = {
- "pytorch_bf16_bf16_bf16_matmul-no-scales":
- lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
- ),
- "pytorch_fp16_fp16_fp16_matmul-no-scales":
- lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
- "cutlass_i8_i8_bf16_scaled_mm":
- lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
- "cutlass_i8_i8_bf16_scaled_mm_bias":
- lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
- bias),
- "cutlass_i8_i8_bf16_scaled_mm_azp":
- lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
- bfloat16, azp_adj),
- "cutlass_i8_i8_bf16_scaled_mm_azp_bias":
- lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
- bfloat16, azp_adj, None, bias),
- "cutlass_i8_i8_bf16_scaled_mm_azp_pt":
- lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
- bfloat16, azp_adj, azp),
- "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
- lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
- bfloat16, azp_adj, azp, bias),
+ "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
+ a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
+ ),
+ "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
+ a.to(dtype=torch.float16), b.to(dtype=torch.float16)
+ ),
+ "cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
+ a, b, scale_a, scale_b, torch.bfloat16
+ ),
+ "cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
+ a, b, scale_a, scale_b, torch.bfloat16, bias
+ ),
+ "cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
+ a, b, scale_a, scale_b, torch.bfloat16, azp_adj
+ ),
+ "cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
+ a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
+ ),
+ "cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
+ a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
+ ),
+ "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
+ a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
+ ),
}
timers = []
@@ -96,73 +101,73 @@ def bench_int8(
def bench_fp8(
- dtype: torch.dtype,
- m: int,
- k: int,
- n: int,
- label: str,
- sub_label: str,
- bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
+ dtype: torch.dtype,
+ m: int,
+ k: int,
+ n: int,
+ label: str,
+ sub_label: str,
+ bench_kernels: Optional[list[str]] = None,
+) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
- block_scale_a = torch.rand((m, k // 128),
- device="cuda",
- dtype=torch.float32)
- block_scale_b = torch.rand((k // 128, n // 128),
- device="cuda",
- dtype=torch.float32)
+
+ def ceil_div(x: int, y: int) -> int:
+ return (x + y - 1) // y
+
+ block_scale_a = torch.rand(
+ (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
+ )
+ block_scale_b = torch.rand(
+ ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
+ )
block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t()
- bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
+ bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
print(m, k, n)
bench_fns = {
- "pytorch_bf16_bf16_bf16_matmul-no-scales":
- lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
- ),
- "pytorch_fp16_fp16_fp16_matmul-no-scales":
- lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
- "pytorch_fp8_fp8_fp16_scaled_mm":
- lambda: torch._scaled_mm(
- a, b, scale_a, scale_b, out_dtype=torch.float16),
- "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
- lambda: torch._scaled_mm(a,
- b,
- scale_a,
- scale_b,
- out_dtype=torch.float16,
- use_fast_accum=True),
- "pytorch_fp8_fp8_bf16_scaled_mm":
- lambda: torch._scaled_mm(
- a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
- "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
- lambda: torch._scaled_mm(a,
- b,
- scale_a,
- scale_b,
- out_dtype=torch.bfloat16,
- use_fast_accum=True),
- "cutlass_fp8_fp8_bf16_scaled_mm":
- lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
- "cutlass_fp8_fp8_fp16_scaled_mm":
- lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
- "cutlass_fp8_fp8_bf16_scaled_mm_bias":
- lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
- bias),
- "cutlass_fp8_fp8_fp16_scaled_mm_bias":
- lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
- bias.to(dtype=torch.float16)),
- "triton_fp8_fp8_fp16_scaled_mm_blockwise":
- lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
- block_scale_b.t(), (128, 128)),
- "cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
- lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
- block_scale_b_K_major, torch.float16),
+ "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
+ a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
+ ),
+ "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
+ a.to(dtype=torch.float16), b.to(dtype=torch.float16)
+ ),
+ "pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
+ a, b, scale_a, scale_b, out_dtype=torch.float16
+ ),
+ "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
+ a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
+ ),
+ "pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
+ a, b, scale_a, scale_b, out_dtype=torch.bfloat16
+ ),
+ "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
+ a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
+ ),
+ "cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
+ a, b, scale_a, scale_b, torch.bfloat16
+ ),
+ "cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
+ a, b, scale_a, scale_b, torch.float16
+ ),
+ "cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
+ a, b, scale_a, scale_b, torch.bfloat16, bias
+ ),
+ "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
+ a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
+ ),
+ "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
+ a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
+ ),
+ "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
+ a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
+ ),
}
timers = []
@@ -175,13 +180,15 @@ def bench_fp8(
return timers
-def bench(dtype: torch.dtype,
- m: int,
- k: int,
- n: int,
- label: str,
- sub_label: str,
- bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
+def bench(
+ dtype: torch.dtype,
+ m: int,
+ k: int,
+ n: int,
+ label: str,
+ sub_label: str,
+ bench_kernels: Optional[list[str]] = None,
+) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn:
@@ -195,27 +202,33 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print()
-def run(dtype: torch.dtype,
- MKNs: Iterable[tuple[int, int, int]],
- bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
+def run(
+ dtype: torch.dtype,
+ MKNs: Iterable[tuple[int, int, int]],
+ bench_kernels: Optional[list[str]] = None,
+) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
- timers = bench(dtype,
- m,
- k,
- n,
- f"scaled-{dtype}-gemm",
- f"MKN=({m}x{k}x{n})",
- bench_kernels=bench_kernels)
+ timers = bench(
+ dtype,
+ m,
+ k,
+ n,
+ f"scaled-{dtype}-gemm",
+ f"MKN=({m}x{k}x{n})",
+ bench_kernels=bench_kernels,
+ )
print_timers(timers)
results.extend(timers)
return results
-def make_output(data: Iterable[TMeasurement],
- MKNs: Iterable[tuple[int, int, int]],
- base_description: str,
- timestamp=None):
+def make_output(
+ data: Iterable[TMeasurement],
+ MKNs: Iterable[tuple[int, int, int]],
+ base_description: str,
+ timestamp=None,
+):
print(f"== All Results {base_description} ====")
print_timers(data)
@@ -226,8 +239,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args):
- dim_sizes = list(
- range(args.dim_start, args.dim_end + 1, args.dim_increment))
+ dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
make_output(data, MKNs, f"square_bench-{args.dtype}")
@@ -285,7 +297,7 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
pkl.dump(all_data, f)
-if __name__ == '__main__':
+if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
@@ -310,19 +322,21 @@ def to_torch_dtype(dt):
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
- formatter_class=argparse.RawTextHelpFormatter)
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
- parser.add_argument("--dtype",
- type=to_torch_dtype,
- required=True,
- help="Available options are ['int8', 'fp8']")
+ parser.add_argument(
+ "--dtype",
+ type=to_torch_dtype,
+ required=True,
+ help="Available options are ['int8', 'fp8']",
+ )
parser.add_argument(
"--kernels",
nargs="+",
type=str,
default=None,
- help=
- "Exact names of the kernels to benchmark. If not set, runs all kernels."
+ help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
)
subparsers = parser.add_subparsers(dest="cmd")
@@ -343,19 +357,19 @@ def to_torch_dtype(dt):
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
- model_parser.add_argument("--models",
- nargs="+",
- type=str,
- default=DEFAULT_MODELS,
- choices=WEIGHT_SHAPES.keys())
- model_parser.add_argument("--tp-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_TP_SIZES)
- model_parser.add_argument("--batch-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_BATCH_SIZES)
+ model_parser.add_argument(
+ "--models",
+ nargs="+",
+ type=str,
+ default=DEFAULT_MODELS,
+ choices=WEIGHT_SHAPES.keys(),
+ )
+ model_parser.add_argument(
+ "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
+ )
+ model_parser.add_argument(
+ "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
+ )
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py
index 3d1121df40d0..d31b623a1ee6 100644
--- a/benchmarks/cutlass_benchmarks/weight_shapes.py
+++ b/benchmarks/cutlass_benchmarks/weight_shapes.py
@@ -42,4 +42,4 @@
([8192, 57344], 1),
([28672, 8192], 0),
],
-}
\ No newline at end of file
+}
diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
index 980e68668911..fce156e1c96c 100644
--- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
+++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
@@ -12,39 +12,37 @@
async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
- headers = {
- "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
- }
- async with session.post(url=url, json=data,
- headers=headers) as response:
+ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
+ async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked':
if True:
- async for chunk_bytes in response.content.iter_chunked(
- 1024):
+ async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
-@app.route('/v1/completions', methods=['POST'])
+@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
- prefill_request['max_tokens'] = 1
+ prefill_request["max_tokens"] = 1
# finish prefill
- async for _ in forward_request('http://localhost:8100/v1/completions',
- prefill_request):
+ async for _ in forward_request(
+ "http://localhost:8100/v1/completions", prefill_request
+ ):
continue
# return decode
- generator = forward_request('http://localhost:8200/v1/completions',
- original_request_data)
+ generator = forward_request(
+ "http://localhost:8200/v1/completions", original_request_data
+ )
response = await make_response(generator)
response.timeout = None
@@ -53,11 +51,12 @@ async def handle_request():
except Exception as e:
import sys
import traceback
+
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
-if __name__ == '__main__':
+if __name__ == "__main__":
app.run(port=8000)
diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py
index c2ad4916bf07..fd19b40bf252 100644
--- a/benchmarks/disagg_benchmarks/round_robin_proxy.py
+++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py
@@ -8,7 +8,6 @@
class RoundRobinProxy:
-
def __init__(self, target_ports):
self.target_ports = target_ports
self.port_cycle = itertools.cycle(self.target_ports)
@@ -21,14 +20,15 @@ async def handle_request(self, request):
try:
# Forward the request
async with session.request(
- method=request.method,
- url=target_url,
- headers=request.headers,
- data=request.content,
+ method=request.method,
+ url=target_url,
+ headers=request.headers,
+ data=request.content,
) as response:
# Start sending the response
- resp = web.StreamResponse(status=response.status,
- headers=response.headers)
+ resp = web.StreamResponse(
+ status=response.status, headers=response.headers
+ )
await resp.prepare(request)
# Stream the response content
@@ -45,11 +45,11 @@ async def handle_request(self, request):
async def main():
proxy = RoundRobinProxy([8100, 8200])
app = web.Application()
- app.router.add_route('*', '/{path:.*}', proxy.handle_request)
+ app.router.add_route("*", "/{path:.*}", proxy.handle_request)
runner = web.AppRunner(app)
await runner.setup()
- site = web.TCPSite(runner, 'localhost', 8000)
+ site = web.TCPSite(runner, "localhost", 8000)
await site.start()
print("Proxy server started on http://localhost:8000")
@@ -58,5 +58,5 @@ async def main():
await asyncio.Event().wait()
-if __name__ == '__main__':
+if __name__ == "__main__":
asyncio.run(main())
diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py
index a7b4b9e8bf30..484d0cb3cba7 100644
--- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py
+++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py
@@ -6,43 +6,41 @@
import pandas as pd
if __name__ == "__main__":
-
data = []
- for name in ['disagg_prefill', 'chunked_prefill']:
+ for name in ["disagg_prefill", "chunked_prefill"]:
for qps in [2, 4, 6, 8]:
with open(f"results/{name}-qps-{qps}.json") as f:
x = json.load(f)
- x['name'] = name
- x['qps'] = qps
+ x["name"] = name
+ x["qps"] = qps
data.append(x)
df = pd.DataFrame.from_dict(data)
- dis_df = df[df['name'] == 'disagg_prefill']
- chu_df = df[df['name'] == 'chunked_prefill']
+ dis_df = df[df["name"] == "disagg_prefill"]
+ chu_df = df[df["name"] == "chunked_prefill"]
- plt.style.use('bmh')
- plt.rcParams['font.size'] = 20
+ plt.style.use("bmh")
+ plt.rcParams["font.size"] = 20
for key in [
- 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms',
- 'median_itl_ms', 'p99_itl_ms'
+ "mean_ttft_ms",
+ "median_ttft_ms",
+ "p99_ttft_ms",
+ "mean_itl_ms",
+ "median_itl_ms",
+ "p99_itl_ms",
]:
-
fig, ax = plt.subplots(figsize=(11, 7))
- plt.plot(dis_df['qps'],
- dis_df[key],
- label='disagg_prefill',
- marker='o',
- linewidth=4)
- plt.plot(chu_df['qps'],
- chu_df[key],
- label='chunked_prefill',
- marker='o',
- linewidth=4)
+ plt.plot(
+ dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
+ )
+ plt.plot(
+ chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
+ )
ax.legend()
- ax.set_xlabel('QPS')
+ ax.set_xlabel("QPS")
ax.set_ylabel(key)
ax.set_ylim(bottom=0)
- fig.savefig(f'results/{key}.png')
+ fig.savefig(f"results/{key}.png")
plt.close(fig)
diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py
index 3da583a33448..37a9173a1a93 100644
--- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py
+++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py
@@ -24,10 +24,12 @@ class bench_params_t:
dtype: torch.dtype
def description(self):
- return (f'N {self.num_tokens} '
- f'x D {self.hidden_size} '
- f'x R {self.add_residual} '
- f'x DT {self.dtype}')
+ return (
+ f"N {self.num_tokens} "
+ f"x D {self.hidden_size} "
+ f"x R {self.add_residual} "
+ f"x DT {self.dtype}"
+ )
def get_bench_params() -> list[bench_params_t]:
@@ -38,15 +40,19 @@ def get_bench_params() -> list[bench_params_t]:
DTYPES = [torch.bfloat16, torch.float]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
- bench_params = list(map(lambda x: \
- bench_params_t(x[0], x[1], x[2], x[3]), combinations))
+ bench_params = list(
+ map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
+ )
return bench_params
# Reference impls
-def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
- residual: Optional[torch.Tensor],
- quant_dtype: torch.dtype):
+def unfused_int8_impl(
+ rms_norm_layer: RMSNorm,
+ x: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ quant_dtype: torch.dtype,
+):
# Norm
torch_out = None
if residual is None:
@@ -58,9 +64,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
-def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
- residual: Optional[torch.Tensor],
- quant_dtype: torch.dtype):
+def unfused_fp8_impl(
+ rms_norm_layer: RMSNorm,
+ x: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ quant_dtype: torch.dtype,
+):
# Norm
torch_out = None
if residual is None:
@@ -73,22 +82,27 @@ def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
def fused_impl(
- rms_norm_layer: RMSNorm, # this stores the weights
- x: torch.Tensor,
- residual: Optional[torch.Tensor],
- quant_dtype: torch.dtype):
- out, _ = ops.rms_norm_dynamic_per_token_quant(x,
- rms_norm_layer.weight,
- 1e-6,
- quant_dtype,
- residual=residual)
+ rms_norm_layer: RMSNorm, # this stores the weights
+ x: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ quant_dtype: torch.dtype,
+):
+ out, _ = ops.rms_norm_dynamic_per_token_quant(
+ x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
+ )
# Bench functions
-def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
- quant_dtype: torch.dtype, label: str, sub_label: str,
- fn: Callable, description: str) -> TMeasurement:
-
+def bench_fn(
+ rms_norm_layer: RMSNorm,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ quant_dtype: torch.dtype,
+ label: str,
+ sub_label: str,
+ fn: Callable,
+ description: str,
+) -> TMeasurement:
min_run_time = 1
globals = {
@@ -106,43 +120,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
description=description,
).blocked_autorange(min_run_time=min_run_time)
-def bench(params: bench_params_t, label: str, sub_label: str) \
- -> Iterable[TMeasurement]:
+def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
# Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
scale = 1 / params.hidden_size
- x = torch.randn(params.num_tokens,
- params.hidden_size,
- dtype=params.dtype,
- device='cuda') * scale
- residual = (torch.randn_like(x) * scale).to(device='cuda') \
- if params.add_residual else None
+ x = (
+ torch.randn(
+ params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
+ )
+ * scale
+ )
+ residual = (
+ (torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
+ )
timers = []
# unfused int8 impl.
timers.append(
- bench_fn(layer, x, residual, torch.int8, label, sub_label,
- unfused_int8_impl, "unfused_int8_impl"))
+ bench_fn(
+ layer,
+ x,
+ residual,
+ torch.int8,
+ label,
+ sub_label,
+ unfused_int8_impl,
+ "unfused_int8_impl",
+ )
+ )
# unfused fp8 impl.
timers.append(
- bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
- unfused_fp8_impl, "unfused_fp8_impl"))
+ bench_fn(
+ layer,
+ x,
+ residual,
+ torch.float8_e4m3fn,
+ label,
+ sub_label,
+ unfused_fp8_impl,
+ "unfused_fp8_impl",
+ )
+ )
# fused int8 impl.
timers.append(
- bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
- "fused_int8_impl"))
+ bench_fn(
+ layer,
+ x,
+ residual,
+ torch.int8,
+ label,
+ sub_label,
+ fused_impl,
+ "fused_int8_impl",
+ )
+ )
# fused fp8 impl.
timers.append(
- bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
- fused_impl, "fused_fp8_impl"))
+ bench_fn(
+ layer,
+ x,
+ residual,
+ torch.float8_e4m3fn,
+ label,
+ sub_label,
+ fused_impl,
+ "fused_fp8_impl",
+ )
+ )
print_timers(timers)
@@ -157,13 +209,12 @@ def print_timers(timers: Iterable[TMeasurement]):
def main():
- torch.set_default_device('cuda')
+ torch.set_default_device("cuda")
bench_params = get_bench_params()
timers = []
for bp in tqdm(bench_params):
- timers.extend(
- bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
+ timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers)
# pickle all the results
@@ -172,5 +223,5 @@ def main():
pkl.dump(timers, f)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py
index 8d20b91560dd..e9934aa479dd 100644
--- a/benchmarks/kernels/benchmark_aqlm.py
+++ b/benchmarks/kernels/benchmark_aqlm.py
@@ -9,32 +9,39 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import (
- dequantize_weight, generic_dequantize_gemm, get_int_dtype,
- optimized_dequantize_gemm)
+ dequantize_weight,
+ generic_dequantize_gemm,
+ get_int_dtype,
+ optimized_dequantize_gemm,
+)
from vllm.utils import FlexibleArgumentParser
-os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def torch_mult(
- input: torch.Tensor, # [..., in_features]
- weights: torch.Tensor,
- scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ # [..., in_features]
+ input: torch.Tensor,
+ weights: torch.Tensor,
+ # [num_out_groups, 1, 1, 1]
+ scales: torch.Tensor,
) -> torch.Tensor:
output = F.linear(input, weights)
return output
def dequant_out_scale(
- input: torch.Tensor, # [..., in_features]
- codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
- codebooks: torch.
- Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
- scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ # [..., in_features]
+ input: torch.Tensor,
+ # [num_out_groups, num_in_groups, num_codebooks]
+ codes: torch.IntTensor,
+ # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ codebooks: torch.Tensor,
+ # [num_out_groups, 1, 1, 1]
+ scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
-
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
@@ -46,40 +53,42 @@ def dequant_out_scale(
flattened_output *= b_scales
return flattened_output.view(orig_shape)
else:
- b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
- -1, weights.shape[1])
+ b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_weight_scale(
- input: torch.Tensor, # [..., in_features]
- codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
- codebooks: torch.
- Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
- scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ # [..., in_features]
+ input: torch.Tensor,
+ # [num_out_groups, num_in_groups, num_codebooks]
+ codes: torch.IntTensor,
+ # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ codebooks: torch.Tensor,
+ # [num_out_groups, 1, 1, 1]
+ scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
-
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
- b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
- -1, weights.shape[1])
+ b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_no_scale(
- input: torch.Tensor, # [..., in_features]
- codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
- codebooks: torch.
- Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
- scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ # [..., in_features]
+ input: torch.Tensor,
+ # [num_out_groups, num_in_groups, num_codebooks]
+ codes: torch.IntTensor,
+ # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ codebooks: torch.Tensor,
+ # [num_out_groups, 1, 1, 1]
+ scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
-
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
return F.linear(input, weights, bias)
@@ -89,23 +98,26 @@ def dequant_no_scale(
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
-
n = int(parts.sum().item())
- device = torch.device('cuda:0')
+ device = torch.device("cuda:0")
code_range = (1 << bits) // 2
ingroups = 8
- codes = torch.randint(-code_range,
- code_range,
- size=(n, k // ingroups, nbooks),
- dtype=get_int_dtype(bits),
- device=device)
+ codes = torch.randint(
+ -code_range,
+ code_range,
+ size=(n, k // ingroups, nbooks),
+ dtype=get_int_dtype(bits),
+ device=device,
+ )
- codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
- dtype=torch.float16,
- device=device)
+ codebooks = torch.randn(
+ size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
+ dtype=torch.float16,
+ device=device,
+ )
count = 0
for index in range(16):
@@ -138,24 +150,25 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
def main():
-
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
# Add arguments
- parser.add_argument("--nbooks",
- type=int,
- default=1,
- help="Number of codebooks (default: 1)")
- parser.add_argument("--bits",
- type=int,
- default=16,
- help="Number of bits per code element (default: 16)")
+ parser.add_argument(
+ "--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
+ )
+ parser.add_argument(
+ "--bits",
+ type=int,
+ default=16,
+ help="Number of bits per code element (default: 16)",
+ )
parser.add_argument(
"--test",
type=bool,
default=False,
help="Run the decompression/dequant tester rather than benchmarking "
- "(default: False)")
+ "(default: False)",
+ )
# Parse the arguments
args = parser.parse_args()
@@ -165,7 +178,7 @@ def main():
bits = args.bits
if args.test:
- dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
+ dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
return
# Otherwise, benchmark.
@@ -184,31 +197,54 @@ def main():
with open(filename, "w") as f:
sys.stdout = f
- print('m | k | n | n parts', end='')
+ print("m | k | n | n parts", end="")
for method in methods:
- print(f" | {method.__name__.replace('_', ' ')} (ยตs)", end='')
- print('')
+ print(f" | {method.__name__.replace('_', ' ')} (ยตs)", end="")
+ print("")
# These are reasonable prefill sizes.
- ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
- (4096, (11008, 11008)), (11008, (4096, )))
+ ksandpartions = (
+ (4096, (4096, 4096, 4096)),
+ (4096, (4096,)),
+ (4096, (11008, 11008)),
+ (11008, (4096,)),
+ )
# reasonable ranges for m.
for m in [
- 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
- 128, 256, 512, 1024, 1536, 2048, 3072, 4096
+ 1,
+ 2,
+ 4,
+ 8,
+ 10,
+ 12,
+ 14,
+ 16,
+ 24,
+ 32,
+ 48,
+ 52,
+ 56,
+ 64,
+ 96,
+ 112,
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1536,
+ 2048,
+ 3072,
+ 4096,
]:
- print(f'{m}', file=sys.__stdout__)
+ print(f"{m}", file=sys.__stdout__)
for ksp in ksandpartions:
- run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
- methods)
+ run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
sys.stdout = sys.__stdout__
-def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
- methods):
-
+def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1
num_trials = 1
@@ -229,7 +265,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
)
n = parts.sum().item()
- print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
+ print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
for method in methods:
best_time_us = 1e20
@@ -249,32 +285,36 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us
- print(f' | {kernel_dur_us:.0f}', end='')
+ print(f" | {kernel_dur_us:.0f}", end="")
- print('')
+ print("")
-def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
- nbooks: int, bits: int, method) -> float:
-
+def run_timing(
+ num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
+) -> float:
n = int(parts.sum().item())
- device = torch.device('cuda:0')
+ device = torch.device("cuda:0")
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
code_range = (1 << bits) // 2
ingroups = 8
- codes = torch.randint(-code_range,
- code_range,
- size=(n, k // ingroups, nbooks),
- dtype=get_int_dtype(bits),
- device=device)
-
- codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
- dtype=torch.float16,
- device=device)
+ codes = torch.randint(
+ -code_range,
+ code_range,
+ size=(n, k // ingroups, nbooks),
+ dtype=get_int_dtype(bits),
+ device=device,
+ )
+
+ codebooks = torch.randn(
+ size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
+ dtype=torch.float16,
+ device=device,
+ )
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py
index b23b4f3ea685..d40ab70ec539 100644
--- a/benchmarks/kernels/benchmark_bitblas.py
+++ b/benchmarks/kernels/benchmark_bitblas.py
@@ -3,27 +3,33 @@
# Licensed under the MIT License.
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
- MINIMUM_BITBLAS_VERSION)
+ MINIMUM_BITBLAS_VERSION,
+)
try:
import bitblas
+
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
- raise ImportError("bitblas version is wrong. Please "
- f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
+ raise ImportError(
+ "bitblas version is wrong. Please "
+ f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
+ )
except ImportError as e:
bitblas_import_exception = e
- raise ValueError("Trying to use the bitblas backend, but could not import"
- f"with the following error: {bitblas_import_exception}. "
- "Please install bitblas through the following command: "
- f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
- ) from bitblas_import_exception
+ raise ValueError(
+ "Trying to use the bitblas backend, but could not import"
+ f"with the following error: {bitblas_import_exception}. "
+ "Please install bitblas through the following command: "
+ f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
+ ) from bitblas_import_exception
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
from vllm.utils import FlexibleArgumentParser
parser = FlexibleArgumentParser(
- description="Benchmark BitBLAS int4 on a specific target.")
+ description="Benchmark BitBLAS int4 on a specific target."
+)
# Add arguments to the parser
parser.add_argument(
@@ -32,10 +38,9 @@
default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking.",
)
-parser.add_argument("--group_size",
- type=int,
- default=None,
- help="Group size for grouped quantization.")
+parser.add_argument(
+ "--group_size", type=int, default=None, help="Group size for grouped quantization."
+)
parser.add_argument(
"--A_dtype",
type=str,
@@ -82,17 +87,17 @@
choices=["nt", "nn"],
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
)
-parser.add_argument("--with_bias",
- action="store_true",
- help="Include bias in the benchmark.")
+parser.add_argument(
+ "--with_bias", action="store_true", help="Include bias in the benchmark."
+)
parser.add_argument(
"--with_scaling",
action="store_true",
help="Include scaling factor in the quantization.",
)
-parser.add_argument("--with_zeros",
- action="store_true",
- help="Include zeros in the quantization.")
+parser.add_argument(
+ "--with_zeros", action="store_true", help="Include zeros in the quantization."
+)
parser.add_argument(
"--zeros_mode",
type=str,
@@ -170,8 +175,7 @@
]
# Build test shapes with all the shared arguments
-test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
- for shape in shapes]
+test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes]
benchmark_sets = []
benchmark_sets.extend(test_shapes)
@@ -206,12 +210,12 @@
func_name = args_split[0]
input_args_str = "-".join(args_split[1:])
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
- col_widths[1] = max(col_widths[1],
- len(input_args_str) + 2,
- len(headers[1]) + 2)
- col_widths[2] = max(col_widths[2],
- len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
- len(headers[2]) + 2)
+ col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2)
+ col_widths[2] = max(
+ col_widths[2],
+ len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
+ len(headers[2]) + 2,
+ )
# break only if you want to measure widths from a single example;
# otherwise, let it loop over all items.
@@ -232,5 +236,6 @@
f"{values['BitBLAS_top20_latency']:.3f} ms",
]
row_str = "".join(
- [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
+ [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]
+ )
print(row_str)
diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py
new file mode 100644
index 000000000000..d39d8a6e3aba
--- /dev/null
+++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py
@@ -0,0 +1,489 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
+kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
+activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
+and 16-bit activations.
+"""
+
+import nvtx
+import torch
+import torch.utils.benchmark as benchmark
+
+from vllm import _custom_ops as ops
+from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
+from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
+from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
+from vllm.scalar_type import scalar_types
+from vllm.utils import FlexibleArgumentParser
+
+WEIGHT_SHAPES_MOE = {
+ "nvidia/DeepSeek-R1-FP4": [
+ [256, 8, 2048, 7168],
+ ],
+}
+
+DEFAULT_MODELS = [
+ "nvidia/DeepSeek-R1-FP4",
+]
+
+DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
+DEFAULT_TP_SIZES = [1]
+
+PER_ACT_TOKEN_OPTS = [False]
+PER_OUT_CH_OPTS = [False]
+FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
+FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
+
+
+def to_fp8(tensor: torch.Tensor):
+ finfo = torch.finfo(torch.float8_e4m3fn)
+ return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
+ dtype=torch.float8_e4m3fn
+ )
+
+
+def bench_run(
+ results: list[benchmark.Measurement],
+ model: str,
+ num_experts: int,
+ topk: int,
+ per_act_token: bool,
+ per_out_ch: bool,
+ mkn: tuple[int, int, int],
+):
+ label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
+
+ sub_label = (
+ "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
+ model, num_experts, topk, per_act_token, per_out_ch, mkn
+ )
+ )
+
+ print(f"Testing: {sub_label}")
+
+ (m, k, n) = mkn
+
+ dtype = torch.half
+ device = "cuda"
+ a = torch.randn((m, k), device=device, dtype=dtype) / 10
+ w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
+ w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
+
+ _, a_fp8_scale = ops.scaled_fp8_quant(a)
+
+ w1_fp8q = torch.empty(
+ (num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn
+ )
+ w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn)
+ w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
+ w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
+
+ for expert in range(num_experts):
+ w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert])
+ w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert])
+
+ w1_fp8q_notransp = w1_fp8q.clone()
+ w2_fp8q_notransp = w2_fp8q.clone()
+ w1_fp8q = w1_fp8q.transpose(1, 2)
+ w2_fp8q = w2_fp8q.transpose(1, 2)
+
+ score = torch.randn((m, num_experts), device=device, dtype=dtype)
+
+ topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
+
+ quant_blocksize = 16
+ w1_blockscale = torch.empty(
+ (num_experts, 2 * n, k // quant_blocksize),
+ device=device,
+ dtype=torch.float8_e4m3fn,
+ )
+ w2_blockscale = torch.empty(
+ (num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn
+ )
+
+ # n_b_scales = 2 * n if per_out_ch else 1
+ # k_b_scales = k if per_out_ch else 1
+ w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8)
+ w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8)
+
+ w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
+ w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
+ a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
+ a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
+
+ for expert in range(num_experts):
+ w1_e = w1[expert]
+ w2_e = w2[expert]
+ w1_amax = torch.abs(w1_e).max().to(torch.float32)
+ w2_amax = torch.abs(w2_e).max().to(torch.float32)
+ w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
+ w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
+
+ w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
+ w1_e, w1_gs[expert]
+ )
+
+ w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
+ w2_e, w2_gs[expert]
+ )
+
+ def run_triton_moe(
+ a: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ a_fp8_scale: torch.Tensor,
+ num_repeats: int,
+ ):
+ for _ in range(num_repeats):
+ fused_experts(
+ a,
+ w1,
+ w2,
+ topk_weights,
+ topk_ids,
+ use_fp8_w8a8=True,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a_fp8_scale,
+ )
+
+ def run_cutlass_moe_fp4(
+ a: torch.Tensor,
+ w1_fp4: torch.Tensor,
+ w2_fp4: torch.Tensor,
+ w1_blockscale: torch.Tensor,
+ w2_blockscale: torch.Tensor,
+ w1_gs: torch.Tensor,
+ w2_gs: torch.Tensor,
+ a1_gs: torch.Tensor,
+ a2_gs: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ m: int,
+ n: int,
+ k: int,
+ e: int,
+ device: torch.device,
+ num_repeats: int,
+ ):
+ for _ in range(num_repeats):
+ with nvtx.annotate("cutlass_moe_fp4", color="green"):
+ cutlass_moe_fp4(
+ a=a,
+ a1_gscale=a1_gs,
+ a2_gscale=a2_gs,
+ w1_fp4=w1_fp4,
+ w1_blockscale=w1_blockscale,
+ w1_alphas=w1_gs,
+ w2_fp4=w2_fp4,
+ w2_blockscale=w2_blockscale,
+ w2_alphas=w2_gs,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ m=m,
+ n=n,
+ k=k,
+ e=num_experts,
+ device=device,
+ )
+
+ def run_cutlass_from_graph(
+ a: torch.Tensor,
+ a1_gscale: torch.Tensor,
+ w1_fp4: torch.Tensor,
+ w1_blockscale: torch.Tensor,
+ w1_alphas: torch.Tensor,
+ a2_gscale: torch.Tensor,
+ w2_fp4: torch.Tensor,
+ w2_blockscale: torch.Tensor,
+ w2_alphas: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ m: int,
+ n: int,
+ k: int,
+ e: int,
+ device: torch.device,
+ ):
+ with set_current_vllm_config(
+ VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
+ ):
+ return cutlass_moe_fp4(
+ a=a,
+ a1_gscale=a1_gs,
+ w1_fp4=w1_fp4,
+ w1_blockscale=w1_blockscale,
+ w1_alphas=w1_alphas,
+ a2_gscale=a2_gs,
+ w2_fp4=w2_fp4,
+ w2_blockscale=w2_blockscale,
+ w2_alphas=w2_alphas,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ m=m,
+ n=n,
+ k=k,
+ e=num_experts,
+ device=device,
+ )
+
+ def run_triton_from_graph(
+ a: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ a_fp8_scale: torch.Tensor,
+ ):
+ with set_current_vllm_config(
+ VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
+ ):
+ return fused_experts(
+ a,
+ w1,
+ w2,
+ topk_weights,
+ topk_ids,
+ use_fp8_w8a8=True,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a_fp8_scale,
+ )
+
+ def replay_graph(graph, num_repeats):
+ for _ in range(num_repeats):
+ graph.replay()
+ torch.cuda.synchronize()
+
+ cutlass_stream = torch.cuda.Stream()
+ cutlass_graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
+ run_cutlass_from_graph(
+ a=a,
+ a1_gscale=a1_gs,
+ w1_fp4=w1_fp4,
+ w1_blockscale=w1_blockscale,
+ w1_alphas=w1_gs,
+ a2_gscale=a2_gs,
+ w2_fp4=w2_fp4,
+ w2_blockscale=w2_blockscale,
+ w2_alphas=w2_gs,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ m=m,
+ n=n,
+ k=k,
+ e=num_experts,
+ device=device,
+ )
+ torch.cuda.synchronize()
+
+ triton_stream = torch.cuda.Stream()
+ triton_graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(triton_graph, stream=triton_stream):
+ run_triton_from_graph(
+ a,
+ w1_fp8q_notransp,
+ w2_fp8q_notransp,
+ topk_weights,
+ topk_ids,
+ w1_fp8scale,
+ w2_fp8scale,
+ a_fp8_scale,
+ )
+ torch.cuda.synchronize()
+
+ min_run_time = 5
+ num_warmup = 5
+ num_runs = 25
+
+ globals = {
+ # Baseline params
+ "w1": w1,
+ "w2": w2,
+ "score": score,
+ "topk": topk,
+ "w1_fp8q_notransp": w1_fp8q_notransp,
+ "w2_fp8q_notransp": w2_fp8q_notransp,
+ "w1_fp8scale": w1_fp8scale,
+ "w2_fp8scale": w2_fp8scale,
+ "a_fp8_scale": a_fp8_scale,
+ # Cutlass params
+ "a": a,
+ "a1_gscale": a1_gs,
+ "w1_fp4": w1_fp4,
+ "w1_blockscale": w1_blockscale,
+ "w1_alphas": w1_gs,
+ "a2_gscale": a2_gs,
+ "w2_fp4": w2_fp4,
+ "w2_blockscale": w2_blockscale,
+ "w2_alphas": w2_gs,
+ "topk_weights": topk_weights,
+ "topk_ids": topk_ids,
+ "m": m,
+ "n": n,
+ "k": k,
+ "e": num_experts,
+ "device": device,
+ # cuda graph params
+ "cutlass_graph": cutlass_graph,
+ "triton_graph": triton_graph,
+ # Gen params
+ "num_runs": num_runs,
+ # Kernels
+ "run_triton_moe": run_triton_moe,
+ "run_cutlass_moe_fp4": run_cutlass_moe_fp4,
+ "replay_graph": replay_graph,
+ }
+
+ # Warmup
+ run_triton_moe(
+ a,
+ w1_fp8q_notransp,
+ w2_fp8q_notransp,
+ topk_weights,
+ topk_ids,
+ w1_fp8scale,
+ w2_fp8scale,
+ a_fp8_scale,
+ num_warmup,
+ )
+
+ results.append(
+ benchmark.Timer(
+ stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
+ globals=globals,
+ label=label,
+ sub_label=sub_label,
+ description="triton_moe",
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
+
+ # Warmup
+ replay_graph(triton_graph, num_warmup)
+
+ results.append(
+ benchmark.Timer(
+ stmt="replay_graph(triton_graph, num_runs)",
+ globals=globals,
+ label=label,
+ sub_label=sub_label,
+ description="triton_moe_cuda_graphs",
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
+
+ # Warmup
+
+ run_cutlass_moe_fp4(
+ a,
+ w1_fp4,
+ w2_fp4,
+ w1_blockscale,
+ w2_blockscale,
+ w1_gs,
+ w2_gs,
+ a1_gs,
+ a2_gs,
+ topk_weights,
+ topk_ids,
+ m,
+ n,
+ k,
+ num_experts,
+ device,
+ num_warmup,
+ )
+
+ results.append(
+ benchmark.Timer(
+ stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
+ globals=globals,
+ label=label,
+ sub_label=sub_label,
+ description="cutlass_moe_fp4",
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
+
+ # Warmup
+ replay_graph(cutlass_graph, num_warmup)
+
+ results.append(
+ benchmark.Timer(
+ stmt="replay_graph(cutlass_graph, num_runs)",
+ globals=globals,
+ label=label,
+ sub_label=sub_label,
+ description="cutlass_moe_fp4_cuda_graphs",
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
+
+
+def main(args):
+ print("Benchmarking models:")
+ for i, model in enumerate(args.models):
+ print(f"[{i}] {model}")
+
+ results: list[benchmark.Measurement] = []
+
+ for model in args.models:
+ for tp in args.tp_sizes:
+ for layer in WEIGHT_SHAPES_MOE[model]:
+ num_experts = layer[0]
+ topk = layer[1]
+ size_k = layer[2]
+ size_n = layer[3] // tp
+
+ if len(args.limit_k) > 0 and size_k not in args.limit_k:
+ continue
+
+ if len(args.limit_n) > 0 and size_n not in args.limit_n:
+ continue
+
+ for per_act_token in PER_ACT_TOKEN_OPTS:
+ for per_out_ch in PER_OUT_CH_OPTS:
+ for size_m in args.batch_sizes:
+ mkn = (size_m, size_k, size_n)
+ bench_run(
+ results,
+ model,
+ num_experts,
+ topk,
+ per_act_token,
+ per_out_ch,
+ mkn,
+ )
+
+ compare = benchmark.Compare(results)
+ compare.print()
+
+
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser(
+ description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
+ )
+ parser.add_argument(
+ "--models",
+ nargs="+",
+ type=str,
+ default=DEFAULT_MODELS,
+ choices=WEIGHT_SHAPES_MOE.keys(),
+ )
+ parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
+ parser.add_argument(
+ "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
+ )
+ parser.add_argument("--limit-k", nargs="+", type=int, default=[])
+ parser.add_argument("--limit-n", nargs="+", type=int, default=[])
+ parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
+ parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
+ parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
+
+ args = parser.parse_args()
+ main(args)
diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
index c92ea43e8260..2197bceabe6c 100644
--- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
@@ -6,14 +6,18 @@
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
-from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
- fused_experts,
- fused_topk)
+from vllm.model_executor.layers.fused_moe.fused_moe import (
+ cutlass_moe_fp8,
+ fused_experts,
+ fused_topk,
+)
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = [
- "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
- "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
+ "nm-testing/Mixtral-8x7B-Instruct-v0.1",
+ "nm-testing/deepseekv2-lite",
+ "ibm-granite/granite-3.0-1b-a400m",
+ "ibm-granite/granite-3.0-3b-a800m",
]
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
@@ -24,19 +28,27 @@
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
- return torch.round(tensor.clamp(
- min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
+ return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
+ dtype=torch.float8_e4m3fn
+ )
-def bench_run(results: list[benchmark.Measurement], model: str,
- num_experts: int, topk: int, per_act_token: bool,
- per_out_ch: bool, mkn: tuple[int, int, int]):
+def bench_run(
+ results: list[benchmark.Measurement],
+ model: str,
+ num_experts: int,
+ topk: int,
+ per_act_token: bool,
+ per_out_ch: bool,
+ mkn: tuple[int, int, int],
+):
label = "Quant Matmul"
sub_label = (
- "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
- "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
- mkn))
+ "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
+ model, num_experts, topk, per_act_token, per_out_ch, mkn
+ )
+ )
print(f"Testing: {sub_label}")
@@ -50,35 +62,17 @@ def bench_run(results: list[benchmark.Measurement], model: str,
_, a_scale = ops.scaled_fp8_quant(a)
- w1_q = torch.empty((num_experts, 2 * n, k),
- device="cuda",
- dtype=torch.float8_e4m3fn)
- w2_q = torch.empty((num_experts, k, n),
- device="cuda",
- dtype=torch.float8_e4m3fn)
- w1_scale = torch.empty((num_experts, 1, 1),
- device="cuda",
- dtype=torch.float32)
- w2_scale = torch.empty((num_experts, 1, 1),
- device="cuda",
- dtype=torch.float32)
-
- ab_strides1 = torch.full((num_experts, ),
- k,
- device="cuda",
- dtype=torch.int64)
- c_strides1 = torch.full((num_experts, ),
- 2 * n,
- device="cuda",
- dtype=torch.int64)
- ab_strides2 = torch.full((num_experts, ),
- n,
- device="cuda",
- dtype=torch.int64)
- c_strides2 = torch.full((num_experts, ),
- k,
- device="cuda",
- dtype=torch.int64)
+ w1_q = torch.empty(
+ (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
+ )
+ w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
+ w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
+ w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
+
+ ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
+ c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
+ ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
+ c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
@@ -91,82 +85,120 @@ def bench_run(results: list[benchmark.Measurement], model: str,
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk(
- a, score, topk, renormalize=False)
+ a, score, topk, renormalize=False
+ )
- def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
- topk_weights: torch.Tensor, topk_ids: torch.Tensor,
- w1_scale: torch.Tensor, w2_scale: torch.Tensor,
- a_scale: torch.Tensor, num_repeats: int):
+ def run_triton_moe(
+ a: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ a_scale: torch.Tensor,
+ num_repeats: int,
+ ):
for _ in range(num_repeats):
- fused_experts(a,
- w1,
- w2,
- topk_weights,
- topk_ids,
- use_fp8_w8a8=True,
- w1_scale=w1_scale,
- w2_scale=w2_scale,
- a1_scale=a_scale)
-
- def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
- w1: torch.Tensor, w2: torch.Tensor,
- w1_scale: torch.Tensor, w2_scale: torch.Tensor,
- topk_weights: torch.Tensor, topk_ids: torch.Tensor,
- ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
- ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
- num_repeats: int):
+ fused_experts(
+ a,
+ w1,
+ w2,
+ topk_weights,
+ topk_ids,
+ use_fp8_w8a8=True,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a_scale,
+ )
+
+ def run_cutlass_moe(
+ a: torch.Tensor,
+ a_scale: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ ab_strides1: torch.Tensor,
+ c_strides1: torch.Tensor,
+ ab_strides2: torch.Tensor,
+ c_strides2: torch.Tensor,
+ num_repeats: int,
+ ):
for _ in range(num_repeats):
- cutlass_moe_fp8(a,
- w1,
- w2,
- w1_scale,
- w2_scale,
- topk_weights,
- topk_ids,
- ab_strides1,
- c_strides1,
- ab_strides2,
- c_strides2,
- a1_scale=a_scale)
+ cutlass_moe_fp8(
+ a,
+ w1,
+ w2,
+ w1_scale,
+ w2_scale,
+ topk_weights,
+ topk_ids,
+ ab_strides1,
+ c_strides1,
+ ab_strides2,
+ c_strides2,
+ a1_scale=a_scale,
+ )
def run_cutlass_from_graph(
- a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
- w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
- topk_weights: torch.Tensor, topk_ids: torch.Tensor,
- ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
- ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
+ a: torch.Tensor,
+ a_scale: torch.Tensor,
+ w1_q: torch.Tensor,
+ w2_q: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ ab_strides1: torch.Tensor,
+ c_strides1: torch.Tensor,
+ ab_strides2: torch.Tensor,
+ c_strides2: torch.Tensor,
+ ):
with set_current_vllm_config(
- VllmConfig(parallel_config=ParallelConfig(
- pipeline_parallel_size=1))):
- return cutlass_moe_fp8(a,
- w1_q,
- w2_q,
- w1_scale,
- w2_scale,
- topk_weights,
- topk_ids,
- ab_strides1,
- c_strides1,
- ab_strides2,
- c_strides2,
- a1_scale=a_scale)
-
- def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
- w2: torch.Tensor, topk_weights: torch.Tensor,
- topk_ids: torch.Tensor, w1_scale: torch.Tensor,
- w2_scale: torch.Tensor, a_scale: torch.Tensor):
+ VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
+ ):
+ return cutlass_moe_fp8(
+ a,
+ w1_q,
+ w2_q,
+ w1_scale,
+ w2_scale,
+ topk_weights,
+ topk_ids,
+ ab_strides1,
+ c_strides1,
+ ab_strides2,
+ c_strides2,
+ a1_scale=a_scale,
+ )
+
+ def run_triton_from_graph(
+ a: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ a_scale: torch.Tensor,
+ ):
with set_current_vllm_config(
- VllmConfig(parallel_config=ParallelConfig(
- pipeline_parallel_size=1))):
- return fused_experts(a,
- w1,
- w2,
- topk_weights,
- topk_ids,
- use_fp8_w8a8=True,
- w1_scale=w1_scale,
- w2_scale=w2_scale,
- a1_scale=a_scale)
+ VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
+ ):
+ return fused_experts(
+ a,
+ w1,
+ w2,
+ topk_weights,
+ topk_ids,
+ use_fp8_w8a8=True,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a_scale,
+ )
def replay_graph(graph, num_repeats):
for _ in range(num_repeats):
@@ -176,16 +208,35 @@ def replay_graph(graph, num_repeats):
cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
- run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
- topk_weights, topk_ids, ab_strides1, c_strides1,
- ab_strides2, c_strides2)
+ run_cutlass_from_graph(
+ a,
+ a_scale,
+ w1_q,
+ w2_q,
+ w1_scale,
+ w2_scale,
+ topk_weights,
+ topk_ids,
+ ab_strides1,
+ c_strides1,
+ ab_strides2,
+ c_strides2,
+ )
torch.cuda.synchronize()
triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
- run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
- topk_ids, w1_scale, w2_scale, a_scale)
+ run_triton_from_graph(
+ a,
+ w1_q_notransp,
+ w2_q_notransp,
+ topk_weights,
+ topk_ids,
+ w1_scale,
+ w2_scale,
+ a_scale,
+ )
torch.cuda.synchronize()
min_run_time = 5
@@ -225,18 +276,27 @@ def replay_graph(graph, num_repeats):
}
# Warmup
- run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
- w1_scale, w2_scale, a_scale, num_warmup)
+ run_triton_moe(
+ a,
+ w1_q_notransp,
+ w2_q_notransp,
+ topk_weights,
+ topk_ids,
+ w1_scale,
+ w2_scale,
+ a_scale,
+ num_warmup,
+ )
results.append(
benchmark.Timer(
- stmt=
- "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
+ stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
# Warmup
replay_graph(triton_graph, num_warmup)
@@ -248,22 +308,35 @@ def replay_graph(graph, num_repeats):
label=label,
sub_label=sub_label,
description="triton_moe_cuda_graphs",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
# Warmup
- run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
- topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
- num_warmup)
+ run_cutlass_moe(
+ a,
+ a_scale,
+ w1_q,
+ w2_q,
+ w1_scale,
+ w2_scale,
+ topk_weights,
+ topk_ids,
+ ab_strides1,
+ c_strides1,
+ ab_strides2,
+ c_strides2,
+ num_warmup,
+ )
results.append(
benchmark.Timer(
- stmt=
- "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
+ stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
# Warmup
replay_graph(cutlass_graph, num_warmup)
@@ -275,7 +348,8 @@ def replay_graph(graph, num_repeats):
label=label,
sub_label=sub_label,
description="grouped_gemm_moe_cuda_graphs",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
def main(args):
@@ -303,8 +377,15 @@ def main(args):
for per_out_ch in PER_OUT_CH_OPTS:
for size_m in DEFAULT_BATCH_SIZES:
mkn = (size_m, size_k, size_n)
- bench_run(results, model, num_experts, topk,
- per_act_token, per_out_ch, mkn)
+ bench_run(
+ results,
+ model,
+ num_experts,
+ topk,
+ per_act_token,
+ per_out_ch,
+ mkn,
+ )
compare = benchmark.Compare(results)
compare.print()
@@ -312,7 +393,8 @@ def main(args):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description="Benchmark Marlin across specified models/shapes/batches")
+ description="Benchmark Marlin across specified models/shapes/batches"
+ )
parser.add_argument(
"--models",
nargs="+",
@@ -320,21 +402,14 @@ def main(args):
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
- parser.add_argument("--tp-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_TP_SIZES)
- parser.add_argument("--batch-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_BATCH_SIZES)
+ parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
+ parser.add_argument(
+ "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
+ )
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
- parser.add_argument("--limit-per-act-token",
- nargs="+",
- type=int,
- default=[])
+ parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
args = parser.parse_args()
diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py
index e12d74c01e43..f21ca97eeb8a 100644
--- a/benchmarks/kernels/benchmark_layernorm.py
+++ b/benchmarks/kernels/benchmark_layernorm.py
@@ -10,14 +10,16 @@
@torch.inference_mode()
-def main(num_tokens: int,
- hidden_size: int,
- add_residual: bool,
- dtype: torch.dtype,
- seed: int = 0,
- do_profile: bool = False,
- num_warmup_iters: int = 5,
- num_iters: int = 100) -> None:
+def main(
+ num_tokens: int,
+ hidden_size: int,
+ add_residual: bool,
+ dtype: torch.dtype,
+ seed: int = 0,
+ do_profile: bool = False,
+ num_warmup_iters: int = 5,
+ num_iters: int = 100,
+) -> None:
current_platform.seed_everything(seed)
torch.set_default_device("cuda")
@@ -56,33 +58,35 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
print(f"Kernel running time: {latency * 1000000:.3f} us")
-if __name__ == '__main__':
- parser = FlexibleArgumentParser(
- description="Benchmark the layernorm kernel.")
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--add-residual", action="store_true")
- parser.add_argument("--dtype",
- type=str,
- choices=["half", "bfloat16", "float"],
- default="half")
+ parser.add_argument(
+ "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
+ )
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
- parser.add_argument("--num-iters",
- type=int,
- default=100,
- help="Number of benchmark iterations. "
- "If --profile is set, this number is ignored")
+ parser.add_argument(
+ "--num-iters",
+ type=int,
+ default=100,
+ help="Number of benchmark iterations. "
+ "If --profile is set, this number is ignored",
+ )
args = parser.parse_args()
print(args)
- main(num_tokens=args.num_tokens,
- hidden_size=args.hidden_size,
- add_residual=args.add_residual,
- dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
- seed=args.seed,
- do_profile=args.profile,
- num_warmup_iters=args.num_warmup_iters,
- num_iters=args.num_iters)
+ main(
+ num_tokens=args.num_tokens,
+ hidden_size=args.hidden_size,
+ add_residual=args.add_residual,
+ dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
+ seed=args.seed,
+ do_profile=args.profile,
+ num_warmup_iters=args.num_warmup_iters,
+ num_iters=args.num_iters,
+ )
diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py
index d382ede10b41..6c1284930c1e 100644
--- a/benchmarks/kernels/benchmark_lora.py
+++ b/benchmarks/kernels/benchmark_lora.py
@@ -20,18 +20,36 @@
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
- from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
- lora_shrink)
- from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
- _LORA_B_PTR_DICT)
+ from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
+ from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_TP_SIZES = [1]
DEFAULT_BATCH_SIZES = [
- 1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024,
- 2048, 3072, 4096, 5120, 6144, 7168, 8192
+ 1,
+ 16,
+ 32,
+ 64,
+ 128,
+ 192,
+ 256,
+ 320,
+ 384,
+ 448,
+ 512,
+ 640,
+ 768,
+ 896,
+ 1024,
+ 2048,
+ 3072,
+ 4096,
+ 5120,
+ 6144,
+ 7168,
+ 8192,
]
DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384]
DEFAULT_LORA_RANKS = [16]
@@ -52,12 +70,9 @@ def dtype_to_str(dtype: torch.dtype):
raise ValueError(f"Unsupported dtype {dtype}")
-def make_rand_lora_weight_tensor(k: int,
- n: int,
- num_loras: int,
- dtype: torch.dtype,
- device: str = "cuda") -> torch.Tensor:
-
+def make_rand_lora_weight_tensor(
+ k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda"
+) -> torch.Tensor:
# LoRA weights column major
return torch.rand((num_loras, n, k), dtype=dtype).to(device)
@@ -78,18 +93,15 @@ def make_rand_tensors(
A = torch.rand(a_shape, dtype=a_dtype).to(device)
# LoRA weights column major
- Bs = [
- torch.rand(b_shape, dtype=b_dtype).to(device)
- for _ in range(num_slices)
- ]
+ Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)]
C = torch.zeros(c_shape, dtype=c_dtype).to(device)
return A, Bs, C
-def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
- sort_by_lora_id: bool,
- device: str) -> torch.Tensor:
+def make_prompt_lora_mapping(
+ num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str
+) -> torch.Tensor:
"""
All prompts are mapped to a LoRA ID in range [0, num_active_loras).
where 0 refers to first lora, 1 refers to second lora and so on.
@@ -97,9 +109,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
assert num_active_loras > 0
if not sort_by_lora_id:
- return torch.randint(0,
- num_active_loras, (num_prompts, ),
- dtype=torch.long)
+ return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long)
# Divide LoRAs equally and in order.
part_size = num_prompts // num_active_loras
@@ -110,14 +120,18 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
while len(prompt_lora_mapping) < num_prompts:
prompt_lora_mapping.extend([lora_id] * part_size)
lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id
- return torch.tensor(prompt_lora_mapping[:num_prompts],
- dtype=torch.long,
- device=device)
-
-
-def make_token_lora_mapping(num_tokens: int, num_prompts: int,
- prompt_lora_mapping: torch.Tensor,
- seq_len_tensor: torch.Tensor, device: str):
+ return torch.tensor(
+ prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device
+ )
+
+
+def make_token_lora_mapping(
+ num_tokens: int,
+ num_prompts: int,
+ prompt_lora_mapping: torch.Tensor,
+ seq_len_tensor: torch.Tensor,
+ device: str,
+):
"""
Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
"""
@@ -136,11 +150,15 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int,
return torch.tensor(token_lora_mapping, dtype=torch.long, device=device)
-def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
- lora_weights: list[torch.Tensor],
- seq_lens_cpu: torch.Tensor,
- prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
- add_inputs: Optional[bool]):
+def ref_group_gemm(
+ ref_out: torch.Tensor,
+ input: torch.Tensor,
+ lora_weights: list[torch.Tensor],
+ seq_lens_cpu: torch.Tensor,
+ prompt_lora_mapping_cpu: torch.Tensor,
+ scaling: float,
+ add_inputs: Optional[bool],
+):
"""
Torch group gemm reference implementation to test correctness of
benchmarking operations.
@@ -149,7 +167,7 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
out_list = []
current_offset = 0
for lora_index, b_length in zip(range(batches), seq_lens_cpu):
- x = input[current_offset:b_length + current_offset, :]
+ x = input[current_offset : b_length + current_offset, :]
current_offset += b_length
w = lora_weights[prompt_lora_mapping_cpu[lora_index]]
result = torch.nn.functional.linear(x, w)
@@ -168,6 +186,7 @@ class OpType(Enum):
"""
LoRA Ops to benchmark and its properties.
"""
+
LORA_SHRINK = auto()
LORA_EXPAND = auto()
@@ -188,8 +207,9 @@ def is_expand_fn(self) -> bool:
def num_slices(self) -> list[int]:
return [1, 2, 3]
- def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
- lora_rank: int) -> tuple[int, int, int]:
+ def mkn(
+ self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
+ ) -> tuple[int, int, int]:
num_tokens = batch_size * seq_length
if self.is_shrink_fn():
m = num_tokens
@@ -203,7 +223,7 @@ def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
return m, k, n
def matmul_dtypes(
- self, op_dtype: torch.dtype
+ self, op_dtype: torch.dtype
) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
"""
return a type, b type and c type for A x B = C
@@ -215,9 +235,14 @@ def matmul_dtypes(
return torch.float32, op_dtype, op_dtype
def matmul_shapes(
- self, batch_size: int, seq_length: int, hidden_size: int,
- lora_rank: int, num_loras: int,
- num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]:
+ self,
+ batch_size: int,
+ seq_length: int,
+ hidden_size: int,
+ lora_rank: int,
+ num_loras: int,
+ num_slices: int,
+ ) -> tuple[tuple[int], tuple[int], tuple[int]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
in A x B = C, for the op_type
@@ -241,31 +266,38 @@ def bench_fn(self) -> Callable:
raise ValueError(f"Unrecognized optype {self}")
- def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
- lora_weights: list[torch.Tensor],
- **kwargs) -> Callable:
+ def run_ref_group_gemm(
+ self,
+ output: torch.Tensor,
+ input: torch.Tensor,
+ lora_weights: list[torch.Tensor],
+ **kwargs,
+ ) -> Callable:
"""Each benchmark operation expects the input, lora_weights and outputs
- in a slightly different format. Refer to self.matmul_shapes().
- run_ref_group_gemm accounts for those differences in executing a
- reference group gemm for correctness testing.
+ in a slightly different format. Refer to self.matmul_shapes().
+ run_ref_group_gemm accounts for those differences in executing a
+ reference group gemm for correctness testing.
"""
w_dtype = lora_weights[0].dtype
num_slices = len(lora_weights)
if self in [OpType.LORA_SHRINK]:
for slice_idx in range(num_slices):
- ref_group_gemm(ref_out=output[slice_idx, :],
- input=input,
- lora_weights=lora_weights[slice_idx],
- **kwargs)
+ ref_group_gemm(
+ ref_out=output[slice_idx, :],
+ input=input,
+ lora_weights=lora_weights[slice_idx],
+ **kwargs,
+ )
elif self in [OpType.LORA_EXPAND]:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
ref_group_gemm(
- ref_out=output[:, slice_offset:slice_offset + hidden_size],
+ ref_out=output[:, slice_offset : slice_offset + hidden_size],
input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx],
- **kwargs)
+ **kwargs,
+ )
else:
raise ValueError(f"Unrecognized optype {self}")
@@ -275,6 +307,7 @@ class BenchmarkContext:
"""
LoRA benchmark context
"""
+
batch_size: int
hidden_size: int
num_loras: int
@@ -299,17 +332,18 @@ def bench_label(self) -> str:
return f"lora-{self.dtype}"
def bench_sublabel(self, op_type: OpType) -> str:
- m, k, n = op_type.mkn(self.batch_size, self.seq_length,
- self.hidden_size, self.lora_rank)
+ m, k, n = op_type.mkn(
+ self.batch_size, self.seq_length, self.hidden_size, self.lora_rank
+ )
desc = {
- 'bs': self.batch_size,
- 'sl': self.seq_length,
- 'm': m,
- 'k': k,
- 'n': n,
- 'num_loras': self.num_loras,
- 'sort_by_lora': self.sort_by_lora_id,
- 'num_slices': self.num_slices,
+ "bs": self.batch_size,
+ "sl": self.seq_length,
+ "m": m,
+ "k": k,
+ "n": n,
+ "num_loras": self.num_loras,
+ "sort_by_lora": self.sort_by_lora_id,
+ "num_slices": self.num_slices,
}
return json.dumps(desc)
@@ -319,6 +353,7 @@ class BenchmarkTensors:
"""
Input/Output tensors used for benchmarks
"""
+
# matmul tensors
input: torch.Tensor
lora_weights_lst: list[torch.Tensor]
@@ -330,23 +365,29 @@ class BenchmarkTensors:
prompt_lora_mapping: torch.Tensor
def io_types(self) -> str:
- return (f"{dtype_to_str(self.input.dtype)}x"
- f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>"
- f"{dtype_to_str(self.output.dtype)}")
+ return (
+ f"{dtype_to_str(self.input.dtype)}x"
+ f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>"
+ f"{dtype_to_str(self.output.dtype)}"
+ )
@staticmethod
- def make(ctx: BenchmarkContext,
- op_type: OpType,
- device: str = "cuda") -> "BenchmarkTensors":
-
+ def make(
+ ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
+ ) -> "BenchmarkTensors":
# Make input / output matmul tensors.
a_shape, b_shape, c_shape = op_type.matmul_shapes(
- ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank,
- ctx.num_loras, ctx.num_slices)
+ ctx.batch_size,
+ ctx.seq_length,
+ ctx.hidden_size,
+ ctx.lora_rank,
+ ctx.num_loras,
+ ctx.num_slices,
+ )
a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
- input_tensor, lora_weights, output_tensor = \
- make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type,
- num_slices = ctx.num_slices)
+ input_tensor, lora_weights, output_tensor = make_rand_tensors(
+ a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices
+ )
# Make metadata tensors.
# Keep the metadata tensors in the CPU for further processing if needed.
@@ -356,27 +397,38 @@ def make(ctx: BenchmarkContext,
# Make metadata tensors involved in correctness testing.
# Prepare seq lens tensor
- seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
- (ctx.batch_size, ))
+ seq_len_tensor = torch.randint(
+ ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,)
+ )
assert total_tokens == seq_len_tensor.sum()
# Prepare prompt lora indices tensor
prompt_lora_indices_tensor = make_prompt_lora_mapping(
- ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
+ ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu"
+ )
# Make LoRAKernelMeta
token_lora_indices_tensor = make_token_lora_mapping(
- total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
- seq_len_tensor, "cpu")
+ total_tokens,
+ ctx.batch_size,
+ prompt_lora_indices_tensor,
+ seq_len_tensor,
+ "cpu",
+ )
lora_kernel_meta = LoRAKernelMeta.make(
max_loras=ctx.num_loras,
max_num_tokens=token_lora_indices_tensor.size(0),
- device="cpu")
- lora_kernel_meta.prepare_tensors(
- token_lora_mapping=token_lora_indices_tensor)
-
- return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
- lora_kernel_meta, seq_len_tensor,
- prompt_lora_indices_tensor)
+ device="cpu",
+ )
+ lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor)
+
+ return BenchmarkTensors(
+ input_tensor,
+ lora_weights,
+ output_tensor,
+ lora_kernel_meta,
+ seq_len_tensor,
+ prompt_lora_indices_tensor,
+ )
def sanity_check(self) -> None:
"""
@@ -386,7 +438,7 @@ def sanity_check(self) -> None:
# check metadata tensors
assert torch.sum(self.seq_lens) == num_tokens
num_seqs = self.seq_lens.shape[0]
- #assert self.seq_start_loc.shape[0] == num_seqs
+ # assert self.seq_start_loc.shape[0] == num_seqs
assert self.prompt_lora_mapping.shape[0] == num_seqs
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
@@ -430,8 +482,11 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]:
_, num_tokens, _, num_slices = self.metadata()
# Sanity check matrix shapes.
- i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
- 0].shape, self.output.shape
+ i_shape, lw_shape, o_shape = (
+ self.input.shape,
+ self.lora_weights_lst[0].shape,
+ self.output.shape,
+ )
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
@@ -445,16 +500,17 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]:
assert o_shape == (num_slices, num_tokens, lora_rank)
return {
- 'inputs': self.input,
- 'lora_a_weights': self.lora_weights_lst,
- 'output_tensor': self.output,
- 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
- 'token_indices_sorted_by_lora_ids':
- self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
- 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
- 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
- 'lora_ids': self.lora_kernel_meta.active_lora_ids,
- 'scaling': 1.0,
+ "inputs": self.input,
+ "lora_a_weights": self.lora_weights_lst,
+ "output_tensor": self.output,
+ "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
+ "token_indices_sorted_by_lora_ids": (
+ self.lora_kernel_meta.token_indices_sorted_by_lora_ids
+ ),
+ "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
+ "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
+ "lora_ids": self.lora_kernel_meta.active_lora_ids,
+ "scaling": 1.0,
}
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
@@ -464,8 +520,11 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
_, num_tokens, _, num_slices = self.metadata()
# Sanity check matrix shapes.
- i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
- 0].shape, self.output.shape
+ i_shape, lw_shape, o_shape = (
+ self.input.shape,
+ self.lora_weights_lst[0].shape,
+ self.output.shape,
+ )
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
@@ -480,22 +539,23 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
assert o_shape == (num_tokens, hidden_size * num_slices)
return {
- 'inputs': self.input,
- 'lora_b_weights': self.lora_weights_lst,
- 'output_tensor': self.output,
- 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
- 'token_indices_sorted_by_lora_ids':
- self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
- 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
- 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
- 'lora_ids': self.lora_kernel_meta.active_lora_ids,
- 'offset_start': 0,
- 'add_inputs': add_inputs,
+ "inputs": self.input,
+ "lora_b_weights": self.lora_weights_lst,
+ "output_tensor": self.output,
+ "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
+ "token_indices_sorted_by_lora_ids": (
+ self.lora_kernel_meta.token_indices_sorted_by_lora_ids
+ ),
+ "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
+ "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
+ "lora_ids": self.lora_kernel_meta.active_lora_ids,
+ "offset_start": 0,
+ "add_inputs": add_inputs,
}
- def bench_fn_kwargs(self,
- op_type: OpType,
- add_inputs: Optional[bool] = None) -> dict[str, Any]:
+ def bench_fn_kwargs(
+ self, op_type: OpType, add_inputs: Optional[bool] = None
+ ) -> dict[str, Any]:
if op_type.is_shrink_fn():
assert add_inputs is None
else:
@@ -507,8 +567,9 @@ def bench_fn_kwargs(self,
return self.as_lora_expand_kwargs(add_inputs)
raise ValueError(f"Unrecognized optype {self}")
- def test_correctness(self, op_type: OpType,
- expand_fn_add_inputs: Optional[bool]) -> bool:
+ def test_correctness(
+ self, op_type: OpType, expand_fn_add_inputs: Optional[bool]
+ ) -> bool:
"""
Test correctness of op_type implementation against a grouped gemm
reference implementation.
@@ -518,8 +579,7 @@ def test_correctness(self, op_type: OpType,
ref_output = self.output.clone()
self.output.zero_()
- op_type.bench_fn()(
- **self.bench_fn_kwargs(op_type, expand_fn_add_inputs))
+ op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs))
op_type.run_ref_group_gemm(
ref_output,
@@ -528,7 +588,8 @@ def test_correctness(self, op_type: OpType,
seq_lens_cpu=seq_lens_cpu,
prompt_lora_mapping_cpu=prompt_lora_mapping_cpu,
scaling=1.0,
- add_inputs=expand_fn_add_inputs)
+ add_inputs=expand_fn_add_inputs,
+ )
rtol, atol = {
torch.float16: (6e-2, 6e-2),
@@ -539,13 +600,14 @@ def test_correctness(self, op_type: OpType,
return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol)
-def bench_optype(ctx: BenchmarkContext,
- arg_pool_size: int,
- op_type: OpType,
- cuda_graph_nops: Optional[int] = None,
- expand_fn_add_inputs: Optional[bool] = None,
- test_correctness: bool = False) -> TMeasurement:
-
+def bench_optype(
+ ctx: BenchmarkContext,
+ arg_pool_size: int,
+ op_type: OpType,
+ cuda_graph_nops: Optional[int] = None,
+ expand_fn_add_inputs: Optional[bool] = None,
+ test_correctness: bool = False,
+) -> TMeasurement:
assert arg_pool_size >= 1
if op_type.is_shrink_fn():
assert expand_fn_add_inputs is None
@@ -553,17 +615,17 @@ def bench_optype(ctx: BenchmarkContext,
assert expand_fn_add_inputs is not None
# BenchmarkContext -> BenchmarkTensors
- bench_tensors : list[BenchmarkTensors] = \
- [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
+ bench_tensors: list[BenchmarkTensors] = [
+ BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
+ ]
for bt in bench_tensors:
bt.sanity_check()
# Test correctness of our implementation.
if test_correctness:
- assert all([
- bt.test_correctness(op_type, expand_fn_add_inputs)
- for bt in bench_tensors
- ])
+ assert all(
+ [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]
+ )
# BenchmarkTensors -> dict (kwargs)
kwargs_list = [
@@ -585,40 +647,49 @@ def bench_optype(ctx: BenchmarkContext,
for k, v in _kwargs.items():
kwargs[k].values.append(v)
- describe_args = (f"add_inputs={expand_fn_add_inputs}"
- if expand_fn_add_inputs is not None else "")
- description = (
- f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})")
+ describe_args = (
+ f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else ""
+ )
+ description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})"
cuda_graph_params = None
if cuda_graph_nops:
cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
timer = None
- with Bench(cuda_graph_params,
- ctx.bench_label(), ctx.bench_sublabel(op_type), description,
- op_type.bench_fn(), **kwargs) as bench:
+ with Bench(
+ cuda_graph_params,
+ ctx.bench_label(),
+ ctx.bench_sublabel(op_type),
+ description,
+ op_type.bench_fn(),
+ **kwargs,
+ ) as bench:
timer = bench.run()
return timer
-def bench_torch_mm(ctx: BenchmarkContext,
- arg_pool_size: int,
- op_type: OpType,
- cuda_graph_nops: Optional[int] = None) -> TMeasurement:
+def bench_torch_mm(
+ ctx: BenchmarkContext,
+ arg_pool_size: int,
+ op_type: OpType,
+ cuda_graph_nops: Optional[int] = None,
+) -> TMeasurement:
"""
Benchmark basic torch.mm as a roofline.
When all the input tokens have the same LoRA ID, the LoRA kernels are just
- a matmul. This torch.mm benchmark serves as a roofline for that case.
+ a matmul. This torch.mm benchmark serves as a roofline for that case.
input op_type is used in determining the m, k, n dimensions for the matmul.
"""
- batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size,
- ctx.hidden_size,
- ctx.lora_rank,
- ctx.seq_length,
- ctx.dtype)
+ batch_size, hidden_size, lora_rank, seq_length, dtype = (
+ ctx.batch_size,
+ ctx.hidden_size,
+ ctx.lora_rank,
+ ctx.seq_length,
+ ctx.dtype,
+ )
m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank)
# For a fairer comparison.
@@ -632,18 +703,24 @@ def bench_torch_mm(ctx: BenchmarkContext,
Cs.append(torch.rand((m, n), dtype=dtype).to("cuda"))
# Make torch.mm kwargs
- mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)}
+ mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)}
description = (
f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}"
f"x{dtype_to_str(dtype)}"
- f"=>{dtype_to_str(dtype)})")
+ f"=>{dtype_to_str(dtype)})"
+ )
cuda_graph_params = None
if cuda_graph_nops:
cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
- with Bench(cuda_graph_params, ctx.bench_label(),
- ctx.bench_sublabel(op_type), description, torch.mm,
- **mm_kwargs) as bench:
+ with Bench(
+ cuda_graph_params,
+ ctx.bench_label(),
+ ctx.bench_sublabel(op_type),
+ description,
+ torch.mm,
+ **mm_kwargs,
+ ) as bench:
return bench.run()
@@ -660,8 +737,7 @@ def use_cuda_graph_recommendation() -> str:
"""
-def print_timers(timers: list[TMeasurement],
- args: Optional[argparse.Namespace] = None):
+def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None):
compare = TBenchmark.Compare(timers)
compare.print()
@@ -670,22 +746,23 @@ def print_timers(timers: list[TMeasurement],
f"Note : The timings reported above is for {args.cuda_graph_nops} "
"consecutive invocations of the benchmarking functions. "
f"Please divide by {args.cuda_graph_nops} for single invocation "
- "timings.")
+ "timings."
+ )
- print("Note on Comparison with torch.mm : The torch.mm numbers are "
- "benchmark numbers of a simple matmul emulating the single lora "
- "case. It is provided as a roofline for comparing our LoRA Kernel "
- "implementations. It is expected that the LoRA kernels will be "
- "slower than torch.mm in cases where num_loras is big. But for "
- "small num_loras the goal should be to match the torch.mm numbers.")
+ print(
+ "Note on Comparison with torch.mm : The torch.mm numbers are "
+ "benchmark numbers of a simple matmul emulating the single lora "
+ "case. It is provided as a roofline for comparing our LoRA Kernel "
+ "implementations. It is expected that the LoRA kernels will be "
+ "slower than torch.mm in cases where num_loras is big. But for "
+ "small num_loras the goal should be to match the torch.mm numbers."
+ )
def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
-
if args.cuda_graph_nops is not None:
assert args.cuda_graph_nops > 0
- print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA "
- "Graph")
+ print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph")
else:
print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}")
@@ -697,21 +774,30 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
for bench_op in bench_ops:
for num_slices in bench_op.num_slices():
_ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(
- num_slices)
+ num_slices
+ )
# Benchmark torch.mm as a roofline
seq_len_timers.append(
- bench_torch_mm(_ctx, args.arg_pool_size, bench_op,
- args.cuda_graph_nops))
+ bench_torch_mm(
+ _ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops
+ )
+ )
# Benchmark bench_op
- expand_fn_add_inputs = [
- None
- ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
+ expand_fn_add_inputs = (
+ [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
+ )
for add_input_arg in expand_fn_add_inputs:
seq_len_timers.append(
- bench_optype(_ctx, args.arg_pool_size, bench_op,
- args.cuda_graph_nops, add_input_arg,
- args.test_correctness))
+ bench_optype(
+ _ctx,
+ args.arg_pool_size,
+ bench_op,
+ args.cuda_graph_nops,
+ add_input_arg,
+ args.test_correctness,
+ )
+ )
print_timers(seq_len_timers)
timers.extend(seq_len_timers)
@@ -733,13 +819,17 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
pickle.dump(timers, f)
-def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
- args: argparse.Namespace) -> list[BenchmarkContext]:
-
+def as_benchmark_contexts(
+ hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
+) -> list[BenchmarkContext]:
ctxs: list[BenchmarkContext] = []
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
- args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
- args.sort_by_lora_id):
+ args.batch_sizes,
+ list(hidden_sizes),
+ lora_ranks,
+ args.num_loras,
+ args.sort_by_lora_id,
+ ):
ctxs.append(
BenchmarkContext(
batch_size=batch_size,
@@ -747,13 +837,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
lora_rank=lora_rank,
num_loras=num_loras,
num_active_loras=args.num_active_loras
- if args.num_active_loras else num_loras,
+ if args.num_active_loras
+ else num_loras,
# To be filled based on the OpType to benchmark
seq_length=None,
sort_by_lora_id=sort_by_lora_id,
dtype=args.dtype,
# To be filled based on the OpType to benchmark
- num_slices=None))
+ num_slices=None,
+ )
+ )
return ctxs
@@ -761,13 +854,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
def run_list_bench(args: argparse.Namespace):
print(args)
- print("List bench :\n"
- f" Hidden Sizes {args.hidden_sizes}"
- f" LoRA Ranks {args.lora_ranks}")
+ print(
+ "List bench :\n"
+ f" Hidden Sizes {args.hidden_sizes}"
+ f" LoRA Ranks {args.lora_ranks}"
+ )
# Get all benchmarking contexts
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
- hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)
+ hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args
+ )
run(args, bench_contexts)
@@ -776,19 +872,22 @@ def run_range_bench(args: argparse.Namespace):
print(args)
hidden_sizes = list(
- range(args.hidden_sizes_start, args.hidden_sizes_end + 1,
- args.hidden_sizes_increment))
+ range(
+ args.hidden_sizes_start,
+ args.hidden_sizes_end + 1,
+ args.hidden_sizes_increment,
+ )
+ )
lora_ranks = list(
- range(args.lora_ranks_start, args.lora_ranks_end + 1,
- args.lora_ranks_increment))
+ range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment)
+ )
- print("Range bench :\n"
- f" Hidden Sizes {hidden_sizes}"
- f" LoRA Ranks {lora_ranks}")
+ print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}")
# Get all benchmarking contexts
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
- hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)
+ hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args
+ )
run(args, bench_contexts)
@@ -806,21 +905,19 @@ def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]:
# Get all hidden sizes
hidden_sizes: set[int] = set()
for model_name, tp_size in product(args.models, args.tp_sizes):
- hidden_sizes = hidden_sizes.union(
- hidden_sizes_from_model(model_name, tp_size))
+ hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size))
- print("Model bench :\n"
- f" Hidden Sizes {hidden_sizes}"
- f" LoRA Ranks {args.lora_ranks}")
+ print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}")
# Get all benchmarking contexts
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
- hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)
+ hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args
+ )
run(args, bench_contexts)
-if __name__ == '__main__':
+if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "torch.float16":
@@ -830,14 +927,15 @@ def to_torch_dtype(dt):
raise ValueError("unsupported dtype")
def get_bool(s: str) -> bool:
- return s.lower() in ['true', '1']
+ return s.lower() in ["true", "1"]
def add_common_command_args(p: argparse.ArgumentParser):
p.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
- help="Available options are ['torch.float16', 'torch.bfloat16']")
+ help="Available options are ['torch.float16', 'torch.bfloat16']",
+ )
p.add_argument(
"--arg-pool-size",
@@ -845,56 +943,66 @@ def add_common_command_args(p: argparse.ArgumentParser):
default=32,
help="Run profiles with a pool of input/output/meta tensors instead"
"of simply reusing the same tensors for all runs. A bigger arg-pool"
- "mitigates hardware caching effects during benchmarking.")
+ "mitigates hardware caching effects during benchmarking.",
+ )
p.add_argument(
"--cuda-graph-nops",
type=int,
- help=("when set profiling is done using cudagraph, "
- "with the given number of operations in a graph."
- "Note that the measurement returned is the time "
- "taken for N consecutive executions of the benchmarking "
- "functions, where N is the value of this argument."))
- p.add_argument("--num-loras",
- nargs="+",
- type=int,
- default=DEFAULT_NUM_LORAS)
- p.add_argument("--num-active-loras",
- type=int,
- default=None,
- help="Active LoRAs. When None, all LoRAs are active")
- p.add_argument("--sort-by-lora-id",
- nargs="+",
- type=get_bool,
- default=DEFAULT_SORT_BY_LORA_IDS)
- p.add_argument("--op-types",
- nargs="+",
- type=OpType.from_str,
- default=list(OpType))
- p.add_argument('--seq-lengths',
- nargs="+",
- type=int,
- default=DEFAULT_SEQ_LENGTHS)
- p.add_argument("--batch-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_BATCH_SIZES)
- p.add_argument("--expand-fn-add-inputs",
- nargs="+",
- type=get_bool,
- default=DEFAULT_EXPAND_FN_ADD_INPUTS)
+ help=(
+ "when set profiling is done using cudagraph, "
+ "with the given number of operations in a graph."
+ "Note that the measurement returned is the time "
+ "taken for N consecutive executions of the benchmarking "
+ "functions, where N is the value of this argument."
+ ),
+ )
+ p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS)
+ p.add_argument(
+ "--num-active-loras",
+ type=int,
+ default=None,
+ help="Active LoRAs. When None, all LoRAs are active",
+ )
+ p.add_argument(
+ "--sort-by-lora-id",
+ nargs="+",
+ type=get_bool,
+ default=DEFAULT_SORT_BY_LORA_IDS,
+ )
+ p.add_argument(
+ "--op-types", nargs="+", type=OpType.from_str, default=list(OpType)
+ )
+ p.add_argument(
+ "--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS
+ )
+ p.add_argument(
+ "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
+ )
+ p.add_argument(
+ "--expand-fn-add-inputs",
+ nargs="+",
+ type=get_bool,
+ default=DEFAULT_EXPAND_FN_ADD_INPUTS,
+ )
p.add_argument(
- '-o',
- '--output-directory',
+ "-o",
+ "--output-directory",
type=str,
- help=("Output directory to store a the list of benchmarking"
- "TMeasurement objects as a pickle file"))
+ help=(
+ "Output directory to store a the list of benchmarking"
+ "TMeasurement objects as a pickle file"
+ ),
+ )
p.add_argument(
"--test-correctness",
- action='store_true',
- help=("When enabled, the benchmarking functions are tested"
- "for correctness before the actual benchmarking"))
+ action="store_true",
+ help=(
+ "When enabled, the benchmarking functions are tested"
+ "for correctness before the actual benchmarking"
+ ),
+ )
parser = FlexibleArgumentParser(
description=f"""
@@ -910,50 +1018,45 @@ def add_common_command_args(p: argparse.ArgumentParser):
range_bench example:
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
""", # noqa: E501
- formatter_class=argparse.RawTextHelpFormatter)
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
subparsers = parser.add_subparsers(dest="cmd", required=True)
list_parser = subparsers.add_parser("list_bench")
- list_parser.add_argument("--hidden-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_HIDDEN_SIZES)
- list_parser.add_argument("--lora-ranks",
- nargs="+",
- type=int,
- default=DEFAULT_LORA_RANKS)
+ list_parser.add_argument(
+ "--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES
+ )
+ list_parser.add_argument(
+ "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
+ )
add_common_command_args(list_parser)
list_parser.set_defaults(func=run_list_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--hidden-sizes-start", type=int, required=True)
range_parser.add_argument("--hidden-sizes-end", type=int, required=True)
- range_parser.add_argument("--hidden-sizes-increment",
- type=int,
- required=True)
+ range_parser.add_argument("--hidden-sizes-increment", type=int, required=True)
range_parser.add_argument("--lora-ranks-start", type=int, required=True)
range_parser.add_argument("--lora-ranks-end", type=int, required=True)
- range_parser.add_argument("--lora-ranks-increment",
- type=int,
- required=True)
+ range_parser.add_argument("--lora-ranks-increment", type=int, required=True)
add_common_command_args(range_parser)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
- model_parser.add_argument("--models",
- nargs="+",
- type=str,
- default=DEFAULT_MODELS,
- choices=WEIGHT_SHAPES.keys())
- model_parser.add_argument("--tp-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_TP_SIZES)
- model_parser.add_argument("--lora-ranks",
- nargs="+",
- type=int,
- default=DEFAULT_LORA_RANKS)
+ model_parser.add_argument(
+ "--models",
+ nargs="+",
+ type=str,
+ default=DEFAULT_MODELS,
+ choices=WEIGHT_SHAPES.keys(),
+ )
+ model_parser.add_argument(
+ "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
+ )
+ model_parser.add_argument(
+ "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
+ )
add_common_command_args(model_parser)
model_parser.set_defaults(func=run_model_bench)
diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py
index a661ea9d7e60..f8f1db04790b 100644
--- a/benchmarks/kernels/benchmark_machete.py
+++ b/benchmarks/kernels/benchmark_machete.py
@@ -20,12 +20,18 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
- GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales,
- marlin_zero_points)
+ GPTQ_MARLIN_MAX_PARALLEL,
+ GPTQ_MARLIN_MIN_THREAD_N,
+ marlin_permute_scales,
+ marlin_zero_points,
+)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
- MarlinWorkspace)
+ MarlinWorkspace,
+)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
- pack_rows, quantize_weights)
+ pack_rows,
+ quantize_weights,
+)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
@@ -82,12 +88,14 @@ def rand_data(shape, dtype=torch.float16, scale=1):
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
-def quantize_and_pack(atype: torch.dtype,
- w: torch.Tensor,
- wtype: ScalarType,
- stype: Optional[torch.dtype],
- group_size: Optional[int],
- zero_points: bool = False):
+def quantize_and_pack(
+ atype: torch.dtype,
+ w: torch.Tensor,
+ wtype: ScalarType,
+ stype: Optional[torch.dtype],
+ group_size: Optional[int],
+ zero_points: bool = False,
+):
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(
@@ -96,21 +104,24 @@ def quantize_and_pack(atype: torch.dtype,
group_size=group_size,
zero_points=zero_points,
# to match how the kernel applies zps
- ref_zero_points_after_scales=True)
+ ref_zero_points_after_scales=True,
+ )
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
return w_ref, w_q, w_s, w_zp
-def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
- group_size: Optional[int]) -> list[BenchmarkTensors]:
+def create_bench_tensors(
+ shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
+) -> list[BenchmarkTensors]:
m, n, k = shape
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
- num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
- (k * n * types.weight_type.size_bits))
+ num_weights = math.ceil(
+ 2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)
+ )
a = rand_data((m, k), types.act_type, scale=5)
@@ -124,8 +135,13 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
w = w.to(torch.float16)
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
- a.dtype, w, types.weight_type, types.group_scale_type, group_size,
- types.group_zero_type is not None)
+ a.dtype,
+ w,
+ types.weight_type,
+ types.group_scale_type,
+ group_size,
+ types.group_zero_type is not None,
+ )
if not a.dtype.is_floating_point:
aiinfo = torch.iinfo(a.dtype)
@@ -133,21 +149,30 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
w_ref = w_ref.to(torch.float32)
- w_ch_s = None if types.channel_scale_type is None else\
- rand_data((n,), types.channel_scale_type)
- w_tok_s = None if types.token_scale_type is None else\
- rand_data((m,), types.token_scale_type)
+ w_ch_s = (
+ None
+ if types.channel_scale_type is None
+ else rand_data((n,), types.channel_scale_type)
+ )
+ w_tok_s = (
+ None
+ if types.token_scale_type is None
+ else rand_data((m,), types.token_scale_type)
+ )
benchmark_tensors.append(
- BenchmarkTensors(w_ref=w_ref,
- a=a,
- w_q=w_q_packed,
- wtype=types.weight_type,
- w_g_s=w_s,
- w_g_zp=w_zp,
- group_size=group_size,
- w_ch_s=w_ch_s,
- w_tok_s=w_tok_s))
+ BenchmarkTensors(
+ w_ref=w_ref,
+ a=a,
+ w_q=w_q_packed,
+ wtype=types.weight_type,
+ w_g_s=w_s,
+ w_g_zp=w_zp,
+ group_size=group_size,
+ w_ch_s=w_ch_s,
+ w_tok_s=w_tok_s,
+ )
+ )
return benchmark_tensors
@@ -170,50 +195,57 @@ def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
return lambda: ops.cutlass_scaled_mm(
- bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16)
+ bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16
+ )
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
device = bt.a.device
- workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
- GPTQ_MARLIN_MAX_PARALLEL)
+ workspace = MarlinWorkspace(
+ bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
+ )
if bt.w_g_zp is None:
w_zp = torch.empty(0, dtype=torch.int, device=device)
else:
- w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
- bt.w_ref.shape[1], bt.wtype.size_bits)
+ w_zp = marlin_zero_points(
+ bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
+ )
if bt.group_size is None:
w_s = torch.tensor([], device="cuda", dtype=torch.half)
else:
- w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
- bt.w_ref.shape[1], bt.group_size)
+ w_s = marlin_permute_scales(
+ bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size
+ )
sort_indices = torch.empty(0, dtype=torch.int, device=device)
g_idx = torch.empty(0, dtype=torch.int, device=device)
- w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
- bt.w_ref.shape[1], bt.wtype.size_bits)
+ w_q = ops.gptq_marlin_repack(
+ bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
+ )
if bt.a.dtype.is_floating_point:
assert bt.w_ch_s is None
assert bt.w_tok_s is None
assert bt.group_size is not None
- fn = lambda: ops.gptq_marlin_gemm(a=bt.a,
- b_q_weight=w_q,
- b_scales=w_s,
- b_zeros=w_zp,
- g_idx=g_idx,
- perm=sort_indices,
- workspace=workspace.scratch,
- b_q_type=bt.wtype,
- size_m=bt.a.shape[0],
- size_n=bt.w_ref.shape[1],
- size_k=bt.w_ref.shape[0],
- is_k_full=True,
- is_zp_float=False)
+ fn = lambda: ops.gptq_marlin_gemm(
+ a=bt.a,
+ b_q_weight=w_q,
+ b_scales=w_s,
+ b_zeros=w_zp,
+ g_idx=g_idx,
+ perm=sort_indices,
+ workspace=workspace.scratch,
+ b_q_type=bt.wtype,
+ size_m=bt.a.shape[0],
+ size_n=bt.w_ref.shape[1],
+ size_k=bt.w_ref.shape[0],
+ is_k_full=True,
+ is_zp_float=False,
+ )
else:
assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8
@@ -221,36 +253,35 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
if bt.w_ch_s is not None:
s_ch = bt.w_ch_s.to(torch.float32)
else:
- s_ch = torch.ones(bt.w_ref.shape[1],
- dtype=torch.float32,
- device=device)
+ s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
if bt.w_tok_s is not None:
s_tok = bt.w_tok_s.to(torch.float32)
else:
- s_tok = torch.ones(bt.a.shape[0],
- dtype=torch.float32,
- device=device)
-
- fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
- b_q_weight=w_q,
- s_group=w_s,
- s_tok=s_tok,
- s_ch=s_ch,
- workspace=workspace.scratch,
- size_m=bt.a.shape[0],
- size_n=bt.w_ref.shape[1],
- size_k=bt.w_ref.shape[0])
+ s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
+
+ fn = lambda: ops.marlin_qqq_gemm(
+ a=bt.a,
+ b_q_weight=w_q,
+ s_group=w_s,
+ s_tok=s_tok,
+ s_ch=s_ch,
+ workspace=workspace.scratch,
+ size_m=bt.a.shape[0],
+ size_n=bt.w_ref.shape[1],
+ size_k=bt.w_ref.shape[0],
+ )
return fn
-def machete_create_bench_fn(bt: BenchmarkTensors,
- out_type=torch.dtype,
- schedule=None) -> Callable:
+def machete_create_bench_fn(
+ bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
+) -> Callable:
w_q = bt.w_q.t().contiguous().t() # make col major
- w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
- None if bt.w_g_s is None else bt.w_g_s.dtype)
+ w_q = ops.machete_prepack_B(
+ w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype
+ )
w_g_zp = bt.w_g_zp
if w_g_zp is not None:
@@ -275,26 +306,24 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
# bench
-def bench_fns(label: str, sub_label: str, description: str,
- fns: list[Callable]):
-
+def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]):
min_run_time = 1 if not NVTX_PROFILE else 0.1
res = TBenchmark.Timer(
stmt="""
for fn in fns:
fn()
""",
- globals={
- "fns": fns
- },
+ globals={"fns": fns},
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
if NVTX_PROFILE:
- with nvtx.annotate("mm-bench"), nvtx.annotate(
- f"{label}|{sub_label}|{description}"):
+ with (
+ nvtx.annotate("mm-bench"),
+ nvtx.annotate(f"{label}|{sub_label}|{description}"),
+ ):
fns[0]()
return res
@@ -304,19 +333,20 @@ def bench_fns(label: str, sub_label: str, description: str,
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
-def bench(types: TypeConfig,
- group_size: int,
- m: int,
- k: int,
- n: int,
- label: str,
- sub_label: str,
- sweep_schedules: bool = True) -> list[TMeasurement]:
+def bench(
+ types: TypeConfig,
+ group_size: int,
+ m: int,
+ k: int,
+ n: int,
+ label: str,
+ sub_label: str,
+ sweep_schedules: bool = True,
+) -> list[TMeasurement]:
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
sub_label += f", L={len(benchmark_tensors)}"
- name_type_string = f"W{types.weight_type}"+\
- f"-A{terse_type_name(types.act_type)}"
+ name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}"
if types.group_scale_type is not None:
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
if types.group_zero_type is not None:
@@ -332,31 +362,45 @@ def bench(types: TypeConfig,
# pytorch impl
timers.append(
bench_fns(
- label, sub_label, "torch.matmul (fp16)",
- [torch_matmul_f16_create_bench_fn(bt)
- for bt in benchmark_tensors]))
+ label,
+ sub_label,
+ "torch.matmul (fp16)",
+ [torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors],
+ )
+ )
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
timers.append(
bench_fns(
- label, sub_label,
- f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
- cutlass_scaled_mm_create_bench_fn(bt)
- for bt in benchmark_tensors
- ]))
+ label,
+ sub_label,
+ f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
+ [cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors],
+ )
+ )
if types.act_type != torch.float8_e4m3fn:
timers.append(
- bench_fns(label, sub_label, f"marlin ({name_type_string})",
- [marlin_create_bench_fn(bt)
- for bt in benchmark_tensors]))
+ bench_fns(
+ label,
+ sub_label,
+ f"marlin ({name_type_string})",
+ [marlin_create_bench_fn(bt) for bt in benchmark_tensors],
+ )
+ )
# machete
timers.append(
- bench_fns(label, sub_label, f"machete ({name_type_string})", [
- machete_create_bench_fn(bt, out_type=types.output_type)
- for bt in benchmark_tensors
- ]))
+ bench_fns(
+ label,
+ sub_label,
+ f"machete ({name_type_string})",
+ [
+ machete_create_bench_fn(bt, out_type=types.output_type)
+ for bt in benchmark_tensors
+ ],
+ )
+ )
if sweep_schedules:
global _SWEEP_SCHEDULES_RESULTS
@@ -371,7 +415,8 @@ def bench(types: TypeConfig,
group_zeros_type=types.group_zero_type,
token_scales_type=types.token_scale_type,
channel_scales_type=types.channel_scale_type,
- out_type=types.output_type)
+ out_type=types.output_type,
+ )
if schedules is None or len(schedules) == 0:
raise ValueError("No schedules found to sweep")
@@ -383,11 +428,17 @@ def bench(types: TypeConfig,
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue
- res = bench_fns(label, sub_label, "machete_best", [
- machete_create_bench_fn(
- bt, out_type=types.output_type, schedule=schedule)
- for bt in benchmark_tensors
- ])
+ res = bench_fns(
+ label,
+ sub_label,
+ "machete_best",
+ [
+ machete_create_bench_fn(
+ bt, out_type=types.output_type, schedule=schedule
+ )
+ for bt in benchmark_tensors
+ ],
+ )
results_row = {
"M": m,
@@ -398,10 +449,8 @@ def bench(types: TypeConfig,
"median": res.median,
}
if _SWEEP_SCHEDULES_RESULTS is None:
- _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
- columns=results_row.keys())
- _SWEEP_SCHEDULES_RESULTS.\
- loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
+ _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys())
+ _SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
@@ -422,8 +471,9 @@ def print_timers(timers: list[TMeasurement]):
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
types = TypeConfig(
act_type=args.act_type,
- weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
- else scalar_types.uint4,
+ weight_type=scalar_types.uint4b8
+ if args.group_zero_type is None
+ else scalar_types.uint4,
output_type=args.out_type,
group_scale_type=args.group_scale_type,
group_zero_type=args.group_zero_type,
@@ -433,14 +483,16 @@ def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
results: list[TMeasurement] = []
for m, k, n in MKNs:
- timers = bench(types,
- args.group_size,
- m,
- k,
- n,
- f"{args.act_type}-gemm",
- f"MKN=({m}x{k}x{n})",
- sweep_schedules=args.sweep_schedules)
+ timers = bench(
+ types,
+ args.group_size,
+ m,
+ k,
+ n,
+ f"{args.act_type}-gemm",
+ f"MKN=({m}x{k}x{n})",
+ sweep_schedules=args.sweep_schedules,
+ )
print_timers(timers)
results.extend(timers)
@@ -454,7 +506,6 @@ def make_output(
base_description: str,
timestamp=None,
):
-
print(f"== All Results {base_description} ====")
print_timers(data)
@@ -468,8 +519,7 @@ def make_output(
def run_square_bench(args):
- dim_sizes = list(
- range(args.dim_start, args.dim_end + 1, args.dim_increment))
+ dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, args.sweep_schedules, MKNs)
@@ -479,8 +529,9 @@ def run_square_bench(args):
def run_range_bench(args):
m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
- m_increment, k_increment, n_increment = \
- (int(x) for x in args.dim_increment.split(","))
+ m_increment, k_increment, n_increment = (
+ int(x) for x in args.dim_increment.split(",")
+ )
Ms = list(range(m_start, m_end + 1, m_increment))
Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment))
@@ -492,7 +543,6 @@ def run_range_bench(args):
def run_model_bench(args):
-
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
@@ -535,10 +585,13 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
args_dict = vars(args)
args_dict.pop("func")
- pkl.dump({
- "args": args_dict,
- "results": all_results,
- }, f)
+ pkl.dump(
+ {
+ "args": args_dict,
+ "results": all_results,
+ },
+ f,
+ )
if __name__ == "__main__":
@@ -554,7 +607,6 @@ def to_torch_dtype(dt):
}[dt]
class ToTorchDtype(argparse.Action):
-
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, to_torch_dtype(values))
@@ -580,32 +632,32 @@ def __call__(self, parser, namespace, values, option_string=None):
"--act-type",
action=ToTorchDtype,
required=True,
- choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
+ choices=["bfloat16", "float16", "int8", "float8_e4m3fn"],
)
parser.add_argument(
"--group-scale-type",
action=ToTorchDtype,
- choices=['bfloat16', 'float16'],
+ choices=["bfloat16", "float16"],
)
parser.add_argument(
"--group-zero-type",
type=to_torch_dtype,
- choices=['bfloat16', 'float16'],
+ choices=["bfloat16", "float16"],
)
parser.add_argument(
"--channel-scale-type",
action=ToTorchDtype,
- choices=['float'],
+ choices=["float"],
)
parser.add_argument(
"--token-scale-type",
action=ToTorchDtype,
- choices=['float'],
+ choices=["float"],
)
parser.add_argument(
"--out-type",
action=ToTorchDtype,
- choices=['bfloat16', 'float16'],
+ choices=["bfloat16", "float16"],
)
parser.add_argument(
"--group-size",
@@ -618,9 +670,11 @@ def __call__(self, parser, namespace, values, option_string=None):
action="store_true",
help="Run a sweep over all supported schedules",
)
- parser.add_argument("--sweep-csv-out",
- help="CSV to store sweep results",
- default="sch_sweep_results.csv")
+ parser.add_argument(
+ "--sweep-csv-out",
+ help="CSV to store sweep results",
+ default="sch_sweep_results.csv",
+ )
subparsers = parser.add_subparsers(dest="cmd", required=True)
square_parser = subparsers.add_parser("square_bench")
@@ -634,17 +688,20 @@ def __call__(self, parser, namespace, values, option_string=None):
"--dim-start",
type=str,
required=True,
- help="Start value for M,K,N as common separated list")
+ help="Start value for M,K,N as common separated list",
+ )
range_parser.add_argument(
"--dim-end",
type=str,
required=True,
- help="End value (inclusive) for M,K,N as common separated list")
+ help="End value (inclusive) for M,K,N as common separated list",
+ )
range_parser.add_argument(
"--dim-increment",
type=str,
required=True,
- help="Increment value for M,K,N as common separated list")
+ help="Increment value for M,K,N as common separated list",
+ )
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
@@ -655,14 +712,12 @@ def __call__(self, parser, namespace, values, option_string=None):
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
- model_parser.add_argument("--tp-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_TP_SIZES)
- model_parser.add_argument("--batch-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_BATCH_SIZES)
+ model_parser.add_argument(
+ "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
+ )
+ model_parser.add_argument(
+ "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
+ )
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py
index 1e785ac8fc73..b17baff2e5f5 100644
--- a/benchmarks/kernels/benchmark_marlin.py
+++ b/benchmarks/kernels/benchmark_marlin.py
@@ -6,19 +6,34 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
- GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
- GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
+ GPTQ_MARLIN_24_MAX_PARALLEL,
+ GPTQ_MARLIN_24_MIN_THREAD_N,
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
+)
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
- ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
+ ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
+ ALLSPARK_SUPPORTED_QUANT_TYPES,
+)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
- GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
- MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
+ GPTQ_MARLIN_MAX_PARALLEL,
+ GPTQ_MARLIN_MIN_THREAD_N,
+ MARLIN_SUPPORTED_GROUP_SIZES,
+ query_marlin_supported_quant_types,
+)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
- MarlinWorkspace, marlin_quantize)
+ MarlinWorkspace,
+ marlin_quantize,
+)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
- marlin_24_quantize)
+ marlin_24_quantize,
+)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
- gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
+ gptq_pack,
+ gptq_quantize_weights,
+ quantize_weights,
+ sort_weights,
+)
from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser
@@ -29,22 +44,29 @@
K_FULL_OPTS = [False, True]
-def bench_run(results: list[benchmark.Measurement], model: str,
- act_order: bool, is_k_full: bool, quant_type: ScalarType,
- group_size: int, size_m: int, size_k: int, size_n: int):
+def bench_run(
+ results: list[benchmark.Measurement],
+ model: str,
+ act_order: bool,
+ is_k_full: bool,
+ quant_type: ScalarType,
+ group_size: int,
+ size_m: int,
+ size_k: int,
+ size_n: int,
+):
label = "Quant Matmul"
- sub_label = ("{}, act={} k_full={}, q={}, g={}, "
- "MKN=({}x{}x{})".format(model, act_order, is_k_full,
- str(quant_type), group_size, size_m,
- size_k, size_n))
+ sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
+ model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
+ )
print(f"Testing: {sub_label}")
a = torch.randn(size_m, size_k).to(torch.half).cuda()
b = torch.rand(size_k, size_n).to(torch.half).cuda()
- a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
+ a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
# Marlin quant
(
@@ -57,14 +79,16 @@ def bench_run(results: list[benchmark.Measurement], model: str,
) = marlin_quantize(b, quant_type, group_size, act_order)
# Marlin_24 quant
- (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
- marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
+ (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
+ marlin_24_quantize(b, quant_type, group_size)
+ )
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
# GPTQ quant
- (w_ref, q_w, s, g_idx,
- rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
+ (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
+ b, quant_type, group_size, act_order
+ )
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx"
@@ -74,32 +98,37 @@ def bench_run(results: list[benchmark.Measurement], model: str,
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare
- marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
- GPTQ_MARLIN_MAX_PARALLEL)
+ marlin_workspace = MarlinWorkspace(
+ size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
+ )
- marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
- GPTQ_MARLIN_24_MAX_PARALLEL)
+ marlin_24_workspace = MarlinWorkspace(
+ size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
+ )
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
# AllSpark W8A16 quant
- as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
- and group_size == -1 and not act_order and is_k_full)
+ as_supported_case = (
+ quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
+ and group_size == -1
+ and not act_order
+ and is_k_full
+ )
if as_supported_case:
properties = torch.cuda.get_device_properties(b.device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor
- supported_arch = (sm_version >= 80 and sm_version < 90)
+ supported_arch = sm_version >= 80 and sm_version < 90
as_supported_case = as_supported_case and supported_arch
if supported_arch:
has_zp = False
- w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
- has_zp)
+ w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
qw = qw.to(torch.uint8)
- qw_reorder, s_reorder, zp_reorder = \
- ops.allspark_repack_weight(
- qw, s, zp, has_zp)
+ qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
+ qw, s, zp, has_zp
+ )
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
globals = {
@@ -136,8 +165,7 @@ def bench_run(results: list[benchmark.Measurement], model: str,
"zp_reorder": zp_reorder if as_supported_case else None,
"sm_count": sm_count if as_supported_case else None,
"sm_version": sm_version if as_supported_case else None,
- "CUBLAS_M_THRESHOLD":
- CUBLAS_M_THRESHOLD if as_supported_case else None,
+ "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
@@ -158,60 +186,63 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label,
sub_label=sub_label,
description="pytorch_gemm",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
results.append(
benchmark.Timer(
- stmt=
- "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
+ stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp16",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
results.append(
benchmark.Timer(
- stmt=
- "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
+ stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp32",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
- if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
- and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
+ if (
+ quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
+ and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
+ ):
results.append(
benchmark.Timer(
- stmt=
- "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
+ stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
results.append(
benchmark.Timer(
- stmt=
- "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
+ stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_repack",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
if as_supported_case:
results.append(
benchmark.Timer(
- stmt=
- "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
+ stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="allspark_w8a16_gemm_fp32",
- ).blocked_autorange(min_run_time=min_run_time))
+ ).blocked_autorange(min_run_time=min_run_time)
+ )
def main(args):
@@ -233,37 +264,50 @@ def main(args):
continue
for act_order in ACT_ORDER_OPTS:
- if len(args.limit_act_order
- ) > 0 and act_order not in args.limit_act_order:
+ if (
+ len(args.limit_act_order) > 0
+ and act_order not in args.limit_act_order
+ ):
continue
for is_k_full in K_FULL_OPTS:
- if len(args.limit_k_full
- ) > 0 and is_k_full not in args.limit_k_full:
+ if (
+ len(args.limit_k_full) > 0
+ and is_k_full not in args.limit_k_full
+ ):
continue
- for quant_type in query_marlin_supported_quant_types(
- False):
- if len(args.limit_num_bits) > 0 and \
- quant_type.size_bits not in args.limit_num_bits:
+ for quant_type in query_marlin_supported_quant_types(False):
+ if (
+ len(args.limit_num_bits) > 0
+ and quant_type.size_bits not in args.limit_num_bits
+ ):
continue
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
- if len(
- args.limit_group_size
- ) > 0 and group_size not in args.limit_group_size:
+ if (
+ len(args.limit_group_size) > 0
+ and group_size not in args.limit_group_size
+ ):
continue
# For act_order, the group_size must be less than
# size_k
- if act_order and (group_size == size_k
- or group_size == -1):
+ if act_order and (group_size == size_k or group_size == -1):
continue
for size_m in args.batch_sizes:
- bench_run(results, model, act_order, is_k_full,
- quant_type, group_size, size_m,
- size_k, size_n)
+ bench_run(
+ results,
+ model,
+ act_order,
+ is_k_full,
+ quant_type,
+ group_size,
+ size_m,
+ size_k,
+ size_n,
+ )
compare = benchmark.Compare(results)
compare.print()
@@ -274,7 +318,8 @@ def main(args):
#
if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description="Benchmark Marlin across specified models/shapes/batches")
+ description="Benchmark Marlin across specified models/shapes/batches"
+ )
parser.add_argument(
"--models",
nargs="+",
@@ -282,10 +327,9 @@ def main(args):
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
- parser.add_argument("--batch-sizes",
- nargs="+",
- type=int,
- default=DEFAULT_BATCH_SIZES)
+ parser.add_argument(
+ "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
+ )
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 4e328b4d49e5..c2f7660858f5 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -31,56 +31,60 @@ class BenchmarkConfig(TypedDict):
num_stages: int
-def benchmark_config(config: BenchmarkConfig,
- num_tokens: int,
- num_experts: int,
- shard_intermediate_size: int,
- hidden_size: int,
- topk: int,
- dtype: torch.dtype,
- use_fp8_w8a8: bool,
- use_int8_w8a16: bool,
- num_iters: int = 100,
- block_quant_shape: List[int] = None,
- use_deep_gemm: bool = False) -> float:
+def benchmark_config(
+ config: BenchmarkConfig,
+ num_tokens: int,
+ num_experts: int,
+ shard_intermediate_size: int,
+ hidden_size: int,
+ topk: int,
+ dtype: torch.dtype,
+ use_fp8_w8a8: bool,
+ use_int8_w8a16: bool,
+ num_iters: int = 100,
+ block_quant_shape: List[int] = None,
+ use_deep_gemm: bool = False,
+) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16:
- w1 = torch.randint(-127,
- 127, (
- num_experts,
- shard_intermediate_size,
- hidden_size,
- ),
- dtype=torch.int8)
- w2 = torch.randint(-127,
- 127, (
- num_experts,
- hidden_size,
- shard_intermediate_size // 2,
- ),
- dtype=torch.int8)
+ w1 = torch.randint(
+ -127,
+ 127,
+ (
+ num_experts,
+ shard_intermediate_size,
+ hidden_size,
+ ),
+ dtype=torch.int8,
+ )
+ w2 = torch.randint(
+ -127,
+ 127,
+ (
+ num_experts,
+ hidden_size,
+ shard_intermediate_size // 2,
+ ),
+ dtype=torch.int8,
+ )
else:
- w1 = torch.randn(num_experts,
- shard_intermediate_size,
- hidden_size,
- dtype=init_dtype)
- w2 = torch.randn(num_experts,
- hidden_size,
- shard_intermediate_size // 2,
- dtype=init_dtype)
- gating_output = torch.randn(num_iters,
- num_tokens,
- num_experts,
- dtype=torch.float32)
+ w1 = torch.randn(
+ num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
+ )
+ w2 = torch.randn(
+ num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
+ )
+ gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
- w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
- dtype=torch.float32)
+ w1_scale = torch.randn(
+ (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
+ )
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
if block_quant_shape:
@@ -93,10 +97,14 @@ def benchmark_config(config: BenchmarkConfig,
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
- w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
- dtype=torch.float32) * factor_for_scale
- w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
- dtype=torch.float32) * factor_for_scale
+ w1_scale = (
+ torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
+ * factor_for_scale
+ )
+ w2_scale = (
+ torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
+ * factor_for_scale
+ )
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
@@ -114,10 +122,12 @@ def prepare(i: int):
def run():
from vllm.model_executor.layers.fused_moe import override_config
+
with override_config(config):
if use_deep_gemm:
topk_weights, topk_ids, token_expert_indices = fused_topk(
- x, input_gating, topk, False)
+ x, input_gating, topk, False
+ )
return fused_experts(
x,
w1,
@@ -213,8 +223,7 @@ def get_rocm_tuning_space(use_fp16):
return param_ranges
-def get_configs_compute_bound(use_fp16,
- block_quant_shape) -> list[dict[str, int]]:
+def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = []
if current_platform.is_rocm():
@@ -250,20 +259,25 @@ def get_configs_compute_bound(use_fp16,
if block_quant_shape is not None and not use_fp16:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
for config in configs[:]:
- if config["BLOCK_SIZE_K"] % block_k != 0 or config[
- "BLOCK_SIZE_N"] % block_n != 0:
+ if (
+ config["BLOCK_SIZE_K"] % block_k != 0
+ or config["BLOCK_SIZE_N"] % block_n != 0
+ ):
configs.remove(config)
return configs
-def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
- search_space, is_fp16, topk):
+def prune_rocm_search_space(
+ num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
+):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
- pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
- search_space, is_fp16)
- pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
- search_space, is_fp16)
+ pruned_space_1 = prune_rocm_configs(
+ num_tokens * topk, N1, K1, search_space, is_fp16
+ )
+ pruned_space_2 = prune_rocm_configs(
+ num_tokens * topk, N2, K2, search_space, is_fp16
+ )
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space
@@ -301,14 +315,14 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M")
if is_fp16:
- if (matrix_instr_nonkdim > BLOCK_SIZE_M
- or matrix_instr_nonkdim > BLOCK_SIZE_N):
+ if (
+ matrix_instr_nonkdim > BLOCK_SIZE_M
+ or matrix_instr_nonkdim > BLOCK_SIZE_N
+ ):
continue
- if (matrix_instr_nonkdim >= M
- and matrix_instr_nonkdim != BLOCK_SIZE_M):
+ if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
continue
- if (matrix_instr_nonkdim >= N
- and matrix_instr_nonkdim != BLOCK_SIZE_N):
+ if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
@@ -329,8 +343,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
- LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
- BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
+ LDS = (
+ BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
+ + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
+ )
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
@@ -364,7 +380,6 @@ def merge_unique_dicts(list1, list2):
@ray.remote(num_gpus=1)
class BenchmarkWorker:
-
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
@@ -388,36 +403,40 @@ def benchmark(
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed)
- dtype_str = get_config_dtype_str(dtype,
- use_int8_w8a16=use_int8_w8a16,
- use_fp8_w8a8=use_fp8_w8a8)
+ dtype_str = get_config_dtype_str(
+ dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
+ )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
- op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
- dtype_str)
+ op_config = get_moe_configs(
+ num_experts, shard_intermediate_size // 2, dtype_str
+ )
if op_config is None:
- config = get_default_config(num_tokens,
- num_experts,
- shard_intermediate_size,
- hidden_size,
- topk,
- dtype_str,
- is_marlin=False)
+ config = get_default_config(
+ num_tokens,
+ num_experts,
+ shard_intermediate_size,
+ hidden_size,
+ topk,
+ dtype_str,
+ is_marlin=False,
+ )
else:
- config = op_config[min(op_config.keys(),
- key=lambda x: abs(x - num_tokens))]
- kernel_time = benchmark_config(config,
- num_tokens,
- num_experts,
- shard_intermediate_size,
- hidden_size,
- topk,
- dtype,
- use_fp8_w8a8,
- use_int8_w8a16,
- num_iters=100,
- block_quant_shape=block_quant_shape,
- use_deep_gemm=use_deep_gemm)
+ config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
+ kernel_time = benchmark_config(
+ config,
+ num_tokens,
+ num_experts,
+ shard_intermediate_size,
+ hidden_size,
+ topk,
+ dtype,
+ use_fp8_w8a8,
+ use_int8_w8a16,
+ num_iters=100,
+ block_quant_shape=block_quant_shape,
+ use_deep_gemm=use_deep_gemm,
+ )
return config, kernel_time
def tune(
@@ -438,10 +457,14 @@ def tune(
best_time = float("inf")
if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
- search_space = prune_rocm_search_space(num_tokens,
- shard_intermediate_size,
- hidden_size, search_space,
- is_fp16, topk)
+ search_space = prune_rocm_search_space(
+ num_tokens,
+ shard_intermediate_size,
+ hidden_size,
+ search_space,
+ is_fp16,
+ topk,
+ )
need_device_guard = False
if current_platform.is_rocm():
@@ -449,8 +472,7 @@ def tune(
if visible_device != f"{self.device_id}":
need_device_guard = True
- with torch.cuda.device(
- self.device_id) if need_device_guard else nullcontext():
+ with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
@@ -465,7 +487,8 @@ def tune(
use_int8_w8a16,
num_iters=20,
block_quant_shape=block_quant_shape,
- use_deep_gemm=use_deep_gemm)
+ use_deep_gemm=use_deep_gemm,
+ )
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
@@ -481,42 +504,44 @@ def tune(
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
- "BLOCK_SIZE_M":
- config["BLOCK_SIZE_M"],
- "BLOCK_SIZE_N":
- config["BLOCK_SIZE_N"],
- "BLOCK_SIZE_K":
- config["BLOCK_SIZE_K"],
- "GROUP_SIZE_M":
- config["GROUP_SIZE_M"],
- "num_warps":
- config["num_warps"],
- "num_stages":
- config["num_stages"],
- **({
- "waves_per_eu": config["waves_per_eu"]
- } if "waves_per_eu" in config else {}),
- **({
- "matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
- } if "matrix_instr_nonkdim" in config else {}),
- **({
- "kpack": config["kpack"]
- } if "kpack" in config else {}),
+ "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
+ "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
+ "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
+ "GROUP_SIZE_M": config["GROUP_SIZE_M"],
+ "num_warps": config["num_warps"],
+ "num_stages": config["num_stages"],
+ **(
+ {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
+ ),
+ **(
+ {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
+ if "matrix_instr_nonkdim" in config
+ else {}
+ ),
+ **({"kpack": config["kpack"]} if "kpack" in config else {}),
}
-def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
- shard_intermediate_size: int, hidden_size: int, topk: int,
- dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
- block_quant_shape: List[int]) -> None:
- dtype_str = get_config_dtype_str(dtype,
- use_int8_w8a16=use_int8_w8a16,
- use_fp8_w8a8=use_fp8_w8a8)
+def save_configs(
+ configs: dict[int, BenchmarkConfig],
+ num_experts: int,
+ shard_intermediate_size: int,
+ hidden_size: int,
+ topk: int,
+ dtype: torch.dtype,
+ use_fp8_w8a8: bool,
+ use_int8_w8a16: bool,
+ block_quant_shape: List[int],
+) -> None:
+ dtype_str = get_config_dtype_str(
+ dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
+ )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
- filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
- dtype_str, block_quant_shape)
+ filename = get_config_file_name(
+ num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
+ )
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
@@ -525,18 +550,16 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
def get_weight_block_size_safety(config, default_value=None):
-
- quantization_config = getattr(config, 'quantization_config', {})
+ quantization_config = getattr(config, "quantization_config", {})
if isinstance(quantization_config, dict):
- return quantization_config.get('weight_block_size', default_value)
+ return quantization_config.get("weight_block_size", default_value)
return default_value
def main(args: argparse.Namespace):
print(args)
- config = get_config(model=args.model,
- trust_remote_code=args.trust_remote_code)
+ config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
if args.model_prefix:
config = getattr(config, args.model_prefix)
config = SimpleNamespace(**config)
@@ -551,14 +574,12 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
- elif (config.architectures[0]
- in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")):
+ elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
- elif config.architectures[0] in ("Qwen2MoeForCausalLM",
- "Qwen3MoeForCausalLM"):
+ elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
@@ -573,16 +594,35 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
- dtype = torch.float16 if current_platform.is_rocm() else getattr(
- torch, config.torch_dtype)
+ dtype = (
+ torch.float16
+ if current_platform.is_rocm()
+ else getattr(torch, config.torch_dtype)
+ )
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)
if args.batch_size is None:
batch_sizes = [
- 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
- 2048, 3072, 4096
+ 1,
+ 2,
+ 4,
+ 8,
+ 16,
+ 24,
+ 32,
+ 48,
+ 64,
+ 96,
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1536,
+ 2048,
+ 3072,
+ 4096,
]
else:
batch_sizes = [args.batch_size]
@@ -593,7 +633,8 @@ def main(args: argparse.Namespace):
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger.warning(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
- "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.")
+ "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
+ )
val = os.environ["HIP_VISIBLE_DEVICES"]
os.environ["ROCR_VISIBLE_DEVICES"] = val
del os.environ["HIP_VISIBLE_DEVICES"]
@@ -620,25 +661,59 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
start = time.time()
configs = _distribute(
- "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
- topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
- block_quant_shape, use_deep_gemm)
- for batch_size in batch_sizes])
+ "tune",
+ [
+ (
+ batch_size,
+ E,
+ shard_intermediate_size,
+ hidden_size,
+ topk,
+ dtype,
+ use_fp8_w8a8,
+ use_int8_w8a16,
+ search_space,
+ block_quant_shape,
+ use_deep_gemm,
+ )
+ for batch_size in batch_sizes
+ ],
+ )
best_configs = {
- M: sort_config(config)
- for M, config in zip(batch_sizes, configs)
+ M: sort_config(config) for M, config in zip(batch_sizes, configs)
}
- save_configs(best_configs, E, shard_intermediate_size, hidden_size,
- topk, dtype, use_fp8_w8a8, use_int8_w8a16,
- block_quant_shape)
+ save_configs(
+ best_configs,
+ E,
+ shard_intermediate_size,
+ hidden_size,
+ topk,
+ dtype,
+ use_fp8_w8a8,
+ use_int8_w8a16,
+ block_quant_shape,
+ )
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark",
- [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
- use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
- for batch_size in batch_sizes])
+ [
+ (
+ batch_size,
+ E,
+ shard_intermediate_size,
+ hidden_size,
+ topk,
+ dtype,
+ use_fp8_w8a8,
+ use_int8_w8a16,
+ block_quant_shape,
+ use_deep_gemm,
+ )
+ for batch_size in batch_sizes
+ ],
+ )
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
@@ -647,18 +722,15 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
if __name__ == "__main__":
parser = FlexibleArgumentParser()
- parser.add_argument("--model",
- type=str,
- default="mistralai/Mixtral-8x7B-Instruct-v0.1")
- parser.add_argument("--tp-size",
- "-tp",
- "--tensor-parallel-size",
- type=int,
- default=2)
- parser.add_argument("--dtype",
- type=str,
- choices=["auto", "fp8_w8a8", "int8_w8a16"],
- default="auto")
+ parser.add_argument(
+ "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
+ )
+ parser.add_argument(
+ "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
+ )
+ parser.add_argument(
+ "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
+ )
parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
index 937df9624651..333986fdf5ef 100644
--- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py
+++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
@@ -8,7 +8,9 @@
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
- _moe_permute, _moe_unpermute_and_reduce)
+ _moe_permute,
+ _moe_unpermute_and_reduce,
+)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
@@ -27,15 +29,17 @@ class BenchmarkConfig(TypedDict):
num_stages: int
-def benchmark_permute(num_tokens: int,
- num_experts: int,
- hidden_size: int,
- topk: int,
- dtype: torch.dtype,
- use_fp8_w8a8: bool,
- use_int8_w8a16: bool,
- num_iters: int = 100,
- use_customized_permute: bool = False) -> float:
+def benchmark_permute(
+ num_tokens: int,
+ num_experts: int,
+ hidden_size: int,
+ topk: int,
+ dtype: torch.dtype,
+ use_fp8_w8a8: bool,
+ use_int8_w8a16: bool,
+ num_iters: int = 100,
+ use_customized_permute: bool = False,
+) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
# output_hidden_states = torch.empty_like(hidden_states)
@@ -46,36 +50,41 @@ def benchmark_permute(num_tokens: int,
align_block_size = None
qhidden_states = hidden_states
- gating_output = torch.randn(num_iters,
- num_tokens,
- num_experts,
- dtype=torch.float32)
+ gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_weights, topk_ids, token_expert_indices = fused_topk(
- qhidden_states, input_gating, topk, False)
+ qhidden_states, input_gating, topk, False
+ )
def prepare(i: int):
input_gating.copy_(gating_output[i])
def run():
if use_customized_permute:
- (permuted_hidden_states, first_token_off, inv_perm_idx,
- m_indices) = moe_permute(
- qhidden_states,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- token_expert_indices=token_expert_indices,
- topk=topk,
- n_expert=num_experts,
- n_local_expert=num_experts,
- expert_map=None,
- align_block_size=align_block_size,
- )
+ (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
+ moe_permute(
+ qhidden_states,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ token_expert_indices=token_expert_indices,
+ topk=topk,
+ n_expert=num_experts,
+ n_local_expert=num_experts,
+ expert_map=None,
+ align_block_size=align_block_size,
+ )
+ )
else:
- (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
- inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
- num_experts, None, align_block_size)
+ (
+ permuted_hidden_states,
+ a1q_scale,
+ sorted_token_ids,
+ expert_ids,
+ inv_perm,
+ ) = _moe_permute(
+ qhidden_states, None, topk_ids, num_experts, None, align_block_size
+ )
# JIT compilation & warmup
run()
@@ -111,15 +120,17 @@ def run():
return avg
-def benchmark_unpermute(num_tokens: int,
- num_experts: int,
- hidden_size: int,
- topk: int,
- dtype: torch.dtype,
- use_fp8_w8a8: bool,
- use_int8_w8a16: bool,
- num_iters: int = 100,
- use_customized_permute: bool = False) -> float:
+def benchmark_unpermute(
+ num_tokens: int,
+ num_experts: int,
+ hidden_size: int,
+ topk: int,
+ dtype: torch.dtype,
+ use_fp8_w8a8: bool,
+ use_int8_w8a16: bool,
+ num_iters: int = 100,
+ use_customized_permute: bool = False,
+) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
output_hidden_states = torch.empty_like(hidden_states)
@@ -133,46 +144,74 @@ def benchmark_unpermute(num_tokens: int,
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_weights, topk_ids, token_expert_indices = fused_topk(
- qhidden_states, input_gating, topk, False)
+ qhidden_states, input_gating, topk, False
+ )
def prepare():
if use_customized_permute:
- (permuted_hidden_states, first_token_off, inv_perm_idx,
- m_indices) = moe_permute(
- qhidden_states,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- token_expert_indices=token_expert_indices,
- topk=topk,
- n_expert=num_experts,
- n_local_expert=num_experts,
- expert_map=None,
- align_block_size=align_block_size,
- )
+ (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
+ moe_permute(
+ qhidden_states,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ token_expert_indices=token_expert_indices,
+ topk=topk,
+ n_expert=num_experts,
+ n_local_expert=num_experts,
+ expert_map=None,
+ align_block_size=align_block_size,
+ )
+ )
# convert to fp16/bf16 as gemm output
- return (permuted_hidden_states.to(dtype), first_token_off,
- inv_perm_idx, m_indices)
+ return (
+ permuted_hidden_states.to(dtype),
+ first_token_off,
+ inv_perm_idx,
+ m_indices,
+ )
else:
- (permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids,
- inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
- num_experts, None, align_block_size)
+ (
+ permuted_qhidden_states,
+ a1q_scale,
+ sorted_token_ids,
+ expert_ids,
+ inv_perm,
+ ) = _moe_permute(
+ qhidden_states, None, topk_ids, num_experts, None, align_block_size
+ )
# convert to fp16/bf16 as gemm output
- return (permuted_qhidden_states.to(dtype), a1q_scale,
- sorted_token_ids, expert_ids, inv_perm)
+ return (
+ permuted_qhidden_states.to(dtype),
+ a1q_scale,
+ sorted_token_ids,
+ expert_ids,
+ inv_perm,
+ )
def run(input: tuple):
if use_customized_permute:
- (permuted_hidden_states, first_token_off, inv_perm_idx,
- m_indices) = input
- moe_unpermute(permuted_hidden_states, topk_weights, topk_ids,
- inv_perm_idx, first_token_off, topk, num_experts,
- num_experts)
+ (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
+ moe_unpermute(
+ permuted_hidden_states,
+ topk_weights,
+ topk_ids,
+ inv_perm_idx,
+ first_token_off,
+ topk,
+ num_experts,
+ num_experts,
+ )
else:
- (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
- inv_perm) = input
- _moe_unpermute_and_reduce(output_hidden_states,
- permuted_hidden_states, inv_perm,
- topk_weights)
+ (
+ permuted_hidden_states,
+ a1q_scale,
+ sorted_token_ids,
+ expert_ids,
+ inv_perm,
+ ) = input
+ _moe_unpermute_and_reduce(
+ output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
+ )
# JIT compilation & warmup
input = prepare()
@@ -209,7 +248,6 @@ def run(input: tuple):
@ray.remote(num_gpus=1)
class BenchmarkWorker:
-
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
@@ -241,7 +279,8 @@ def benchmark(
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
- use_customized_permute=use_customized_permute)
+ use_customized_permute=use_customized_permute,
+ )
unpermute_time = benchmark_unpermute(
num_tokens,
num_experts,
@@ -251,15 +290,15 @@ def benchmark(
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
- use_customized_permute=use_customized_permute)
+ use_customized_permute=use_customized_permute,
+ )
return permute_time, unpermute_time
def get_weight_block_size_safety(config, default_value=None):
-
- quantization_config = getattr(config, 'quantization_config', {})
+ quantization_config = getattr(config, "quantization_config", {})
if isinstance(quantization_config, dict):
- return quantization_config.get('weight_block_size', default_value)
+ return quantization_config.get("weight_block_size", default_value)
return default_value
@@ -267,20 +306,21 @@ def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(
- args.model, trust_remote_code=args.trust_remote_code)
+ args.model, trust_remote_code=args.trust_remote_code
+ )
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
- elif (config.architectures[0] == "DeepseekV3ForCausalLM"
- or config.architectures[0] == "DeepseekV2ForCausalLM"):
+ elif (
+ config.architectures[0] == "DeepseekV3ForCausalLM"
+ or config.architectures[0] == "DeepseekV2ForCausalLM"
+ ):
E = config.n_routed_experts
topk = config.num_experts_per_tok
- elif config.architectures[0] in [
- "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
- ]:
+ elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
@@ -299,8 +339,24 @@ def main(args: argparse.Namespace):
if args.batch_size is None:
batch_sizes = [
- 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
- 2048, 3072, 4096
+ 1,
+ 2,
+ 4,
+ 8,
+ 16,
+ 24,
+ 32,
+ 48,
+ 64,
+ 96,
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1536,
+ 2048,
+ 3072,
+ 4096,
]
else:
batch_sizes = [args.batch_size]
@@ -321,9 +377,21 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
return ray.get(outputs)
outputs = _distribute(
- "benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8,
- use_int8_w8a16, use_customized_permute)
- for batch_size in batch_sizes])
+ "benchmark",
+ [
+ (
+ batch_size,
+ E,
+ hidden_size,
+ topk,
+ dtype,
+ use_fp8_w8a8,
+ use_int8_w8a16,
+ use_customized_permute,
+ )
+ for batch_size in batch_sizes
+ ],
+ )
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}")
@@ -333,13 +401,12 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
if __name__ == "__main__":
parser = FlexibleArgumentParser()
- parser.add_argument("--model",
- type=str,
- default="mistralai/Mixtral-8x7B-Instruct-v0.1")
- parser.add_argument("--dtype",
- type=str,
- choices=["auto", "fp8_w8a8", "int8_w8a16"],
- default="auto")
+ parser.add_argument(
+ "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
+ )
+ parser.add_argument(
+ "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
+ )
parser.add_argument("--use-customized-permute", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py
index 2625239b08ef..54f05e723226 100644
--- a/benchmarks/kernels/benchmark_paged_attention.py
+++ b/benchmarks/kernels/benchmark_paged_attention.py
@@ -9,8 +9,11 @@
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
-from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
- create_kv_caches_with_random)
+from vllm.utils import (
+ STR_DTYPE_TO_TORCH_DTYPE,
+ FlexibleArgumentParser,
+ create_kv_caches_with_random,
+)
logger = init_logger(__name__)
@@ -38,19 +41,15 @@ def main(
current_platform.seed_everything(seed)
scale = float(1.0 / (head_size**0.5))
- query = torch.empty(num_seqs,
- num_query_heads,
- head_size,
- dtype=dtype,
- device=device)
+ query = torch.empty(
+ num_seqs, num_query_heads, head_size, dtype=dtype, device=device
+ )
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
alibi_slopes = None
if use_alibi:
- alibi_slopes = torch.randn(num_query_heads,
- dtype=torch.float,
- device=device)
+ alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
seq_lens = [seq_len for _ in range(num_seqs)]
max_seq_len = max(seq_lens)
@@ -61,24 +60,23 @@ def main(
block_tables_lst: list[list[int]] = []
for _ in range(num_seqs):
block_table = [
- random.randint(0, NUM_BLOCKS - 1)
- for _ in range(max_num_blocks_per_seq)
+ random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
]
block_tables_lst.append(block_table)
- block_tables = torch.tensor(block_tables_lst,
- dtype=torch.int,
- device=device)
+ block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
# Create the KV cache.
- key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
- block_size,
- 1,
- num_kv_heads,
- head_size,
- kv_cache_dtype,
- dtype,
- device=device)
+ key_caches, value_caches = create_kv_caches_with_random(
+ NUM_BLOCKS,
+ block_size,
+ 1,
+ num_kv_heads,
+ head_size,
+ kv_cache_dtype,
+ dtype,
+ device=device,
+ )
key_cache, value_cache = key_caches[0], value_caches[0]
# Prepare for the paged attention kernel.
@@ -86,11 +84,11 @@ def main(
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
- if not args.custom_paged_attn:
+ if not args.custom_paged_attn and not current_platform.is_navi():
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
- num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
+ num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype,
@@ -110,9 +108,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
start_time = time.perf_counter()
# Using default kv_scale
- k_scale = v_scale = torch.tensor(1.0,
- dtype=torch.float32,
- device=device)
+ k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
for _ in range(num_iters):
if version == "v1":
@@ -166,6 +162,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
scale,
block_tables,
seq_lens,
+ None,
block_size,
max_seq_len,
alibi_slopes,
@@ -195,30 +192,29 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
print(f"Kernel running time: {latency * 1000000:.3f} us")
-if __name__ == '__main__':
- logger.warning("This script benchmarks the paged attention kernel. "
- "By default this is no longer used in vLLM inference.")
+if __name__ == "__main__":
+ logger.warning(
+ "This script benchmarks the paged attention kernel. "
+ "By default this is no longer used in vLLM inference."
+ )
- parser = FlexibleArgumentParser(
- description="Benchmark the paged attention kernel.")
- parser.add_argument("--version",
- type=str,
- choices=["v1", "v2"],
- default="v2")
+ parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
+ parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8)
- parser.add_argument("--head-size",
- type=int,
- choices=[64, 80, 96, 112, 120, 128, 192, 256],
- default=128)
+ parser.add_argument(
+ "--head-size",
+ type=int,
+ choices=[64, 80, 96, 112, 120, 128, 192, 256],
+ default=128,
+ )
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true")
- parser.add_argument("--dtype",
- type=str,
- choices=["half", "bfloat16", "float"],
- default="half")
+ parser.add_argument(
+ "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
+ )
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument(
@@ -228,10 +224,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
default="auto",
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
- "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
- parser.add_argument("--custom-paged-attn",
- action="store_true",
- help="Use custom paged attention")
+ "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
+ )
+ parser.add_argument(
+ "--custom-paged-attn", action="store_true", help="Use custom paged attention"
+ )
args = parser.parse_args()
print(args)
diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py
index b643897a60ee..2463dfebe83c 100644
--- a/benchmarks/kernels/benchmark_quant.py
+++ b/benchmarks/kernels/benchmark_quant.py
@@ -10,15 +10,17 @@
@torch.inference_mode()
-def main(num_tokens: int,
- hidden_size: int,
- static_scale: bool,
- quant_dtype: torch.dtype,
- dtype: torch.dtype,
- seed: int = 0,
- do_profile: bool = False,
- num_warmup_iters: int = 5,
- num_iters: int = 100) -> None:
+def main(
+ num_tokens: int,
+ hidden_size: int,
+ static_scale: bool,
+ quant_dtype: torch.dtype,
+ dtype: torch.dtype,
+ seed: int = 0,
+ do_profile: bool = False,
+ num_warmup_iters: int = 5,
+ num_iters: int = 100,
+) -> None:
current_platform.seed_everything(seed)
torch.set_default_device("cuda")
@@ -56,7 +58,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
print(f"Kernel running time: {latency * 1000000:.3f} us")
-if __name__ == '__main__':
+if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
@@ -66,37 +68,40 @@ def to_torch_dtype(dt):
raise ValueError(f"Unsupported dtype: {dt}")
parser = FlexibleArgumentParser(
- description="Benchmark the quantization (fp8 or int8) kernel.")
+ description="Benchmark the quantization (fp8 or int8) kernel."
+ )
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--static-scale", action="store_true")
- parser.add_argument("--quant-dtype",
- type=str,
- choices=["fp8", "int8"],
- default="int8")
- parser.add_argument("--dtype",
- type=str,
- choices=["half", "bfloat16", "float"],
- default="half")
+ parser.add_argument(
+ "--quant-dtype", type=str, choices=["fp8", "int8"], default="int8"
+ )
+ parser.add_argument(
+ "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
+ )
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
- parser.add_argument("--num-iters",
- type=int,
- default=100,
- help="Number of benchmark iterations. "
- "If --profile is set, this number is ignored")
+ parser.add_argument(
+ "--num-iters",
+ type=int,
+ default=100,
+ help="Number of benchmark iterations. "
+ "If --profile is set, this number is ignored",
+ )
args = parser.parse_args()
print(args)
- main(num_tokens=args.num_tokens,
- hidden_size=args.hidden_size,
- static_scale=args.static_scale,
- quant_dtype=to_torch_dtype(args.quant_dtype),
- dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
- seed=args.seed,
- do_profile=args.profile,
- num_warmup_iters=args.num_warmup_iters,
- num_iters=args.num_iters)
+ main(
+ num_tokens=args.num_tokens,
+ hidden_size=args.hidden_size,
+ static_scale=args.static_scale,
+ quant_dtype=to_torch_dtype(args.quant_dtype),
+ dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
+ seed=args.seed,
+ do_profile=args.profile,
+ num_warmup_iters=args.num_warmup_iters,
+ num_iters=args.num_iters,
+ )
diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py
index 09a319ccf1d1..d720083b6150 100644
--- a/benchmarks/kernels/benchmark_rmsnorm.py
+++ b/benchmarks/kernels/benchmark_rmsnorm.py
@@ -12,7 +12,6 @@
class HuggingFaceRMSNorm(nn.Module):
-
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -114,23 +113,19 @@ def rmsnorm_vllm(
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
- x = torch.randn(batch_size,
- seq_len,
- hidden_size,
- dtype=dtype,
- device="cuda")
+ x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive(
- x.clone(), weight,
- residual.clone() if residual is not None else None)
+ x.clone(), weight, residual.clone() if residual is not None else None
+ )
output_flashinfer = rmsnorm_flashinfer(
- x.clone(), weight,
- residual.clone() if residual is not None else None)
+ x.clone(), weight, residual.clone() if residual is not None else None
+ )
output_vllm = rmsnorm_vllm(
- x.clone(), weight,
- residual.clone() if residual is not None else None)
+ x.clone(), weight, residual.clone() if residual is not None else None
+ )
if use_residual:
output_naive = output_naive[0]
@@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
print(f"FlashInfer output={output_flashinfer}")
print(f"vLLM output={output_vllm}")
- if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
- rtol=1e-2) and torch.allclose(
- output_naive, output_vllm, atol=1e-2, rtol=1e-2):
+ if torch.allclose(
+ output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
+ ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("โ
All implementations match")
else:
print("โ Implementations differ")
@@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48]
-configs = list(
- itertools.product(head_num_range, batch_size_range, seq_length_range))
+configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual):
-
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"],
@@ -167,19 +160,15 @@ def get_benchmark(use_residual):
line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
- plot_name=
- f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
+ plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
args={},
- ))
+ )
+ )
def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128
- x = torch.randn(batch_size,
- seq_len,
- hidden_size,
- dtype=dtype,
- device="cuda")
+ x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
@@ -240,9 +229,9 @@ def benchmark(head_num, batch_size, seq_len, provider):
default=4096,
help="Hidden size (2nd dimension) of the sequence",
)
- parser.add_argument("--use-residual",
- action="store_true",
- help="Whether to use residual connection")
+ parser.add_argument(
+ "--use-residual", action="store_true", help="Whether to use residual connection"
+ )
parser.add_argument(
"--save-path",
type=str,
@@ -253,10 +242,12 @@ def benchmark(head_num, batch_size, seq_len, provider):
args = parser.parse_args()
# Run correctness test
- calculate_diff(batch_size=args.batch_size,
- seq_len=args.seq_len,
- hidden_size=args.hidden_size,
- use_residual=args.use_residual)
+ calculate_diff(
+ batch_size=args.batch_size,
+ seq_len=args.seq_len,
+ hidden_size=args.hidden_size,
+ use_residual=args.use_residual,
+ )
# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual)
diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py
index 05d24fc4b16d..110d36db157f 100644
--- a/benchmarks/kernels/benchmark_rope.py
+++ b/benchmarks/kernels/benchmark_rope.py
@@ -6,8 +6,7 @@
import nvtx
import torch
-from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
- get_rope)
+from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
@@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora(
# silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors
- batched_rope = get_rope(head_size, rotary_dim, max_position, base,
- is_neox_style, {
- "rope_type": "linear",
- "factor": tuple(scaling_factors)
- })
+ batched_rope = get_rope(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ {"rope_type": "linear", "factor": tuple(scaling_factors)},
+ )
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes: list[RotaryEmbedding] = []
for scaling_factor in scaling_factors:
non_batched_ropes.append(
- get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
- {
- "rope_type": "linear",
- "factor": (scaling_factor, )
- }))
+ get_rope(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ {"rope_type": "linear", "factor": (scaling_factor,)},
+ )
+ )
positions = torch.randint(0, max_position, (batch_size, seq_len))
- query = torch.randn(batch_size,
- seq_len,
- num_heads * head_size,
- dtype=dtype)
+ query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map = torch.tensor(
list(
- accumulate([0] + [
- max_position * scaling_factor * 2
- for scaling_factor in scaling_factors[:-1]
- ])))
- query_types = torch.randint(0,
- len(scaling_factors), (batch_size, seq_len),
- device=device)
+ accumulate(
+ [0]
+ + [
+ max_position * scaling_factor * 2
+ for scaling_factor in scaling_factors[:-1]
+ ]
+ )
+ )
+ )
+ query_types = torch.randint(
+ 0, len(scaling_factors), (batch_size, seq_len), device=device
+ )
# map query types to offsets
query_offsets = offset_map[query_types]
# the kernel takes flattened offsets
@@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora(
torch.cuda.synchronize()
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description="Benchmark the rotary embedding kernels.")
+ description="Benchmark the rotary embedding kernels."
+ )
parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--num-heads", type=int, default=8)
- parser.add_argument("--head-size",
- type=int,
- choices=[64, 80, 96, 112, 120, 128, 192, 256],
- default=128)
+ parser.add_argument(
+ "--head-size",
+ type=int,
+ choices=[64, 80, 96, 112, 120, 128, 192, 256],
+ default=128,
+ )
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
- parser.add_argument("--dtype",
- type=str,
- choices=["bfloat16", "float"],
- default="float")
+ parser.add_argument(
+ "--dtype", type=str, choices=["bfloat16", "float"], default="float"
+ )
parser.add_argument("--seed", type=int, default=0)
- parser.add_argument("--device",
- type=str,
- choices=["cuda:0", "cuda:1"],
- default="cuda:0")
+ parser.add_argument(
+ "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
+ )
args = parser.parse_args()
print(args)
diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py
index 8f07bc8ca52e..6315c1ee6cdd 100644
--- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py
+++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py
@@ -14,14 +14,16 @@
import triton
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
- _w8a8_block_fp8_matmul)
+ _w8a8_block_fp8_matmul,
+)
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True)
-assert current_platform.is_cuda(
-), "Only support tune w8a8 block fp8 kernel on CUDA device."
+assert current_platform.is_cuda(), (
+ "Only support tune w8a8 block fp8 kernel on CUDA device."
+)
DTYPE_MAP = {
"float32": torch.float32,
@@ -40,7 +42,7 @@ def w8a8_block_matmul(
config: dict[str, Any],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
- """This function performs matrix multiplication with
+ """This function performs matrix multiplication with
block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
@@ -51,7 +53,7 @@ def w8a8_block_matmul(
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
- block_size: The block size for per-block quantization.
+ block_size: The block size for per-block quantization.
It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
@@ -71,18 +73,18 @@ def w8a8_block_matmul(
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
- C_shape = A.shape[:-1] + (N, )
+ C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
def grid(META):
- return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
- triton.cdiv(N, META["BLOCK_SIZE_N"]), )
+ return (
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
+ )
if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_block_fp8_matmul
else:
- raise RuntimeError(
- "Currently, only support tune w8a8 block fp8 kernel.")
+ raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
kernel[grid](
A,
@@ -119,14 +121,16 @@ def get_configs_compute_bound():
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
- configs.append({
- "BLOCK_SIZE_M": block_m,
- "BLOCK_SIZE_N": block_n,
- "BLOCK_SIZE_K": block_k,
- "GROUP_SIZE_M": group_size,
- "num_warps": num_warps,
- "num_stages": num_stages,
- })
+ configs.append(
+ {
+ "BLOCK_SIZE_M": block_m,
+ "BLOCK_SIZE_N": block_n,
+ "BLOCK_SIZE_K": block_k,
+ "GROUP_SIZE_M": group_size,
+ "num_warps": num_warps,
+ "num_stages": num_stages,
+ }
+ )
return configs
@@ -165,15 +169,9 @@ def get_weight_shapes(tp_size):
return weight_shapes
-def benchmark_config(A,
- B,
- As,
- Bs,
- block_size,
- config,
- out_dtype=torch.float16,
- num_iters=10):
-
+def benchmark_config(
+ A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
+):
def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
@@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (
- (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
- fp8_max)
+ (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
+ )
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp32 = (
- (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
- fp8_max)
+ (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
+ )
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
else:
- raise RuntimeError(
- "Currently, only support tune w8a8 block fp8 kernel.")
+ raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
- As = torch.rand(M, k_tiles, dtype=torch.float32,
- device="cuda") * factor_for_scale
- Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") *
- factor_for_scale)
+ As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
+ Bs = (
+ torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
+ * factor_for_scale
+ )
best_config = None
best_time = float("inf")
@@ -267,7 +265,8 @@ def save_configs(
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = (
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
- f"block_shape=[{block_n},{block_k}].json")
+ f"block_shape=[{block_n},{block_k}].json"
+ )
config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...")
@@ -295,8 +294,7 @@ def tune_on_gpu(args_dict):
search_space = get_configs_compute_bound()
search_space = [
- config for config in search_space
- if block_k % config["BLOCK_SIZE_K"] == 0
+ config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
]
start = time.time()
@@ -312,15 +310,11 @@ def tune_on_gpu(args_dict):
out_dtype,
search_space,
input_type,
- ) for batch_size in tqdm(batch_sizes,
- desc=f"GPU {gpu_id} - Batch sizes")
+ )
+ for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
]
- best_configs = {
- M: config
- for M, config in zip(batch_sizes, benchmark_results)
- }
- save_configs(N, K, block_n, block_k, best_configs, save_path,
- input_type)
+ best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
+ save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
end = time.time()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
@@ -376,13 +370,14 @@ def main(args):
process_args = []
for gpu_id in range(num_gpus):
- process_args.append({
- "gpu_id": gpu_id,
- "batch_sizes": batches_per_gpu[gpu_id],
- "weight_shapes":
- weight_shapes, # Each GPU processes all weight shapes
- "args": args,
- })
+ process_args.append(
+ {
+ "gpu_id": gpu_id,
+ "batch_sizes": batches_per_gpu[gpu_id],
+ "weight_shapes": weight_shapes, # Each GPU processes all weight shapes
+ "args": args,
+ }
+ )
ctx = mp.get_context("spawn")
with ctx.Pool(num_gpus) as pool:
@@ -398,13 +393,11 @@ def main(args):
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
Then copy to model_executor/layers/quantization/utils/configs
""",
- formatter_class=argparse.RawTextHelpFormatter)
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
parser.add_argument("--tp-size", "-tp", type=int, default=8)
- parser.add_argument("--input-type",
- type=str,
- choices=["fp8"],
- default="fp8")
+ parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
parser.add_argument(
"--out-dtype",
type=str,
diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
index 5fa55bb974e1..e37764825451 100644
--- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
+++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
@@ -11,7 +11,9 @@
# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
- per_token_group_quant_fp8, w8a8_block_fp8_matmul)
+ per_token_group_quant_fp8,
+ w8a8_block_fp8_matmul,
+)
from vllm.triton_utils import triton
diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py
index bd62173a7b3a..0c86e4072957 100644
--- a/benchmarks/kernels/graph_machete_bench.py
+++ b/benchmarks/kernels/graph_machete_bench.py
@@ -2,11 +2,11 @@
import math
import pickle
-import re
from collections import defaultdict
import matplotlib.pyplot as plt
import pandas as pd
+import regex as re
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement
@@ -14,13 +14,14 @@
if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description='Benchmark the latency of processing a single batch of '
- 'requests till completion.')
- parser.add_argument('filename', type=str)
+ description="Benchmark the latency of processing a single batch of "
+ "requests till completion."
+ )
+ parser.add_argument("filename", type=str)
args = parser.parse_args()
- with open(args.filename, 'rb') as f:
+ with open(args.filename, "rb") as f:
data = pickle.load(f)
raw_results: list[TMeasurement] = data["results"]
@@ -38,11 +39,7 @@
raise Exception("MKN not found")
kernel = v.task_spec.description
- results[KN].append({
- "kernel": kernel,
- "batch_size": M,
- "median": v.median
- })
+ results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
@@ -50,14 +47,16 @@
for axs_idx, (shape, data) in enumerate(results.items()):
plt.sca(axs[axs_idx])
df = pd.DataFrame(data)
- sns.lineplot(data=df,
- x="batch_size",
- y="median",
- hue="kernel",
- style="kernel",
- markers=True,
- dashes=False,
- palette="Dark2")
+ sns.lineplot(
+ data=df,
+ x="batch_size",
+ y="median",
+ hue="kernel",
+ style="kernel",
+ markers=True,
+ dashes=False,
+ palette="Dark2",
+ )
plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)")
plt.tight_layout()
diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py
index ac64f786f184..877a29feed9d 100644
--- a/benchmarks/kernels/utils.py
+++ b/benchmarks/kernels/utils.py
@@ -23,6 +23,7 @@ class ArgPool:
For every invocation during a benchmarking run, it will choose a
different value from the list.
"""
+
values: Iterable[Any]
def __getitem__(self, index):
@@ -30,9 +31,7 @@ def __getitem__(self, index):
class Bench:
-
class ArgsIterator:
-
def __init__(self, args_list, kwargs_list):
assert len(args_list) == len(kwargs_list)
self.args_list = args_list
@@ -53,10 +52,16 @@ def reset(self):
def n_args(self):
return self.n
- def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
- label: str, sub_label: str, description: str, fn: Callable,
- *args, **kwargs):
-
+ def __init__(
+ self,
+ cuda_graph_params: Optional[CudaGraphBenchParams],
+ label: str,
+ sub_label: str,
+ description: str,
+ fn: Callable,
+ *args,
+ **kwargs,
+ ):
self.cuda_graph_params = cuda_graph_params
self.use_cuda_graph = self.cuda_graph_params is not None
self.label = label
@@ -67,10 +72,8 @@ def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
# Process args
self._args = args
self._kwargs = kwargs
- self.args_list, self.kwargs_list = self.collapse_argpool(
- *args, **kwargs)
- self.args_iterator = self.ArgsIterator(self.args_list,
- self.kwargs_list)
+ self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs)
+ self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list)
# Cudagraph runner
self.g = None
@@ -100,16 +103,13 @@ def collapse_argpool(self, *args, **kwargs):
for i in range(argpool_size):
# collapse args; Just pick the ith value
- args_list[i] = tuple([
- arg[i] if isinstance(arg, ArgPool) else arg
- for arg in args_list[i]
- ])
+ args_list[i] = tuple(
+ [arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]]
+ )
# collapse kwargs
kwargs_i = kwargs_list[i]
- arg_pool_keys = [
- k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
- ]
+ arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)]
for k in arg_pool_keys:
# again just pick the ith value
kwargs_i[k] = kwargs_i[k][i]
@@ -142,7 +142,7 @@ def get_cuda_graph_runner(self):
def run_cudagrah(self) -> TMeasurement:
assert self.use_cuda_graph
- globals = {'g': self.g}
+ globals = {"g": self.g}
return TBenchmark.Timer(
stmt="g.replay()",
@@ -162,15 +162,15 @@ def run_eager(self) -> TMeasurement:
has_arg_pool = self.args_iterator.n_args > 1
if has_arg_pool:
- setup = '''
+ setup = """
args_iterator.reset()
args_it = args_iterator.__next__()
- '''
- stmt = '''
+ """
+ stmt = """
args, kwargs = next(args_it)
fn(*args, **kwargs)
- '''
- globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
+ """
+ globals = {"fn": self.fn, "args_iterator": self.args_iterator}
else:
# no arg pool. Just use the args and kwargs directly
self.args_iterator.reset()
@@ -178,10 +178,10 @@ def run_eager(self) -> TMeasurement:
args, kwargs = next(args_it)
setup = ""
- stmt = '''
+ stmt = """
fn(*args, **kwargs)
- '''
- globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}
+ """
+ globals = {"fn": self.fn, "args": args, "kwargs": kwargs}
return TBenchmark.Timer(
stmt=stmt,
diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py
index 5f94552e9dc8..d5701a8fbd6d 100644
--- a/benchmarks/overheads/benchmark_hashing.py
+++ b/benchmarks/overheads/benchmark_hashing.py
@@ -7,9 +7,8 @@
from vllm.utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k.
-LONG_PROMPT = ["You are an expert in large language models, aren't you?"
- ] * 1000
-LONG_PROMPT = ' '.join(LONG_PROMPT)
+LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
+LONG_PROMPT = " ".join(LONG_PROMPT)
def main(args):
@@ -30,32 +29,35 @@ def main(args):
print("------start generating------")
for i in range(3):
- profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
- globals(), locals())
+ profiler.runctx(
+ "llm.generate(LONG_PROMPT, sampling_params)", globals(), locals()
+ )
# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
- stats.sort_stats('cumulative')
+ stats.sort_stats("cumulative")
total_time = 0
total_calls = 0
for func in stats.stats:
- if 'hash_of_block' in func[2]:
+ if "hash_of_block" in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
- print(f"Hashing took {total_time:.2f} seconds,"
- f"{percentage:.2f}% of the total runtime.")
+ print(
+ f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime."
+ )
if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description='Benchmark the performance of hashing function in'
- 'automatic prefix caching.')
- parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
- parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
- parser.add_argument('--output-len', type=int, default=10)
- parser.add_argument('--enable-prefix-caching',
- action='store_true',
- help='enable prefix caching')
+ description="Benchmark the performance of hashing function in"
+ "automatic prefix caching."
+ )
+ parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k")
+ parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
+ parser.add_argument("--output-len", type=int, default=10)
+ parser.add_argument(
+ "--enable-prefix-caching", action="store_true", help="enable prefix caching"
+ )
args = parser.parse_args()
main(args)
diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml
new file mode 100644
index 000000000000..65b1e09a247e
--- /dev/null
+++ b/benchmarks/pyproject.toml
@@ -0,0 +1,49 @@
+# This local pyproject file is part of the migration from yapf to ruff format.
+# It uses the same core rules as the main pyproject.toml file, but with the
+# following differences:
+# - ruff line length is overridden to 88
+# - deprecated typing ignores (UP006, UP035) have been removed
+
+[tool.ruff]
+line-length = 88
+
+[tool.ruff.lint.per-file-ignores]
+"vllm/third_party/**" = ["ALL"]
+"vllm/version.py" = ["F401"]
+"vllm/_version.py" = ["ALL"]
+
+[tool.ruff.lint]
+select = [
+ # pycodestyle
+ "E",
+ # Pyflakes
+ "F",
+ # pyupgrade
+ "UP",
+ # flake8-bugbear
+ "B",
+ # flake8-simplify
+ "SIM",
+ # isort
+ "I",
+ # flake8-logging-format
+ "G",
+]
+ignore = [
+ # star imports
+ "F405", "F403",
+ # lambda expression assignment
+ "E731",
+ # Loop control variable not used within loop body
+ "B007",
+ # f-string format
+ "UP032",
+ # Can remove once 3.10+ is the minimum Python version
+ "UP007",
+]
+
+[tool.ruff.lint.isort]
+known-first-party = ["vllm"]
+
+[tool.ruff.format]
+docstring-code-format = true
\ No newline at end of file
diff --git a/benchmarks/run_structured_output_benchmark.sh b/benchmarks/run_structured_output_benchmark.sh
index 53dc7ed70b9c..b043ab83e460 100755
--- a/benchmarks/run_structured_output_benchmark.sh
+++ b/benchmarks/run_structured_output_benchmark.sh
@@ -1,32 +1,98 @@
#!/bin/bash
-# Define the model to use
-MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
-
-# Define the backend to use
-BACKEND=${2:-"vllm"}
-
-# Define the dataset to use
-DATASET=${3:-"xgrammar_bench"}
-
+# default values
+MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"}
+BACKEND=${BACKEND:-"vllm"}
+DATASET=${DATASET:-"xgrammar_bench"}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-OUTPUT_DIR=${4:-"$SCRIPT_DIR/structured_output_benchmark_results"}
+OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"}
+PORT=${PORT:-8000}
+STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1}
+TOTAL_SECONDS=${TOTAL_SECONDS:-90}
+MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300}
+TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"}
-GUIDED_RATIO=${5:-0.5}
+usage() {
+ echo "Usage: $0 [options]"
+ echo "Options:"
+ echo " --model MODEL Model to benchmark (default: $MODEL)"
+ echo " --backend BACKEND Backend to use (default: $BACKEND)"
+ echo " --dataset DATASET Dataset to use (default: $DATASET)"
+ echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)"
+ echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)"
+ echo " --port PORT Port to use (default: $PORT)"
+ echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)"
+ echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)"
+ echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)"
+ echo " -h, --help Show this help message and exit"
+ exit 0
+}
+
+# parse command line arguments
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --model)
+ MODEL="$2"
+ shift 2
+ ;;
+ --backend)
+ BACKEND="$2"
+ shift 2
+ ;;
+ --dataset)
+ DATASET="$2"
+ shift 2
+ ;;
+ --max-new-tokens)
+ MAX_NEW_TOKENS="$2"
+ shift 2
+ ;;
+ --output-dir)
+ OUTPUT_DIR="$2"
+ shift 2
+ ;;
+ --port)
+ PORT="$2"
+ shift 2
+ ;;
+ --structured-output-ratio)
+ STRUCTURED_OUTPUT_RATIO="$2"
+ shift 2
+ ;;
+ --tokenizer-mode)
+ TOKENIZER_MODE="$2"
+ shift 2
+ ;;
+ --total-seconds)
+ TOTAL_SECONDS="$2"
+ shift 2
+ ;;
+ -h|--help)
+ usage
+ ;;
+ *)
+ echo "Unknown argument: $1\n"
+ usage
+ ;;
+ esac
+done
# Create output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR"
# Define QPS values to test
-QPS_VALUES=(70 60 50 25 20 15 10)
+QPS_VALUES=(25 20 15 10 5 1)
# Common parameters
COMMON_PARAMS="--backend $BACKEND \
--model $MODEL \
--dataset $DATASET \
- --structured-output-ratio $GUIDED_RATIO \
+ --structured-output-ratio $STRUCTURED_OUTPUT_RATIO \
--save-results \
- --result-dir $OUTPUT_DIR"
+ --result-dir $OUTPUT_DIR \
+ --output-len $MAX_NEW_TOKENS \
+ --port $PORT \
+ --tokenizer-mode $TOKENIZER_MODE"
echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND"
@@ -45,12 +111,15 @@ for qps in "${QPS_VALUES[@]}"; do
# Construct filename for this run
FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
+ NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc)
+ NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part
+ echo "Running benchmark with $NUM_PROMPTS prompts"
+
# Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
--request-rate $qps \
--result-filename "$FILENAME" \
- --tokenizer-mode ${TOKENIZER_MODE:-"auto"} \
- --port ${PORT:-8000}
+ --num-prompts $NUM_PROMPTS
echo "Completed benchmark with QPS: $qps"
echo "----------------------------------------"
diff --git a/cmake/utils.cmake b/cmake/utils.cmake
index c9cd099b82a7..12e4e39024f5 100644
--- a/cmake/utils.cmake
+++ b/cmake/utils.cmake
@@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
"${multiValueArgs}" ${ARGN} )
foreach(_ARCH ${arg_CUDA_ARCHS})
- string(REPLACE "." "" _ARCH "${_ARCH}")
- set_gencode_flag_for_srcs(
- SRCS ${arg_SRCS}
- ARCH "compute_${_ARCH}"
- CODE "sm_${_ARCH}")
+ # handle +PTX suffix: generate both sm and ptx codes if requested
+ string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
+ if(NOT _HAS_PTX EQUAL -1)
+ string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
+ string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
+ set_gencode_flag_for_srcs(
+ SRCS ${arg_SRCS}
+ ARCH "compute_${_STRIPPED_ARCH}"
+ CODE "sm_${_STRIPPED_ARCH}")
+ set_gencode_flag_for_srcs(
+ SRCS ${arg_SRCS}
+ ARCH "compute_${_STRIPPED_ARCH}"
+ CODE "compute_${_STRIPPED_ARCH}")
+ else()
+ string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
+ set_gencode_flag_for_srcs(
+ SRCS ${arg_SRCS}
+ ARCH "compute_${_STRIPPED_ARCH}"
+ CODE "sm_${_STRIPPED_ARCH}")
+ endif()
endforeach()
if (${arg_BUILD_PTX_FOR_ARCH})
@@ -251,7 +266,10 @@ endmacro()
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `.[letter]` compute the "loose intersection" with the
-# `TGT_CUDA_ARCHS` list of gencodes.
+# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
+# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
+# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
+# architecture in `SRC_CUDA_ARCHS`.
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
@@ -268,44 +286,63 @@ endmacro()
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
+# Example With PTX:
+# SRC_CUDA_ARCHS="8.0+PTX"
+# TGT_CUDA_ARCHS="9.0"
+# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
+# OUT_CUDA_ARCHS="8.0+PTX"
+#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
- list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
- set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
+ set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
+ set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
+
+ # handle +PTX suffix: separate base arch for matching, record PTX requests
+ set(_PTX_ARCHS)
+ foreach(_arch ${_SRC_CUDA_ARCHS})
+ if(_arch MATCHES "\\+PTX$")
+ string(REPLACE "+PTX" "" _base "${_arch}")
+ list(APPEND _PTX_ARCHS "${_base}")
+ list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
+ list(APPEND _SRC_CUDA_ARCHS "${_base}")
+ endif()
+ endforeach()
+ list(REMOVE_DUPLICATES _PTX_ARCHS)
+ list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set(_CUDA_ARCHS)
- if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
- list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
- if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
- list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
+ if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
+ list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
+ if ("9.0" IN_LIST TGT_CUDA_ARCHS)
+ list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
set(_CUDA_ARCHS "9.0a")
endif()
endif()
- if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
- list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
+ if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
+ list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
- list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
+ list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
set(_CUDA_ARCHS "10.0a")
endif()
endif()
- list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
+ list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# is less or equal to ARCH (but has the same major version since SASS binary
# compatibility is only forward compatible within the same major version).
- foreach(_ARCH ${TGT_CUDA_ARCHS_})
+ foreach(_ARCH ${_TGT_CUDA_ARCHS})
set(_TMP_ARCH)
# Extract the major version of the target arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
- foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
+ foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
# Extract the major version of the source arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
- # Check major-version match AND version-less-or-equal
+ # Check version-less-or-equal, and allow PTX arches to match across majors
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
- if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
+ if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
set(_TMP_ARCH "${_SRC_ARCH}")
endif()
else()
@@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS)
+
+ # reapply +PTX suffix to architectures that requested PTX
+ set(_FINAL_ARCHS)
+ foreach(_arch ${_CUDA_ARCHS})
+ if(_arch IN_LIST _PTX_ARCHS)
+ list(APPEND _FINAL_ARCHS "${_arch}+PTX")
+ else()
+ list(APPEND _FINAL_ARCHS "${_arch}")
+ endif()
+ endforeach()
+ set(_CUDA_ARCHS ${_FINAL_ARCHS})
+
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()
diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu
index 88275dbdd83a..55e659679701 100644
--- a/csrc/activation_kernels.cu
+++ b/csrc/activation_kernels.cu
@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
+ if (num_tokens == 0) { \
+ return; \
+ } \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh
index eb216dc8baf1..79a546554fa1 100644
--- a/csrc/attention/attention_kernels.cuh
+++ b/csrc/attention/attention_kernels.cuh
@@ -172,7 +172,7 @@ __device__ void paged_attention_kernel(
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
- // For example, if the the thread group size is 4, then the first thread in
+ // For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
@@ -259,7 +259,7 @@ __device__ void paged_attention_kernel(
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
- // For example, if the the thread group size is 4, then the first thread in
+ // For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
diff --git a/csrc/attention/vertical_slash_index.cu b/csrc/attention/vertical_slash_index.cu
new file mode 100644
index 000000000000..c1b45b143f4e
--- /dev/null
+++ b/csrc/attention/vertical_slash_index.cu
@@ -0,0 +1,401 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+#include
+
+#include
+
+#include
+
+__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
+ int64_t range_end, int64_t block_size,
+ int64_t input_block_count, int64_t kv_seqlen) {
+ if (range_start >= kv_seqlen) {
+ return input_block_count;
+ }
+ if (range_end > kv_seqlen) {
+ range_end = kv_seqlen;
+ }
+ int64_t current_block_count = input_block_count;
+ for (int idx = range_start; idx < range_end; idx += block_size) {
+ block_offset[current_block_count++] = idx;
+ }
+ return current_block_count;
+}
+
+__global__ void convert_vertical_slash_indexes_kernel(
+ const int* q_seqlens, // [BATCH, ]
+ const int* kv_seqlens, // [BATCH, ]
+ const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
+ const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
+ int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
+ int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
+ int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
+ int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
+ int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
+ int64_t NNZ_V, int64_t NNZ_S,
+ bool causal // True for intra, False for succ
+) {
+ const int batch_idx = blockIdx.y;
+ const int head_idx = blockIdx.x;
+ const int group_idx = blockIdx.z;
+
+ int64_t q_seqlen = q_seqlens[batch_idx];
+ int64_t kv_seqlen = kv_seqlens[batch_idx];
+ int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
+ int64_t start_m = block_idx_m * BLOCK_SIZE_M;
+ if (start_m >= q_seqlen) {
+ return;
+ }
+ int64_t end_m = start_m + BLOCK_SIZE_M;
+ vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
+ slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
+ int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
+ block_count += row_offset;
+ block_offset += row_offset * NNZ_S;
+ column_count += row_offset;
+ column_index += row_offset * NNZ_V;
+
+ bool has_slash = true;
+ int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
+ int64_t s = 0, v = 0;
+ int64_t v_idx = vertical_indexes[v++];
+ int64_t s_idx = slash_indexes[s++];
+ if (causal) {
+ while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
+ s_idx = slash_indexes[s++];
+ }
+ if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
+ s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
+ } else {
+ while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
+ s_idx = slash_indexes[s++];
+ }
+ if (s_idx > end_m + kv_seqlen) has_slash = false;
+ s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
+ }
+
+ int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
+ if (!has_slash) {
+ if (causal) {
+ range_start = (kv_seqlen - q_seqlen) + end_m;
+ range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
+ } else {
+ range_start = kv_seqlen;
+ range_end = kv_seqlen + BLOCK_SIZE_N;
+ }
+ }
+
+ bool slash_finished = false;
+ while (1) {
+ if (v_idx < range_end) {
+ if (v_idx < range_start) {
+ column_index[tmp_col_cnt++] = v_idx;
+ }
+ if (v < NNZ_V) {
+ v_idx = vertical_indexes[v++];
+ } else {
+ if (causal)
+ v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
+ else
+ v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
+ }
+ } else {
+ if ((s < NNZ_S && causal) ||
+ (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
+ if (causal)
+ s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
+ BLOCK_SIZE_M);
+ else
+ s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
+ } else {
+ if (v == NNZ_V || (v_idx > range_start && causal)) {
+ // add the last vertical if no more slash
+ if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
+ column_index[tmp_col_cnt++] = v_idx;
+ }
+ tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
+ BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
+ break;
+ } else {
+ if (causal) {
+ range_start = (kv_seqlen - q_seqlen) + end_m;
+ range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
+ } else {
+ // if slash_finished but there are vertical left, save current
+ // blocks
+ tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
+ BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
+ range_start = kv_seqlen;
+ range_end = kv_seqlen + BLOCK_SIZE_N;
+ }
+ slash_finished = true;
+ }
+ }
+ if (!slash_finished) {
+ if (s_idx > range_end + BLOCK_SIZE_M) {
+ tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
+ BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
+ range_start = s_idx - BLOCK_SIZE_M;
+ range_end = s_idx;
+ } else if (s_idx > range_end) {
+ range_end += BLOCK_SIZE_M;
+ }
+ }
+ }
+ }
+
+ block_count[0] = tmp_blk_cnt;
+ column_count[0] = tmp_col_cnt;
+}
+
+void convert_vertical_slash_indexes_64x64(
+ const int* q_seqlens, // [BATCH, ]
+ const int* kv_seqlens, // [BATCH, ]
+ const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
+ const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
+ int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
+ int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
+ int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
+ int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
+ int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
+ int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
+ const int N_THREADS = 64;
+ const dim3 dimBlock(N_THREADS);
+ const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
+ convert_vertical_slash_indexes_kernel<<>>(
+ q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
+ block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
+ BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
+}
+
+/**
+ * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
+ *
+ * This function builds the index of each row of blocks from vertical indices
+ * and slash indices. The vertical indices are treated as points, while the
+ * slash indices are converted as ranges. The output consists of the merged
+ * ranges and separate column indices, where the ranges are represented by
+ * block indices.
+ *
+ * The implementation is referenced from the original MInference repo:
+ * https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
+ */
+void convert_vertical_slash_indexes(
+ torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
+ torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
+ torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
+ torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
+ torch::Tensor q_seqlens, // [BATCH, ]
+ torch::Tensor kv_seqlens, // [BATCH, ]
+ torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
+ torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
+ int64_t context_size, int64_t block_size_M, int64_t block_size_N,
+ bool causal) {
+ cudaSetDevice(q_seqlens.get_device());
+
+ int batch_size = slash_indexes.size(0);
+ int num_heads = slash_indexes.size(1);
+ int nnz_slash = slash_indexes.size(2);
+ int nnz_vertical = vertical_indexes.size(2);
+ int num_rows = (context_size + block_size_M - 1) / block_size_M;
+
+ convert_vertical_slash_indexes_64x64(
+ q_seqlens.data_ptr(), kv_seqlens.data_ptr(),
+ vertical_indexes.data_ptr(), slash_indexes.data_ptr(),
+ block_count.data_ptr(), block_offset.data_ptr(),
+ column_count.data_ptr(), column_index.data_ptr(), batch_size,
+ num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
+ causal);
+}
+
+__global__ void convert_vertical_slash_indexes_kernel_mergehead(
+ const int* q_seqlens, // [BATCH, ]
+ const int* kv_seqlens, // [BATCH, ]
+ const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
+ const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
+ const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
+ int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
+ int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
+ int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
+ int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
+ int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
+ int64_t NNZ_V, int64_t NNZ_S,
+ bool causal // True for intra, False for succ
+) {
+ const int batch_idx = blockIdx.y;
+ const int head_idx = blockIdx.x;
+ const int group_idx = blockIdx.z;
+
+ int64_t q_seqlen = q_seqlens[batch_idx];
+ int64_t kv_seqlen = kv_seqlens[batch_idx];
+ int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
+ int64_t start_m = block_idx_m * BLOCK_SIZE_M;
+ if (start_m >= q_seqlen) {
+ return;
+ }
+ int64_t end_m = start_m + BLOCK_SIZE_M;
+ vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
+ slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
+ int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
+ block_count += row_offset;
+ block_offset += row_offset * NNZ_S;
+ column_count += row_offset;
+ column_index += row_offset * NNZ_V;
+
+ // MergeHead: each head has it's unique max topk NNZ_V๏ผNNZ_S. (NNZ_V๏ผNNZ_S
+ // above is buffer size, use to compute offset)
+ NNZ_S = per_head_slash_topkv[head_idx];
+ NNZ_V = per_head_vertical_topkv[head_idx];
+
+ bool has_slash = true;
+ int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
+ int64_t s = 0, v = 0;
+ int64_t v_idx = vertical_indexes[v++];
+ int64_t s_idx = slash_indexes[s++];
+ if (causal) {
+ while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
+ s_idx = slash_indexes[s++];
+ }
+ if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
+ s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
+ } else {
+ while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
+ s_idx = slash_indexes[s++];
+ }
+ if (s_idx > end_m + kv_seqlen) has_slash = false;
+ s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
+ }
+
+ int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
+ if (!has_slash) {
+ if (causal) {
+ range_start = (kv_seqlen - q_seqlen) + end_m;
+ range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
+ } else {
+ range_start = kv_seqlen;
+ range_end = kv_seqlen + BLOCK_SIZE_N;
+ }
+ }
+
+ bool slash_finished = false;
+ while (1) {
+ if (v_idx < range_end) {
+ if (v_idx < range_start) {
+ column_index[tmp_col_cnt++] = v_idx;
+ }
+ if (v < NNZ_V) {
+ v_idx = vertical_indexes[v++];
+ } else {
+ if (causal)
+ v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
+ else
+ v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
+ }
+ } else {
+ if ((s < NNZ_S && causal) ||
+ (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
+ if (causal)
+ s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
+ BLOCK_SIZE_M);
+ else
+ s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
+ } else {
+ if (v == NNZ_V || (v_idx > range_start && causal)) {
+ // add the last vertical if no more slash
+ if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
+ column_index[tmp_col_cnt++] = v_idx;
+ }
+ tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
+ BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
+ break;
+ } else {
+ if (causal) {
+ range_start = (kv_seqlen - q_seqlen) + end_m;
+ range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
+ } else {
+ // if slash_finished but there are vertical left, save current
+ // blocks
+ tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
+ BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
+ range_start = kv_seqlen;
+ range_end = kv_seqlen + BLOCK_SIZE_N;
+ }
+ slash_finished = true;
+ }
+ }
+ if (!slash_finished) {
+ if (s_idx > range_end + BLOCK_SIZE_M) {
+ tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
+ BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
+ range_start = s_idx - BLOCK_SIZE_M;
+ range_end = s_idx;
+ } else if (s_idx > range_end) {
+ range_end += BLOCK_SIZE_M;
+ }
+ }
+ }
+ }
+
+ block_count[0] = tmp_blk_cnt;
+ column_count[0] = tmp_col_cnt;
+}
+
+void convert_vertical_slash_indexes_64x64_mergehead(
+ const int* q_seqlens, // [BATCH, ]
+ const int* kv_seqlens, // [BATCH, ]
+ const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
+ const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
+ int* per_head_vertical_topkv, int* per_head_slash_topkv,
+ int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
+ int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
+ int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
+ int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
+ int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
+ int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
+ const int N_THREADS = 64;
+ const dim3 dimBlock(N_THREADS);
+ const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
+ convert_vertical_slash_indexes_kernel_mergehead<<>>(
+ q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
+ per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
+ column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
+ NNZ_V, NNZ_S, causal);
+}
+
+/**
+ * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
+ *
+ * Like the above convert_vertical_slash_indexes, but with
+ * pre-computed vertical and slash counts.
+ */
+void convert_vertical_slash_indexes_mergehead(
+ torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
+ torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
+ torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
+ torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
+ torch::Tensor q_seqlens, // [BATCH, ]
+ torch::Tensor kv_seqlens, // [BATCH, ]
+ torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
+ torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
+ torch::Tensor vertical_indices_count, // [N_HEADS, ]
+ torch::Tensor slash_indices_count, // [N_HEADS, ]
+ int64_t context_size, int64_t block_size_M, int64_t block_size_N,
+ bool causal) {
+ cudaSetDevice(q_seqlens.get_device());
+
+ int batch_size = slash_indexes.size(0);
+ int num_heads = slash_indexes.size(1);
+ int nnz_slash = slash_indexes.size(2);
+ int nnz_vertical = vertical_indexes.size(2);
+ int num_rows = (context_size + block_size_M - 1) / block_size_M;
+
+ convert_vertical_slash_indexes_64x64_mergehead(
+ q_seqlens.data_ptr(), kv_seqlens.data_ptr(),
+ vertical_indexes.data_ptr(), slash_indexes.data_ptr(),
+ vertical_indices_count.data_ptr(),
+ slash_indices_count.data_ptr(), block_count.data_ptr(),
+ block_offset.data_ptr(), column_count.data_ptr(),
+ column_index.data_ptr(), batch_size, num_heads, num_rows,
+ block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
+}
diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp
index c2ae554c9f8e..d0f85e23609b 100644
--- a/csrc/core/scalar_type.hpp
+++ b/csrc/core/scalar_type.hpp
@@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
+static inline constexpr auto kFE2M1f =
+ ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
@@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
+static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp
index cf67847b45ba..9a613ba588dd 100644
--- a/csrc/cpu/cpu_types_x86.hpp
+++ b/csrc/cpu/cpu_types_x86.hpp
@@ -19,6 +19,7 @@ namespace vec_op {
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp
index dbe0e30f5cbf..195872e8edd3 100644
--- a/csrc/cutlass_extensions/common.hpp
+++ b/csrc/cutlass_extensions/common.hpp
@@ -15,15 +15,6 @@
cutlassGetStatusString(error)); \
}
-/**
- * Panic wrapper for unwinding CUDA runtime errors
- */
-#define CUDA_CHECK(status) \
- { \
- cudaError_t error = status; \
- TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
- }
-
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
@@ -59,3 +50,13 @@ struct enable_sm90_only : Kernel {
#endif
}
};
+
+template
+struct enable_sm100_only : Kernel {
+ template
+ CUTLASS_DEVICE void operator()(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
+ Kernel::operator()(std::forward(args)...);
+#endif
+ }
+};
diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h
index dc6e0769b878..f7b75c48373f 100644
--- a/csrc/dispatch_utils.h
+++ b/csrc/dispatch_utils.h
@@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
+#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
+
+#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
+ AT_DISPATCH_SWITCH( \
+ TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu
index 98daf1a1b8e6..f62d08c17c6d 100644
--- a/csrc/mamba/causal_conv1d/causal_conv1d.cu
+++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu
@@ -13,6 +13,10 @@
#include
#include
+#ifdef USE_ROCM
+ namespace cub = hipcub;
+#endif
+
#include "static_switch.h"
@@ -501,15 +505,9 @@ void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
auto kernel = &causal_conv1d_fwd_kernel;
if (kSmemSize >= 48 * 1024) {
- #ifndef USE_ROCM
- C10_CUDA_CHECK(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
- #else
- // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
- #endif
}
kernel<<>>(params);
diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
index bd0a34119c82..0c9df925bdbf 100644
--- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
@@ -321,7 +321,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
auto kernel = &selective_scan_fwd_kernel;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py
index 902bcd9dfd21..15f008d4f61e 100644
--- a/csrc/moe/marlin_moe_wna16/generate_kernels.py
+++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py
@@ -31,7 +31,10 @@
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
-SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
+SCALAR_TYPES = [
+ "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
+ "vllm::kFE2M1f"
+]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
@@ -39,7 +42,7 @@
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
-GROUP_BLOCKS = [0, -1, 2, 4, 8]
+GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
@@ -72,6 +75,12 @@ def generate_new_kernels():
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
+ # nvfp4 only supports group_size == 16
+ if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
+ continue
+ # other quantization methods don't support group_size = 16
+ if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
+ continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h
index c40c33d01f37..537282aba8c8 100644
--- a/csrc/moe/marlin_moe_wna16/kernel.h
+++ b/csrc/moe/marlin_moe_wna16/kernel.h
@@ -7,17 +7,18 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
-#define MARLIN_KERNEL_PARAMS \
- const int4 *__restrict__ A, const int4 *__restrict__ B, \
- int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
- const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
- const int *__restrict__ g_idx, \
- const int32_t *__restrict__ sorted_token_ids_ptr, \
- const int32_t *__restrict__ expert_ids_ptr, \
- const int32_t *__restrict__ num_tokens_past_padded_ptr, \
- const float *__restrict__ topk_weights_ptr, int top_k, \
- bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
- int prob_n, int prob_k, int *locks, bool use_atomic_add, \
+#define MARLIN_KERNEL_PARAMS \
+ const int4 *__restrict__ A, const int4 *__restrict__ B, \
+ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
+ const int4 *__restrict__ scales_ptr, \
+ const uint16_t *__restrict__ scale2_ptr, \
+ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
+ const int32_t *__restrict__ sorted_token_ids_ptr, \
+ const int32_t *__restrict__ expert_ids_ptr, \
+ const int32_t *__restrict__ num_tokens_past_padded_ptr, \
+ const float *__restrict__ topk_weights_ptr, int top_k, \
+ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
+ int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h
index c9e199bcea1f..1c255396099d 100644
--- a/csrc/moe/marlin_moe_wna16/marlin_template.h
+++ b/csrc/moe/marlin_moe_wna16/marlin_template.h
@@ -301,9 +301,11 @@ __global__ void Marlin(
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
- const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
- // (k/groupsize)x(n/pack_factor)
- const int* __restrict__ g_idx, // int32 group indices of shape k
+ const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
+ // only)
+ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
+ // (k/groupsize)x(n/pack_factor)
+ const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
@@ -341,6 +343,16 @@ __global__ void Marlin(
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
+ constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
+ w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
+ // see comments of dequant.h for more details
+ constexpr bool dequant_skip_flop =
+ !is_int_type ||
+ has_zp && !is_zp_float && !std::is_same::value ||
+ has_zp && !is_zp_float && !(w_type == vllm::kU8);
+
+ scalar_t2 global_scale;
+
constexpr bool has_act_order = group_blocks == 0;
constexpr int pack_factor = 32 / w_type.size_bits();
@@ -348,7 +360,8 @@ __global__ void Marlin(
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
- const int scales_expert_stride = prob_n * prob_k / group_size / 8;
+ const int scales_expert_stride =
+ prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
@@ -460,9 +473,16 @@ __global__ void Marlin(
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
- sh_block_topk_weights[tid4 * 4 + i] =
- Dtype::num2num2(Dtype::float2num(
- topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
+ int idx = tid4 * 4 + i;
+ idx = idx < block_num_valid_tokens ? idx : 0;
+ if constexpr (w_type == vllm::kFE2M1f) {
+ sh_block_topk_weights[idx] = __hmul2(
+ global_scale, Dtype::num2num2(Dtype::float2num(
+ topk_weights_ptr[sh_block_sorted_ids[idx]])));
+ } else {
+ sh_block_topk_weights[idx] = Dtype::num2num2(
+ Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
+ }
}
}
}
@@ -493,6 +513,11 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id];
}
+ if constexpr (w_type == vllm::kFE2M1f) {
+ uint16_t val = scale2_ptr[expert_id];
+ global_scale = Dtype::num2num2(*reinterpret_cast(&val));
+ }
+
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
if constexpr (has_zp) {
@@ -606,7 +631,7 @@ __global__ void Marlin(
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
- ? thread_k_blocks / group_blocks
+ ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
@@ -664,7 +689,8 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
- s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
+ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
+ (w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x;
}
}
@@ -688,10 +714,20 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
- if constexpr (group_blocks != -1)
+ if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
+ auto warp_id = threadIdx.x / 32;
+ int n_warps = thread_n_blocks / 4;
+ int warp_row = warp_id / n_warps;
+
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
+ (threadIdx.x % 32) / 4;
+ s_sh_rd = s_sh_rd * 2 + warp_row % 2;
+
+ } else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
- else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
+ else if constexpr (group_blocks == -1 &&
+ (m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
else
@@ -801,7 +837,7 @@ __global__ void Marlin(
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
- if (sh_num_groups < act_s_max_num_groups) {
+ if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
}
@@ -1021,12 +1057,19 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
- int cur_group_id = k_blocks / group_blocks;
+ int cur_group_id =
+ k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
- reinterpret_cast(&frag_s[k % 2])[0] =
- sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
+ if constexpr (w_type_id != vllm::kFE2M1f.id()) {
+ reinterpret_cast(&frag_s[k % 2])[0] =
+ sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
+ } else {
+ reinterpret_cast(&frag_s[k % 2])[0] =
+ reinterpret_cast(
+ sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
+ }
}
}
@@ -1199,22 +1242,7 @@ __global__ void Marlin(
};
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
- if constexpr (has_zp && is_zp_float || !has_zp) {
- dequant(q, frag_b_ptr);
- } else {
- static_assert(has_zp && !is_zp_float);
- static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
- // If (has_zp && !is_zp_float),
- // we use not-zp version `dequant` function
- // to improve numerical accuracy.
- // Since both weight and zero point are dequanted using this logic,
- // the final dequanted weight would be correct.
- if constexpr (w_type_id == vllm::kU4.id()) {
- dequant(q, frag_b_ptr);
- } else if constexpr (w_type_id == vllm::kU8.id()) {
- dequant(q, frag_b_ptr);
- }
- }
+ dequant(q, frag_b_ptr);
};
// Execute the actual tensor core matmul of a sub-tile.
@@ -1244,13 +1272,23 @@ __global__ void Marlin(
dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2);
}
}
- if constexpr (has_zp && is_zp_float) {
+ if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
if (is_new_zp) {
reinterpret_cast(&frag_zp)[0] =
reinterpret_cast(&frag_zpf[k2])[0];
}
}
+ if constexpr (w_type == vllm::kFE2M1f) {
+ int s_quant_0 = reinterpret_cast(frag_s[k2])[0];
+ int s_quant_1 = reinterpret_cast(frag_s[k2])[1];
+
+ dequant_fp8_scales(s_quant_0,
+ reinterpret_cast(&frag_s[k2]));
+ dequant_fp8_scales(
+ s_quant_1, reinterpret_cast(&frag_s[k2]) + 2);
+ }
+
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
@@ -1259,7 +1297,10 @@ __global__ void Marlin(
FragB frag_b1;
int b_quant_0, b_quant_1;
- if constexpr (w_type.size_bits() == 4) {
+ if constexpr (w_type_id == vllm::kFE2M1f.id()) {
+ b_quant_1 = frag_b_quant[k2][0][j];
+ b_quant_0 = b_quant_1 << 8;
+ } else if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
@@ -1272,6 +1313,11 @@ __global__ void Marlin(
dequant_data(b_quant_0, reinterpret_cast(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast(&frag_b1));
+ if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
+ sub_zp(frag_b0, frag_zp[j], 0);
+ sub_zp(frag_b1, frag_zp[j], 1);
+ }
+
// Apply scale to frag_b0
if constexpr (has_act_order) {
static_assert(group_blocks != -1);
@@ -1279,7 +1325,8 @@ __global__ void Marlin(
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
- } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
+ } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
+ group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
@@ -1287,7 +1334,7 @@ __global__ void Marlin(
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub(frag_b1, s2.y, frag_zp[j].y);
- } else if constexpr (has_zp && group_blocks != -1) {
+ } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast(&frag_s[k2][j]));
@@ -1554,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
- w_type.size_bits() == 4 && !has_zp) {
+ w_type.size_bits() == 4 &&
+ (has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
}
+ if constexpr (w_type == vllm::kFE2M1f) {
+ if (!mul_topk_weights) {
+ res = __hmul2(res, global_scale);
+ }
+ }
+
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
@@ -1648,7 +1702,9 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) {
fetch_col_zp_to_shared();
- fetch_col_scale_to_shared();
+ if constexpr (!dequant_skip_flop) {
+ fetch_col_scale_to_shared();
+ }
}
}
fetch_to_shared(i, i, i < slice_iters, i);
@@ -1711,17 +1767,20 @@ __global__ void Marlin(
if constexpr (has_act_order) {
slice_k_start += tb_k * stages;
- slice_k_start_shared_fetch += tb_k * stages;
- int first_group_id = g_idx[slice_k_start];
- int last_g_idx = slice_k_start + stages * tb_k * 2;
- if (last_g_idx >= prob_k) {
- last_g_idx = prob_k - 1;
- }
- int last_group_id = g_idx[last_g_idx];
- if (last_group_id >= sh_first_group_id + sh_num_groups) {
- fetch_act_order_scales_to_shared(false, first_group_id,
- last_group_id);
- __syncthreads();
+
+ if (slice_k_start < prob_k) {
+ slice_k_start_shared_fetch += tb_k * stages;
+ int first_group_id = g_idx[slice_k_start];
+ int last_g_idx = slice_k_start + stages * tb_k * 2;
+ if (last_g_idx >= prob_k) {
+ last_g_idx = prob_k - 1;
+ }
+ int last_group_id = g_idx[last_g_idx];
+ if (last_group_id >= sh_first_group_id + sh_num_groups) {
+ fetch_act_order_scales_to_shared(false, first_group_id,
+ last_group_id);
+ __syncthreads();
+ }
}
}
if (slice_iters == 0) {
@@ -1737,7 +1796,8 @@ __global__ void Marlin(
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
- if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
+ if constexpr (!has_act_order && group_blocks == -1 &&
+ (has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
@@ -1747,7 +1807,8 @@ __global__ void Marlin(
}
thread_block_reduce();
- if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
+ if constexpr (!has_act_order && group_blocks == -1 &&
+ (has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
@@ -1771,7 +1832,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
- w_type.size_bits() == 8 && !has_zp) {
+ w_type.size_bits() == 8 &&
+ (has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu
index 00b4e934cc39..2cff04f699b0 100644
--- a/csrc/moe/marlin_moe_wna16/ops.cu
+++ b/csrc/moe/marlin_moe_wna16/ops.cu
@@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
+ // FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
@@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
+ #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
+ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
+ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
+
+ #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
+ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
+ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
+ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
+
+ #define FP4_GET_IF(W_TYPE) \
+ FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
+ FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
+ FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
+ FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
+
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
@@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
BIGGROUP_GET_IF(vllm::kFE4M3fn)
+ FP4_GET_IF(vllm::kFE2M1f)
+
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
@@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
- void* zp, void* g_idx, void* perm, void* a_tmp,
+ void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
@@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
bool m_block_size_8 = moe_block_size == 8;
if (has_zp) {
- TORCH_CHECK(q_type == vllm::kU4,
- "q_type must be u4 when has_zp = True. Got = ", q_type.str());
+ TORCH_CHECK(
+ q_type == vllm::kU4 || q_type == vllm::kU8,
+ "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
- TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
- q_type == vllm::kFE4M3fn,
- "q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
- "False. Got = ",
- q_type.str());
+ TORCH_CHECK(
+ q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
+ q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
+ "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
+ "has_zp = False. Got = ",
+ q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
+ const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
@@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<>>(
- A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
+ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
@@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
+ std::optional const& global_scale_or_none,
std::optional const& b_zeros_or_none,
std::optional const& g_idx_or_none,
std::optional const& perm_or_none, torch::Tensor& workspace,
@@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
}
}
+ torch::Tensor global_scale;
+ if (global_scale_or_none.has_value()) {
+ global_scale = global_scale_or_none.value();
+ TORCH_CHECK(b_q_type == vllm::kFE2M1f,
+ "global_scale can only be used for float4_e2m1f.");
+ } else {
+ global_scale = torch::empty({0}, options);
+ TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
+ "the global_scale parameter must be passed for float4_e2m1f.");
+ }
+
torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value();
@@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm(
if (has_zp) {
TORCH_CHECK(
- b_q_type == vllm::kU4,
- "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
+ b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
+ "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
- b_q_type == vllm::kFE4M3fn,
- "b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
- "False. Got = ",
+ b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
+ "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
+ "float4_e2m1f when "
+ "has_zp = False. Got = ",
b_q_type.str());
}
@@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
+ void* scales_ptr;
+ if (b_q_type == vllm::kFE2M1f) {
+ scales_ptr = b_scales.data_ptr();
+ } else {
+ scales_ptr = b_scales.data_ptr();
+ }
+
MARLIN_NAMESPACE_NAME::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
- c_tmp.data_ptr(), b_scales.data_ptr(),
+ c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
@@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
+ void* scales_ptr;
+ if (b_q_type == vllm::kFE2M1f) {
+ scales_ptr = b_scales.data_ptr();
+ } else {
+ scales_ptr = b_scales.data_ptr();
+ }
+
MARLIN_NAMESPACE_NAME::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(),
- c.data_ptr(), c_tmp.data_ptr(),
- b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
- perm.data_ptr(), a_tmp.data_ptr(),
+ c.data_ptr(), c_tmp.data_ptr(), scales_ptr,
+ global_scale.data_ptr(), b_zeros.data_ptr(),
+ g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu
index d7be769458e3..6b6a9d04a60f 100644
--- a/csrc/moe/moe_align_sum_kernels.cu
+++ b/csrc/moe/moe_align_sum_kernels.cu
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
if (use_global_memory) {
- VLLM_DISPATCH_INTEGRAL_TYPES(
+ VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr());
});
} else if (use_i16) {
- VLLM_DISPATCH_INTEGRAL_TYPES(
+ VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
});
} else {
- VLLM_DISPATCH_INTEGRAL_TYPES(
+ VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel;
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");
- VLLM_DISPATCH_INTEGRAL_TYPES(
+ VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h
index 0bae119a7c46..8fda434d452f 100644
--- a/csrc/moe/moe_ops.h
+++ b/csrc/moe/moe_ops.h
@@ -28,4 +28,6 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit);
-#endif
\ No newline at end of file
+#endif
+
+bool moe_permute_unpermute_supported();
\ No newline at end of file
diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu
index 76d5f0eab021..9a7465261abf 100644
--- a/csrc/moe/moe_permute_unpermute_op.cu
+++ b/csrc/moe/moe_permute_unpermute_op.cu
@@ -5,6 +5,9 @@
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
+// moe_permute kernels require at least CUDA 12.0
+#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
+
void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
@@ -127,7 +130,45 @@ void moe_unpermute(
});
}
+#else
+
+void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
+ torch::Tensor& topk_ids,
+ const torch::Tensor& token_expert_indicies,
+ const std::optional& expert_map,
+ int64_t n_expert, int64_t n_local_expert, int64_t topk,
+ const std::optional& align_block_size,
+ torch::Tensor& permuted_input,
+ torch::Tensor& expert_first_token_offset,
+ torch::Tensor& src_row_id2dst_row_id_map,
+ torch::Tensor& m_indices) {
+ TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
+}
+
+void moe_unpermute(const torch::Tensor& input,
+ const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
+ const torch::Tensor& token_expert_indicies,
+ const std::optional& expert_map,
+ int64_t n_expert, int64_t n_local_expert, int64_t topk,
+ const std::optional& align_block_size,
+ torch::Tensor& permuted_input,
+ torch::Tensor& expert_first_token_offset,
+ torch::Tensor& src_row_id2dst_row_id_map,
+ torch::Tensor& m_indices) {
+ TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
+}
+
+#endif
+
+bool moe_permute_unpermute_supported() {
+#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
+ return true;
+#else
+ return false;
+#endif
+}
+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
-}
\ No newline at end of file
+}
diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
index aa353d0f0437..de2c153882d9 100644
--- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
+++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
@@ -1,6 +1,9 @@
#include "moe_permute_unpermute_kernel.h"
+// moe_permute kernels require at least CUDA 12.0
+#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
+
// CubKeyValueSorter definition begin
CubKeyValueSorter::CubKeyValueSorter()
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
@@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
int num_experts) {
auto tidx = threadIdx.x;
auto bidx = blockIdx.x;
- auto lidx = tidx & 31;
- auto widx = tidx >> 5;
- auto warp_count = (blockDim.x + 31) >> 5;
auto offset = bidx * blockDim.x;
auto bound = min(offset + blockDim.x, size);
extern __shared__ int smem_expert_map[];
@@ -226,4 +226,6 @@ void getMIndices(int64_t* expert_first_token_offset,
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
-}
\ No newline at end of file
+}
+
+#endif
diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu
index de9747b60252..a9379032245d 100644
--- a/csrc/moe/topk_softmax_kernels.cu
+++ b/csrc/moe/topk_softmax_kernels.cu
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}
-template
-__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
- int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
+template
+__launch_bounds__(TPB) __global__ void moeTopK(
+ const float* inputs_after_softmax,
+ const bool* finished,
+ float* output,
+ IndType* indices,
+ int* source_rows,
+ const int num_experts,
+ const int k,
+ const int start_expert,
+ const int end_expert)
{
using cub_kvp = cub::KeyValuePair;
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/
-template
+template
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
- void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
+ void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
@@ -397,8 +405,8 @@ struct TopkConstants
};
} // namespace detail
-template
-void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
+template
+void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
+template
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
- int* topk_indicies,
+ IndType* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
@@ -493,14 +502,32 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
- vllm::moe::topkGatingSoftmaxKernelLauncher(
- gating_output.data_ptr(),
- topk_weights.data_ptr(),
- topk_indices.data_ptr(),
- token_expert_indices.data_ptr(),
- softmax_workspace.data_ptr(),
- num_tokens,
- num_experts,
- topk,
- stream);
+
+ if(topk_indices.scalar_type() == at::ScalarType::Int)
+ {
+ vllm::moe::topkGatingSoftmaxKernelLauncher(
+ gating_output.data_ptr(),
+ topk_weights.data_ptr(),
+ topk_indices.data_ptr(),
+ token_expert_indices.data_ptr(),
+ softmax_workspace.data_ptr(),
+ num_tokens,
+ num_experts,
+ topk,
+ stream);
+ }
+ else
+ {
+ assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
+ vllm::moe::topkGatingSoftmaxKernelLauncher(
+ gating_output.data_ptr(),
+ topk_weights.data_ptr(),
+ topk_indices.data_ptr(),
+ token_expert_indices.data_ptr(),
+ softmax_workspace.data_ptr(),
+ num_tokens,
+ num_experts,
+ topk,
+ stream);
+ }
}
diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp
index 2a8b9bb39caa..7d35ec79ead4 100644
--- a/csrc/moe/torch_bindings.cpp
+++ b/csrc/moe/torch_bindings.cpp
@@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Calculate the result of moe by summing up the partial results
// from all selected experts.
- m.def("moe_sum(Tensor! input, Tensor output) -> ()");
+ m.def("moe_sum(Tensor input, Tensor! output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
// Aligning the number of tokens to be processed by each expert such
@@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
- "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
+ "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
+ "b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
@@ -76,7 +77,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()");
- // conditionally compiled so impl registration is in source file
+
+ m.def("moe_permute_unpermute_supported() -> bool");
+ m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
#endif
}
diff --git a/csrc/ops.h b/csrc/ops.h
index 1dfd2e067e85..7044b4588b81 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
+
+void convert_vertical_slash_indexes(
+ torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
+ torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
+ torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
+ torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
+ torch::Tensor q_seqlens, // [BATCH, ]
+ torch::Tensor kv_seqlens, // [BATCH, ]
+ torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
+ torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
+ int64_t context_size, int64_t block_size_M, int64_t block_size_N,
+ bool causal);
+
+void convert_vertical_slash_indexes_mergehead(
+ torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
+ torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
+ torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
+ torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
+ torch::Tensor q_seqlens, // [BATCH, ]
+ torch::Tensor kv_seqlens, // [BATCH, ]
+ torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
+ torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
+ torch::Tensor vertical_indices_count, // [N_HEADS, ]
+ torch::Tensor slash_indices_count, int64_t context_size,
+ int64_t block_size_M, int64_t block_size_N, bool causal);
#endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
@@ -208,6 +233,12 @@ void cutlass_moe_mm(
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
+void cutlass_fp4_group_mm(
+ torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
+ const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
+ const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
+ const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
+
void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@@ -235,6 +266,12 @@ std::vector cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_scale,
torch::Tensor const& input_scale);
+
+void scaled_fp4_experts_quant(
+ torch::Tensor& output, torch::Tensor& output_scale,
+ torch::Tensor const& input, torch::Tensor const& input_global_scale,
+ torch::Tensor const& input_offset_by_experts,
+ torch::Tensor const& output_scale_offset_by_experts);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu
index ef6dd1c0978d..266f2a0667a2 100644
--- a/csrc/pos_encoding_kernels.cu
+++ b/csrc/pos_encoding_kernels.cu
@@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding(
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
- const int64_t query_stride, const int64_t key_stride) {
+ const int64_t query_stride, const int64_t key_stride,
+ const int64_t head_stride) {
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
@@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding(
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
- const int64_t token_head = token_idx * query_stride + head_idx * head_size;
+ const int64_t token_head =
+ token_idx * query_stride + head_idx * head_stride;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
@@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
- const int64_t token_head = token_idx * key_stride + head_idx * head_size;
+ const int64_t token_head =
+ token_idx * key_stride + head_idx * head_stride;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
@@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel(
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
- const int num_heads, const int num_kv_heads, const int head_size) {
+ const int64_t head_stride, const int num_heads, const int num_kv_heads,
+ const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
@@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel(
apply_rotary_embedding(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
- token_idx, query_stride, key_stride);
+ token_idx, query_stride, key_stride, head_stride);
}
template
@@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
- // or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
- const int num_heads, const int num_kv_heads, const int head_size) {
+ const int64_t head_stride, const int num_heads, const int num_kv_heads,
+ const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
@@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
apply_rotary_embedding(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
- token_idx, query_stride, key_stride);
+ token_idx, query_stride, key_stride, head_stride);
}
} // namespace vllm
@@ -179,6 +183,12 @@ void rotary_embedding(
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
+ // Determine head stride: for [*, heads, head_size] use stride of last dim;
+ // for flat [*, heads*head_size], heads blocks are contiguous of size
+ // head_size
+ int query_ndim = query.dim();
+ int64_t head_stride =
+ (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
@@ -190,14 +200,14 @@ void rotary_embedding(
positions.data_ptr(), query.data_ptr(),
key.has_value() ? key->data_ptr() : nullptr,
cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride,
- num_heads, num_kv_heads, head_size);
+ head_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::rotary_embedding_kernel
<<>>(
positions.data_ptr(), query.data_ptr(),
key.has_value() ? key->data_ptr() : nullptr,
cos_sin_cache.data_ptr(), rot_dim, query_stride,
- key_stride, num_heads, num_kv_heads, head_size);
+ key_stride, head_stride, num_heads, num_kv_heads, head_size);
}
});
}
@@ -263,6 +273,12 @@ void batched_rotary_embedding(
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
+ // Determine head stride: for [*, heads, head_size] use stride of last dim;
+ // for flat [*, heads*head_size], heads blocks are contiguous of size
+ // head_size
+ int query_ndim = query.dim();
+ int64_t head_stride =
+ (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
@@ -276,7 +292,7 @@ void batched_rotary_embedding(
key.has_value() ? key->data_ptr() : nullptr,
cos_sin_cache.data_ptr(),
cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride,
- key_stride, num_heads, num_kv_heads, head_size);
+ key_stride, head_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::batched_rotary_embedding_kernel
<<>>(
@@ -284,7 +300,7 @@ void batched_rotary_embedding(
key.has_value() ? key->data_ptr() : nullptr,
cos_sin_cache.data_ptr(),
cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride,
- key_stride, num_heads, num_kv_heads, head_size);
+ key_stride, head_stride, num_heads, num_kv_heads, head_size);
}
});
}
diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu
index acc3d6722028..67e9149c1379 100644
--- a/csrc/quantization/activation_kernels.cu
+++ b/csrc/quantization/activation_kernels.cu
@@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel(
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
- TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
+ TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
+ out.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.size(-1) % 2 == 0);
diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
index e79785827189..bf46cce60a23 100644
--- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float dst = std::nearbyint(x);
// saturate
- dst = std::clamp(dst, i8_min, i8_max);
+
+ // See https://github.com/pytorch/pytorch/issues/127666
+ // See https://github.com/llvm/llvm-project/issues/95183
+ // hip-clang std::clamp __glibcxx_assert_fail host function when building on
+ // Arch/gcc14. The following replaces std::clamp usage with similar logic
+ // dst = std::clamp(dst, i8_min, i8_max);
+ dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
return static_cast(dst);
#else
// CUDA path
@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast(std::numeric_limits::max());
// saturate
- int32_t dst = std::clamp(x, i8_min, i8_max);
+
+ // See https://github.com/pytorch/pytorch/issues/127666
+ // See https://github.com/llvm/llvm-project/issues/95183
+ // hip-clang std::clamp __glibcxx_assert_fail host function when building on
+ // Arch/gcc14. The following replaces std::clamp usage with similar logic
+ // int32_t dst = std::clamp(x, i8_min, i8_max);
+ int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
return static_cast(dst);
#else
// CUDA path
diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
new file mode 100644
index 000000000000..84492553c02f
--- /dev/null
+++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
@@ -0,0 +1,27 @@
+#include "scaled_mm_kernels.hpp"
+#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
+#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
+
+namespace vllm {
+
+void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
+ torch::Tensor const& a,
+ torch::Tensor const& b,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ TORCH_CHECK(
+ a.size(0) % 4 == 0,
+ "Input tensor must have a number of rows that is a multiple of 4. ",
+ "but got: ", a.size(0), " rows.");
+ if (out.dtype() == torch::kBFloat16) {
+ cutlass_gemm_blockwise_sm100_fp8_dispatch(
+ out, a, b, a_scales, b_scales);
+
+ } else {
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
+ cutlass_gemm_blockwise_sm100_fp8_dispatch(
+ out, a, b, a_scales, b_scales);
+ }
+}
+
+} // namespace vllm
diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
new file mode 100644
index 000000000000..ef324364c6d5
--- /dev/null
+++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
@@ -0,0 +1,205 @@
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_types.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+#include "cutlass/gemm/kernel/tile_scheduler_params.h"
+#include "cutlass/epilogue/dispatch_policy.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+
+#include "cutlass_extensions/gemm/dispatch_policy.hpp"
+#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
+
+#include "cutlass_gemm_caller.cuh"
+
+namespace vllm {
+
+using namespace cute;
+
+template
+struct cutlass_3x_gemm_fp8_blockwise {
+ using ElementAB = cutlass::float_e4m3_t;
+
+ using ElementA = ElementAB;
+ using LayoutA = cutlass::layout::RowMajor;
+ static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value;
+
+ using ElementB = ElementAB;
+ using LayoutB = cutlass::layout::ColumnMajor;
+ static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value;
+
+ using ElementC = void;
+ using ElementD = OutType;
+ using LayoutD = cutlass::layout::RowMajor;
+ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value;
+
+ using LayoutC = LayoutD;
+ static constexpr int AlignmentC = AlignmentD;
+
+ using ElementAccumulator = float;
+ using ElementCompute = float;
+ using ElementBlockScale = float;
+
+ // MMA and Cluster Tile Shapes
+ // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
+ // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
+ static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
+ static constexpr int ScaleGranularityM =
+ size<0>(MmaTileShape{}) / ScaleMsPerTile;
+ static constexpr int ScaleGranularityN =
+ size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
+ static constexpr int ScaleGranularityK =
+ size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
+
+ // Shape of the threadblocks in a cluster
+ using ClusterShape_MNK = ClusterShape;
+
+ using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
+ ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
+ cute::UMMA::Major::MN, cute::UMMA::Major::K>;
+ using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
+ using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
+
+ using ArchTag = cutlass::arch::Sm100;
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
+
+ static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
+ using ElementScalar = float;
+ // clang-format off
+ using DefaultOperation = cutlass::epilogue::fusion::LinearCombination;
+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
+ ArchTag,
+ OperatorClass,
+ MmaTileShape,
+ ClusterShape,
+ cutlass::epilogue::collective::EpilogueTileAuto,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ LayoutC,
+ AlignmentC,
+ ElementD,
+ LayoutD,
+ AlignmentD,
+ EpilogueScheduler,
+ DefaultOperation
+ >::CollectiveOp;
+
+ using StageCountType = cutlass::gemm::collective::StageCountAuto;
+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag,
+ OperatorClass,
+ ElementA,
+ cute::tuple,
+ AlignmentA,
+ ElementB,
+ cute::tuple,
+ AlignmentB,
+ ElementAccumulator,
+ MmaTileShape,
+ ClusterShape,
+
+ cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ MainloopScheduler
+ >::CollectiveOp;
+ // clang-format on
+
+ using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>;
+
+ struct GemmKernel : public KernelType {};
+};
+
+template
+void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
+ torch::Tensor const& b,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ using GemmKernel = typename Gemm::GemmKernel;
+ using StrideA = typename Gemm::GemmKernel::StrideA;
+ using StrideB = typename Gemm::GemmKernel::StrideB;
+ using StrideD = typename Gemm::GemmKernel::StrideD;
+ using StrideC = typename Gemm::GemmKernel::StrideC;
+ using LayoutSFA = typename Gemm::LayoutSFA;
+ using LayoutSFB = typename Gemm::LayoutSFB;
+ using ScaleConfig = typename Gemm::ScaleConfig;
+
+ using ElementAB = typename Gemm::ElementAB;
+ using ElementD = typename Gemm::ElementD;
+
+ int32_t m = a.size(0), n = b.size(1), k = a.size(1);
+ auto prob_shape = cute::make_shape(m, n, k, 1);
+
+ StrideA a_stride;
+ StrideB b_stride;
+ StrideC c_stride;
+ a_stride =
+ cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
+ b_stride =
+ cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
+ c_stride =
+ cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
+
+ LayoutSFA layout_SFA =
+ ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
+ LayoutSFB layout_SFB =
+ ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
+
+ auto a_ptr = static_cast(a.data_ptr());
+ auto b_ptr = static_cast(b.data_ptr());
+ auto a_scales_ptr = static_cast(a_scales.data_ptr());
+ auto b_scales_ptr = static_cast(b_scales.data_ptr());
+
+ typename GemmKernel::MainloopArguments mainloop_args{
+ a_ptr, a_stride, b_ptr, b_stride,
+ a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
+
+ auto c_ptr = static_cast(out.data_ptr());
+ typename GemmKernel::EpilogueArguments epilogue_args{
+ {}, c_ptr, c_stride, c_ptr, c_stride};
+ c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args,
+ epilogue_args);
+}
+
+template
+void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
+ torch::Tensor const& a,
+ torch::Tensor const& b,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ auto m = a.size(0);
+ auto k = a.size(1);
+ auto n = b.size(1);
+ int sms;
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
+
+ auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
+ return std::ceil(static_cast(m) / tile1SM) *
+ std::ceil(static_cast(n) / tile1SM) >=
+ sms;
+ };
+ bool use_2sm = should_use_2sm(m, n);
+ if (use_2sm) {
+ cutlass_gemm_caller_blockwise, Shape<_256, _1, _1>,
+ Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
+ out, a, b, a_scales, b_scales);
+ } else {
+ cutlass_gemm_caller_blockwise, Shape<_128, _1, _1>,
+ Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
+ out, a, b, a_scales, b_scales);
+ }
+}
+
+} // namespace vllm
diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
new file mode 100644
index 000000000000..2ee6a19407f9
--- /dev/null
+++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
@@ -0,0 +1,75 @@
+#include
+#include "cuda_utils.h"
+#include "cutlass_extensions/common.hpp"
+
+template
+void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
+ torch::Tensor const& b, torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ std::optional const& bias,
+ Fp8Func fp8_func, Int8Func int8_func,
+ BlockwiseFunc blockwise_func) {
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
+
+ int M = a.size(0), N = b.size(1), K = a.size(1);
+
+ if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
+ (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
+ // Standard per-tensor/per-token/per-channel scaling
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
+ if (a.dtype() == torch::kFloat8_e4m3fn) {
+ fp8_func(c, a, b, a_scales, b_scales, bias);
+ } else {
+ TORCH_CHECK(a.dtype() == torch::kInt8);
+ if constexpr (!std::is_same_v) {
+ int8_func(c, a, b, a_scales, b_scales, bias);
+ } else {
+ TORCH_CHECK(false, "Int8 not supported for this architecture");
+ }
+ }
+ } else {
+ TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
+ TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
+ int32_t version_num = get_sm_version_num();
+ if (version_num >= 100) {
+ TORCH_CHECK(
+ a.size(0) == a_scales.size(0) &&
+ cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
+ "a_scale_group_shape must be [1, 128].");
+ TORCH_CHECK(
+ cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
+ cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
+ "b_scale_group_shape must be [128, 128].");
+ } else {
+ // TODO: Remove this after using cutlass sm90 blockwise scaling gemm
+ // kernel, or introducing ceil_div to the load_init() of mainloop.
+ using GroupShape = std::array;
+ auto make_group_shape = [](torch::Tensor const& x,
+ torch::Tensor const& s) -> GroupShape {
+ TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
+ return {cuda_utils::ceil_div(x.size(0), s.size(0)),
+ cuda_utils::ceil_div(x.size(1), s.size(1))};
+ };
+
+ GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
+ GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
+
+ // 1x128 per-token group scales for activations
+ // 128x128 blockwise scales for weights
+ TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
+ b_scale_group_shape == GroupShape{128, 128} &&
+ a.dtype() == torch::kFloat8_e4m3fn &&
+ b.dtype() == torch::kFloat8_e4m3fn),
+ "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
+ "a_scale_group_shape must be [1, 128]. Got: [",
+ a_scale_group_shape[0], ", ", a_scale_group_shape[1],
+ "]\n"
+ "b_scale_group_shape must be [128, 128]. Got: [",
+ b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
+ }
+
+ TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
+ blockwise_func(c, a, b, a_scales, b_scales);
+ }
+}
diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
index 85272804774d..c1242fdb39da 100644
--- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
std::optional const& bias);
+void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
+ torch::Tensor const& a,
+ torch::Tensor const& b,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales);
} // namespace vllm
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
index 459eb1bb76eb..0cbd5305e3c2 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
@@ -1,8 +1,6 @@
-#include
+#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
-#include "cuda_utils.h"
-
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell).
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
-
- int M = a.size(0), N = b.size(1), K = a.size(1);
- TORCH_CHECK(
- (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
- (b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
- "Currently, block scaled fp8 gemm is not implemented for Blackwell");
-
- // Standard per-tensor/per-token/per-channel scaling
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
- "Currently, only fp8 gemm is implemented for Blackwell");
- vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
+ dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
+ vllm::cutlass_scaled_mm_sm100_fp8,
+ nullptr, // int8 not supported on SM100
+ vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
}
#endif
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
index bcb91040d5e2..211302171f07 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
@@ -1,8 +1,6 @@
-#include
+#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
-#include "cuda_utils.h"
-
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
-
- int M = a.size(0), N = b.size(1), K = a.size(1);
-
- if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
- (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
- // Standard per-tensor/per-token/per-channel scaling
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
- if (a.dtype() == torch::kFloat8_e4m3fn) {
- vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
- } else {
- TORCH_CHECK(a.dtype() == torch::kInt8);
- vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
- }
- } else {
- using GroupShape = std::array;
- auto make_group_shape = [](torch::Tensor const& x,
- torch::Tensor const& s) -> GroupShape {
- TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
- return {cuda_utils::ceil_div(x.size(0), s.size(0)),
- cuda_utils::ceil_div(x.size(1), s.size(1))};
- };
-
- GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
- GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
-
- // 1x128 per-token group scales for activations
- // 128x128 blockwise scales for weights
- TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
- b_scale_group_shape == GroupShape{128, 128} &&
- a.dtype() == torch::kFloat8_e4m3fn &&
- b.dtype() == torch::kFloat8_e4m3fn),
- "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
- "a_scale_group_shape must be [1, 128]. Got: [",
- a_scale_group_shape[0], ", ", a_scale_group_shape[1],
- "]\n"
- "b_scale_group_shape must be [128, 128]. Got: [",
- b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
- TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
-
- vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
- }
+ dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
+ vllm::cutlass_scaled_mm_sm90_fp8,
+ vllm::cutlass_scaled_mm_sm90_int8,
+ vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
}
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
index 54b63894e4cb..e9b408fbf2ee 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
@@ -29,7 +29,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional const& bias);
-
+#endif
+#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
@@ -37,12 +38,6 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
-void get_cutlass_moe_mm_data_caller(
- const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
- torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
- torch::Tensor& input_permutation, torch::Tensor& output_permutation,
- const int64_t num_experts, const int64_t n, const int64_t k);
-
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
@@ -53,6 +48,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional const& bias);
#endif
+#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
+ defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
+void get_cutlass_moe_mm_data_caller(
+ const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
+ torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
+ torch::Tensor& input_permutation, torch::Tensor& output_permutation,
+ const int64_t num_experts, const int64_t n, const int64_t k);
+#endif
+
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
@@ -110,6 +114,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
return CUDA_VERSION >= 12000;
+ } else if (cuda_device_capability >= 100) {
+ return CUDA_VERSION >= 12080;
}
#endif
@@ -117,7 +123,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
}
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
- // CUTLASS groped FP8 kernels need at least CUDA 12.3
+ // CUTLASS grouped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper)
#if defined CUDA_VERSION
@@ -222,7 +228,8 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
-#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
+#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k);
diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
new file mode 100644
index 000000000000..45ec3d29ce04
--- /dev/null
+++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
@@ -0,0 +1,402 @@
+#include
+#include
+
+#include
+#include
+#include
+
+#include "cute/tensor.hpp"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/group_array_problem_shape.hpp"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+
+#include "cutlass/util/command_line.h"
+#include "cutlass/util/distribution.h"
+#include "cutlass/util/host_tensor.h"
+#include "cutlass/util/packed_stride.hpp"
+#include "cutlass/util/tensor_view_io.h"
+#include "cutlass/util/reference/device/gemm.h"
+#include "cutlass/util/reference/device/tensor_compare.h"
+#include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/util/reference/host/gett.hpp"
+#include "cutlass/util/reference/host/tensor_norm.h"
+#include "cutlass/util/reference/host/tensor_compare.h"
+#include
+
+using namespace cute;
+
+template
+__global__ void __get_group_gemm_starts(
+ ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets,
+ ElementSF** a_scales_offsets, ElementSF** b_scales_offsets,
+ ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int,
+ LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int,
+ ElementAB* b_base_as_int, ElementC* out_base_as_int,
+ ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int,
+ ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets,
+ const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
+ const int K, const int N) {
+ int64_t expert_id = threadIdx.x;
+ if (expert_id >= gridDim.x * blockDim.x) {
+ return;
+ }
+ // Originally int32_t but upcasting to int64_t to avoid overflow
+ // during offset calculations
+ int64_t expert_offset = static_cast(expert_offsets[expert_id]);
+ int64_t sf_offset = static_cast(sf_offsets[expert_id]);
+ // size for block in block scale.
+ int64_t group_size = 16;
+ int64_t m = static_cast(problem_sizes_as_shapes[expert_id * 3]);
+ int64_t n = static_cast(problem_sizes_as_shapes[expert_id * 3 + 1]);
+ int64_t k = static_cast(problem_sizes_as_shapes[expert_id * 3 + 2]);
+ assert((m >= 0 && n == N && k == K && k % 2 == 0) &&
+ "unexpected problem sizes");
+
+ int64_t half_k = static_cast(k / 2);
+ int64_t group_k = static_cast(k / group_size);
+ // Shape of A as uint8/byte = [M, K // 2]
+ // Shape of B as uint8/byte = [E, N, K // 2]
+ a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
+
+ b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
+ // Shape of C = [M, N]
+ out_offsets[expert_id] = out_base_as_int + expert_offset * n;
+ // Shape of a_scale = [sum(sf_sizes), K // group_size]
+ a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
+
+ assert((reinterpret_cast(a_scales_offsets[expert_id]) % 128) ==
+ 0 &&
+ "TMA requires 128-byte alignment");
+
+ // Shape of B scale = [E, N, K // group_size]
+ b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
+ assert((reinterpret_cast(b_scales_offsets[expert_id]) % 128) ==
+ 0 &&
+ "TMA requires 128-byte alignment");
+ // Shape of alpha = [E]
+ alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
+
+ LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
+ LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
+
+ *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(
+ static_cast(m), static_cast(n), static_cast(k), 1));
+ *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(
+ static_cast(m), static_cast(n), static_cast(k), 1));
+}
+
+#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
+ TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
+ LayoutSFB, ScaleConfig) \
+ else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
+ __get_group_gemm_starts \
+ <<<1, num_experts, 0, stream>>>( \
+ static_cast(a_starts.data_ptr()), \
+ static_cast(b_starts.data_ptr()), \
+ static_cast(out_starts.data_ptr()), \
+ static_cast(a_scales_starts.data_ptr()), \
+ static_cast(b_scales_starts.data_ptr()), \
+ static_cast(alpha_starts.data_ptr()), \
+ reinterpret_cast(layout_sfa.data_ptr()), \
+ reinterpret_cast(layout_sfb.data_ptr()), \
+ static_cast(a_tensors.data_ptr()), \
+ static_cast(b_tensors.data_ptr()), \
+ static_cast(out_tensors.data_ptr()), \
+ static_cast(a_scales.data_ptr()), \
+ static_cast(b_scales.data_ptr()), \
+ static_cast(alphas.data_ptr()), \
+ static_cast(expert_offsets.data_ptr()), \
+ static_cast(sf_offsets.data_ptr()), \
+ static_cast(problem_sizes.data_ptr()), K, N); \
+ }
+
+template
+void run_get_group_gemm_starts(
+ const torch::Tensor& a_starts, const torch::Tensor& b_starts,
+ const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
+ const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
+ const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
+ /*these are used for their base addresses*/
+ torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
+ torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales, torch::Tensor const& alphas,
+ torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
+ torch::Tensor const& problem_sizes, int M, int N, int K) {
+ int num_experts = (int)expert_offsets.size(0);
+ auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
+
+ TORCH_CHECK(out_tensors.size(1) == N,
+ "Output tensor shape doesn't match expected shape");
+ TORCH_CHECK(K / 2 == b_tensors.size(2),
+ "b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
+ " dimension must match");
+ if (false) {
+ }
+ //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
+ // ScaleConfig)
+ __CALL_GET_STARTS_KERNEL_BLOCKSCALE(
+ cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
+ cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
+ __CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
+ cutlass::float_ue4m3_t, torch::kFloat16,
+ half, LayoutSFA, LayoutSFB, ScaleConfig)
+ else {
+ TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
+ }
+}
+
+template
+void run_fp4_blockwise_scaled_group_mm(
+ torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
+ const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
+ const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
+ const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
+ int N, int K) {
+ using ProblemShape =
+ cutlass::gemm::GroupProblemShape>;
+ using ElementType = cutlass::float_e2m1_t;
+ using ElementSFType = cutlass::float_ue4m3_t;
+ using ElementA = cutlass::nv_float4_t;
+ using ElementB = cutlass::nv_float4_t;
+
+ using ElementC = OutType;
+ using ElementD = ElementC;
+ using ElementAccumulator = float;
+ // Layout definitions
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::ColumnMajor;
+ using LayoutC = cutlass::layout::RowMajor;
+ using LayoutD = LayoutC;
+
+ // Alignment constraints
+ static constexpr int AlignmentA = 32;
+ static constexpr int AlignmentB = 32;
+ static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value;
+ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value;
+
+ // Architecture definitions
+ using ArchTag = cutlass::arch::Sm100;
+ using EpilogueOperatorClass =
+ cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
+ using MainloopOperatorClass =
+ cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
+ using StageCountType =
+ cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
+ // on the tile size
+
+ using ClusterShape = Shape<_1, _1, _1>;
+ struct MMA1SMConfig {
+ using MmaTileShape = Shape<_128, _128, _128>;
+ using KernelSchedule = cutlass::gemm::
+ KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
+ using EpilogueSchedule =
+ cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
+ };
+
+ using CollectiveEpilogue =
+ typename cutlass::epilogue::collective::CollectiveBuilder<
+ ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape,
+ ClusterShape, Shape<_128, _64>, ElementAccumulator,
+ ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
+ LayoutC*, AlignmentD,
+ typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
+
+ using CollectiveMainloop =
+ typename cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA,
+ ElementB, LayoutB*, AlignmentB, ElementAccumulator,
+ typename MMA1SMConfig::MmaTileShape, ClusterShape,
+ cutlass::gemm::collective::StageCountAutoCarveout(
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
+
+ using GemmKernel =
+ cutlass::gemm::kernel::GemmUniversal;
+
+ using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter;
+ using Gemm = Gemm1SM;
+ using StrideA = typename Gemm::GemmKernel::InternalStrideA;
+ using StrideB = typename Gemm::GemmKernel::InternalStrideB;
+ using StrideC = typename Gemm::GemmKernel::InternalStrideC;
+ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
+
+ using LayoutSFA =
+ typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
+ using LayoutSFB =
+ typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
+ using ScaleConfig =
+ typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
+
+ using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
+ int num_experts = static_cast(expert_offsets.size(0));
+ auto options_int =
+ torch::TensorOptions().dtype(torch::kInt64).device(a.device());
+
+ torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
+ torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
+ torch::Tensor c_strides1 =
+ torch::full({num_experts}, output.stride(0), options_int);
+ torch::Tensor a_strides1 =
+ torch::full({num_experts}, a.stride(0) * 2, options_int);
+ torch::Tensor b_strides1 =
+ torch::full({num_experts}, b.stride(1) * 2, options_int);
+
+ run_get_group_gemm_starts(
+ a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
+ layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
+ expert_offsets, sf_offsets, problem_sizes, M, N, K);
+
+ // Create an instance of the GEMM
+ Gemm gemm_op;
+
+ // Initialize problem_sizes_as_shapes correctly
+ UnderlyingProblemShape* problem_sizes_as_shapes =
+ static_cast(problem_sizes.data_ptr());
+
+ // Set the Scheduler info
+ cutlass::KernelHardwareInfo hw_info;
+ using RasterOrderOptions = typename cutlass::gemm::kernel::detail::
+ PersistentTileSchedulerSm100GroupParams<
+ typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
+ typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
+ scheduler.raster_order = RasterOrderOptions::AlongM;
+ hw_info.device_id = a.get_device();
+ static std::unordered_map cached_sm_counts;
+ if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
+ cached_sm_counts[hw_info.device_id] =
+ cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
+ hw_info.device_id);
+ }
+ hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
+
+ // Mainloop Arguments
+ typename GemmKernel::MainloopArguments mainloop_args{
+ static_cast(a_ptrs.data_ptr()),
+ static_cast(a_strides1.data_ptr()),
+ static_cast(b_ptrs.data_ptr()),
+ static_cast(b_strides1.data_ptr()),
+ static_cast(a_scales_ptrs.data_ptr()),
+ reinterpret_cast(layout_sfa.data_ptr()),
+ static_cast(b_scales_ptrs.data_ptr()),
+ reinterpret_cast(layout_sfb.data_ptr())};
+
+ // Epilogue Arguments
+ typename GemmKernel::EpilogueArguments epilogue_args{
+ {}, // epilogue.thread
+ nullptr,
+ static_cast(c_strides1.data_ptr()),
+ static_cast