diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index a378bc6baa5a..e29881fcbac0 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -8,12 +8,12 @@ # Note that we have 400 MiB quota, please use it wisely. # See https://github.com/pypi/support/issues/3792 . # Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400)) +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400)) def print_top_10_largest_files(zip_file): """Print the top 10 largest files in the given zip file.""" - with zipfile.ZipFile(zip_file, 'r') as z: + with zipfile.ZipFile(zip_file, "r") as z: file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] file_sizes.sort(key=lambda x: x[1], reverse=True) for f, size in file_sizes[:10]: @@ -28,14 +28,18 @@ def check_wheel_size(directory): wheel_path = os.path.join(root, file_name) wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024) if wheel_size_mb > VLLM_MAX_SIZE_MB: - print(f"Not allowed: Wheel {wheel_path} is larger " - f"({wheel_size_mb:.2f} MB) than the limit " - f"({VLLM_MAX_SIZE_MB} MB).") + print( + f"Not allowed: Wheel {wheel_path} is larger " + f"({wheel_size_mb:.2f} MB) than the limit " + f"({VLLM_MAX_SIZE_MB} MB)." + ) print_top_10_largest_files(wheel_path) return 1 else: - print(f"Wheel {wheel_path} is within the allowed size " - f"({wheel_size_mb:.2f} MB).") + print( + f"Wheel {wheel_path} is within the allowed size " + f"({wheel_size_mb:.2f} MB)." + ) return 0 @@ -45,4 +49,4 @@ def check_wheel_size(directory): sys.exit(1) directory = sys.argv[1] - sys.exit(check_wheel_size(directory)) \ No newline at end of file + sys.exit(check_wheel_size(directory)) diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 36e1b6c01326..270663c415c7 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -22,5 +22,5 @@ print(f"Generated index.html for {args.wheel}") # cloudfront requires escaping the '+' character f.write( - template.format(wheel=filename, - wheel_html_escaped=filename.replace("+", "%2B"))) + template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) + ) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml new file mode 100644 index 000000000000..cca58097e8aa --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Llama-3.2-1B-Instruct-FP8 -b "auto" -l 1319 -f 5 -t 1 +model_name: "RedHatAI/Llama-3.2-1B-Instruct-FP8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.335 + - name: "exact_match,flexible-extract" + value: 0.323 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml new file mode 100644 index 000000000000..54579a63a9b8 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2.5-1.5B-Instruct -b auto -l 1319 -f 5 -t 1 +model_name: "Qwen/Qwen2.5-1.5B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.54 + - name: "exact_match,flexible-extract" + value: 0.59 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml new file mode 100644 index 000000000000..a2f235f48581 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1 +model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.47 + - name: "exact_match,flexible-extract" + value: 0.64 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 37eeac85c933..27a1a9a82bd3 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,3 +3,4 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml +Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 254d01edf844..36e0543879b3 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,10 +1,6 @@ -Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +Qwen2.5-1.5B-Instruct.yaml Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml -Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml Qwen1.5-MoE-W4A16-compressed-tensors.yaml -Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml -Qwen2-1.5B-Instruct-FP8W8.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/conftest.py b/.buildkite/lm-eval-harness/conftest.py index a0bcc993ed4a..769d2efda4ad 100644 --- a/.buildkite/lm-eval-harness/conftest.py +++ b/.buildkite/lm-eval-harness/conftest.py @@ -8,11 +8,14 @@ def pytest_addoption(parser): parser.addoption( "--config-list-file", action="store", - help="Path to the file listing model config YAMLs (one per line)") - parser.addoption("--tp-size", - action="store", - default="1", - help="Tensor parallel size to use for evaluation") + help="Path to the file listing model config YAMLs (one per line)", + ) + parser.addoption( + "--tp-size", + action="store", + default="1", + help="Tensor parallel size to use for evaluation", + ) @pytest.fixture(scope="session") @@ -33,7 +36,8 @@ def pytest_generate_tests(metafunc): config_dir = config_list_file.parent with open(config_list_file, encoding="utf-8") as f: configs = [ - config_dir / line.strip() for line in f + config_dir / line.strip() + for line in f if line.strip() and not line.startswith("#") ] metafunc.parametrize("config_filename", configs) diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index c5411daf0df6..409a6ca82008 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -16,19 +16,22 @@ def launch_lm_eval(eval_config, tp_size): - trust_remote_code = eval_config.get('trust_remote_code', False) - model_args = f"pretrained={eval_config['model_name']}," \ - f"tensor_parallel_size={tp_size}," \ - f"enforce_eager=true," \ - f"add_bos_token=true," \ - f"trust_remote_code={trust_remote_code}" + trust_remote_code = eval_config.get("trust_remote_code", False) + model_args = ( + f"pretrained={eval_config['model_name']}," + f"tensor_parallel_size={tp_size}," + f"enforce_eager=true," + f"add_bos_token=true," + f"trust_remote_code={trust_remote_code}" + ) results = lm_eval.simple_evaluate( model="vllm", model_args=model_args, tasks=[task["name"] for task in eval_config["tasks"]], num_fewshot=eval_config["num_fewshot"], limit=eval_config["limit"], - batch_size="auto") + batch_size="auto", + ) return results @@ -42,9 +45,10 @@ def test_lm_eval_correctness_param(config_filename, tp_size): for metric in task["metrics"]: ground_truth = metric["value"] measured_value = results["results"][task["name"]][metric["name"]] - print(f'{task["name"]} | {metric["name"]}: ' - f'ground_truth={ground_truth} | measured={measured_value}') - success = success and np.isclose( - ground_truth, measured_value, rtol=RTOL) + print( + f"{task['name']} | {metric['name']}: " + f"ground_truth={ground_truth} | measured={measured_value}" + ) + success = success and np.isclose(ground_truth, measured_value, rtol=RTOL) assert success diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 1030ec24e8d7..7f2a2d8dc296 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -65,18 +65,18 @@ def read_markdown(file): def results_to_json(latency, throughput, serving): - return json.dumps({ - 'latency': latency.to_dict(), - 'throughput': throughput.to_dict(), - 'serving': serving.to_dict() - }) + return json.dumps( + { + "latency": latency.to_dict(), + "throughput": throughput.to_dict(), + "serving": serving.to_dict(), + } + ) if __name__ == "__main__": - # collect results for test_file in results_folder.glob("*.json"): - with open(test_file) as f: raw_result = json.loads(f.read()) @@ -120,7 +120,8 @@ def results_to_json(latency, throughput, serving): for perc in [10, 25, 50, 75, 90, 99]: # Multiply 1000 to convert the time unit from s to ms raw_result.update( - {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}) + {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]} + ) raw_result["avg_latency"] = raw_result["avg_latency"] * 1000 # add the result to raw_result @@ -153,26 +154,27 @@ def results_to_json(latency, throughput, serving): serving_results = pd.DataFrame.from_dict(serving_results) throughput_results = pd.DataFrame.from_dict(throughput_results) - raw_results_json = results_to_json(latency_results, throughput_results, - serving_results) + raw_results_json = results_to_json( + latency_results, throughput_results, serving_results + ) # remapping the key, for visualization purpose if not latency_results.empty: - latency_results = latency_results[list( - latency_column_mapping.keys())].rename( - columns=latency_column_mapping) + latency_results = latency_results[list(latency_column_mapping.keys())].rename( + columns=latency_column_mapping + ) if not serving_results.empty: - serving_results = serving_results[list( - serving_column_mapping.keys())].rename( - columns=serving_column_mapping) + serving_results = serving_results[list(serving_column_mapping.keys())].rename( + columns=serving_column_mapping + ) if not throughput_results.empty: - throughput_results = throughput_results[list( - throughput_results_column_mapping.keys())].rename( - columns=throughput_results_column_mapping) + throughput_results = throughput_results[ + list(throughput_results_column_mapping.keys()) + ].rename(columns=throughput_results_column_mapping) - processed_results_json = results_to_json(latency_results, - throughput_results, - serving_results) + processed_results_json = results_to_json( + latency_results, throughput_results, serving_results + ) for df in [latency_results, serving_results, throughput_results]: if df.empty: @@ -184,38 +186,39 @@ def results_to_json(latency, throughput, serving): # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", # we want to turn it into "8xGPUTYPE" df["GPU"] = df["GPU"].apply( - lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}") + lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}" + ) # get markdown tables - latency_md_table = tabulate(latency_results, - headers='keys', - tablefmt='pipe', - showindex=False) - serving_md_table = tabulate(serving_results, - headers='keys', - tablefmt='pipe', - showindex=False) - throughput_md_table = tabulate(throughput_results, - headers='keys', - tablefmt='pipe', - showindex=False) + latency_md_table = tabulate( + latency_results, headers="keys", tablefmt="pipe", showindex=False + ) + serving_md_table = tabulate( + serving_results, headers="keys", tablefmt="pipe", showindex=False + ) + throughput_md_table = tabulate( + throughput_results, headers="keys", tablefmt="pipe", showindex=False + ) # document the result with open(results_folder / "benchmark_results.md", "w") as f: - - results = read_markdown("../.buildkite/nightly-benchmarks/" + - "performance-benchmarks-descriptions.md") + results = read_markdown( + "../.buildkite/nightly-benchmarks/" + + "performance-benchmarks-descriptions.md" + ) results = results.format( latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, serving_tests_markdown_table=serving_md_table, - benchmarking_results_in_json_string=processed_results_json) + benchmarking_results_in_json_string=processed_results_json, + ) f.write(results) # document benchmarking results in json with open(results_folder / "benchmark_results.json", "w") as f: - - results = latency_results.to_dict( - orient='records') + throughput_results.to_dict( - orient='records') + serving_results.to_dict(orient='records') + results = ( + latency_results.to_dict(orient="records") + + throughput_results.to_dict(orient="records") + + serving_results.to_dict(orient="records") + ) f.write(json.dumps(results)) diff --git a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py index 5e17b79d26a1..778a3a8d87f6 100644 --- a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py +++ b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py @@ -14,15 +14,12 @@ def main(model, cachedir): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Download and save Hugging Face tokenizer") - parser.add_argument("--model", - type=str, - required=True, - help="Name of the model") - parser.add_argument("--cachedir", - type=str, - required=True, - help="Directory to save the tokenizer") + description="Download and save Hugging Face tokenizer" + ) + parser.add_argument("--model", type=str, required=True, help="Name of the model") + parser.add_argument( + "--cachedir", type=str, required=True, help="Directory to save the tokenizer" + ) args = parser.parse_args() main(args.model, args.cachedir) diff --git a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py index 0ff95a0911b1..10a7a2f5a467 100644 --- a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py @@ -11,33 +11,33 @@ def parse_arguments(): parser = argparse.ArgumentParser( - description= - 'Parse command line arguments for summary-nightly-results script.') - parser.add_argument('--results-folder', - type=str, - required=True, - help='The folder where the results are stored.') - parser.add_argument('--description', - type=str, - required=True, - help='Description of the results.') + description="Parse command line arguments for summary-nightly-results script." + ) + parser.add_argument( + "--results-folder", + type=str, + required=True, + help="The folder where the results are stored.", + ) + parser.add_argument( + "--description", type=str, required=True, help="Description of the results." + ) args = parser.parse_args() return args def get_perf(df, method, model, metric): - means = [] for qps in [2, 4, 8, 16, "inf"]: - target = df['Test name'].str.contains(model) - target = target & df['Engine'].str.contains(method) - target = target & df['Test name'].str.contains("qps_" + str(qps)) + target = df["Test name"].str.contains(model) + target = target & df["Engine"].str.contains(method) + target = target & df["Test name"].str.contains("qps_" + str(qps)) filtered_df = df[target] if filtered_df.empty: - means.append(0.) + means.append(0.0) else: means.append(filtered_df[metric].values[0]) @@ -45,7 +45,6 @@ def get_perf(df, method, model, metric): def get_perf_w_std(df, method, model, metric): - if metric in ["TTFT", "ITL"]: mean = get_perf(df, method, model, "Mean " + metric + " (ms)") mean = mean.tolist() @@ -60,7 +59,8 @@ def get_perf_w_std(df, method, model, metric): else: assert metric == "Tput" mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf( - df, method, model, "Output Tput (tok/s)") + df, method, model, "Output Tput (tok/s)" + ) mean = mean.tolist() std = None @@ -80,18 +80,17 @@ def main(args): # generate markdown table df = pd.DataFrame.from_dict(results) - md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) + md_table = tabulate(df, headers="keys", tablefmt="pipe", showindex=False) with open(args.description) as f: description = f.read() - description = description.format( - nightly_results_benchmarking_table=md_table) + description = description.format(nightly_results_benchmarking_table=md_table) with open("nightly_results.md", "w") as f: f.write(description) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments() main(args) diff --git a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py index 62ee5e10b509..2a7b37991f31 100644 --- a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py +++ b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py @@ -34,10 +34,8 @@ } if __name__ == "__main__": - # collect results for test_file in results_folder.glob("*.json"): - with open(test_file) as f: raw_result = json.loads(f.read()) @@ -56,17 +54,16 @@ serving_results = pd.DataFrame.from_dict(serving_results) if not serving_results.empty: - serving_results = serving_results[list( - serving_column_mapping.keys())].rename( - columns=serving_column_mapping) + serving_results = serving_results[list(serving_column_mapping.keys())].rename( + columns=serving_column_mapping + ) - serving_md_table_with_headers = tabulate(serving_results, - headers='keys', - tablefmt='pipe', - showindex=False) + serving_md_table_with_headers = tabulate( + serving_results, headers="keys", tablefmt="pipe", showindex=False + ) # remove the first line of header - serving_md_table_lines = serving_md_table_with_headers.split('\n') - serving_md_table_without_header = '\n'.join(serving_md_table_lines[2:]) + serving_md_table_lines = serving_md_table_with_headers.split("\n") + serving_md_table_without_header = "\n".join(serving_md_table_lines[2:]) prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE") @@ -76,10 +73,9 @@ # document results with header. # for those who wants to reproduce our benchmark. f.write(serving_md_table_with_headers) - f.write('\n') + f.write("\n") # document benchmarking results in json with open(results_folder / f"{prefix}_nightly_results.json", "w") as f: - - results = serving_results.to_dict(orient='records') + results = serving_results.to_dict(orient="records") f.write(json.dumps(results)) diff --git a/.buildkite/pyproject.toml b/.buildkite/pyproject.toml new file mode 100644 index 000000000000..d5cad1c73c6f --- /dev/null +++ b/.buildkite/pyproject.toml @@ -0,0 +1,46 @@ +# This local pyproject file is part of the migration from yapf to ruff format. +# It uses the same core rules as the main pyproject.toml file, but with the +# following differences: +# - ruff line length is overridden to 88 +# - deprecated typing ignores (UP006, UP035) have been removed + +[tool.ruff] +line-length = 88 + +[tool.ruff.lint.per-file-ignores] +"vllm/third_party/**" = ["ALL"] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + +[tool.ruff.format] +docstring-code-format = true diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 4cc9c70a6adb..b3c27e2c99c2 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -14,7 +14,7 @@ steps: agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -31,7 +31,7 @@ steps: agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -64,7 +64,7 @@ steps: - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" plugins: - docker-login#v3.0.0: - username: vllm + username: vllmbot password-env: DOCKERHUB_TOKEN env: DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index d29903bf497f..bbc896ec6819 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -3,6 +3,9 @@ # This script runs test inside the corresponding ROCm docker container. set -o pipefail +# Export Python path +export PYTHONPATH=".." + # Print ROCm version echo "--- Confirming Clean Initial State" while true; do @@ -74,6 +77,23 @@ HF_MOUNT="/root/.cache/huggingface" commands=$@ echo "Commands:$commands" + +if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then + commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"} +fi + +if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then + commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} +fi + +if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then + commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"} +fi + +if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then + commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} +fi + #ignore certain kernels tests if [[ $commands == *" kernels/core"* ]]; then commands="${commands} \ @@ -161,6 +181,8 @@ fi PARALLEL_JOB_COUNT=8 +MYPYTHONPATH=".." + # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then # assign job count as the number of shards used @@ -181,6 +203,7 @@ if [[ $commands == *"--shard-id="* ]]; then -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ --name "${container_name}_${GPU}" \ "${image_name}" \ /bin/bash -c "${commands_gpu}" \ @@ -211,6 +234,7 @@ else -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ --name "${container_name}" \ "${image_name}" \ /bin/bash -c "${commands}" diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 5d863dd82e9b..077bd9914907 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -32,9 +32,12 @@ function cpu_tests() { set -e pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib pip install sentence-transformers datamodel_code_generator - pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach] - pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5] - pytest -v -s tests/models/encoder_decoder/language -m cpu_model" + pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2] + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m] + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it] + pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach] + pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]" } # All of CPU tests are expected to be finished less than 40 mins. diff --git a/.buildkite/scripts/hardware_ci/run-hpu-test.sh b/.buildkite/scripts/hardware_ci/run-hpu-test.sh index 95b6ac37f185..5efac3ddf469 100644 --- a/.buildkite/scripts/hardware_ci/run-hpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-hpu-test.sh @@ -10,15 +10,17 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu . # Setup cleanup # certain versions of HPU software stack have a bug that can # override the exit code of the script, so we need to use -# separate remove_docker_container and remove_docker_container_and_exit +# separate remove_docker_containers and remove_docker_containers_and_exit # functions, while other platforms only need one remove_docker_container # function. EXITCODE=1 -remove_docker_container() { docker rm -f hpu-test || true; } -remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; } -trap remove_docker_container_and_exit EXIT -remove_docker_container +remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; } +remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; } +trap remove_docker_containers_and_exit EXIT +remove_docker_containers # Run the image and launch offline inference docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m +docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2 + EXITCODE=$? diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh index ec6a080eb499..3d294ea5f8a7 100644 --- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh +++ b/.buildkite/scripts/hardware_ci/run-neuron-test.sh @@ -11,13 +11,14 @@ container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" HF_CACHE="$(realpath ~)/huggingface" mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" +HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN) NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" mkdir -p "${NEURON_COMPILE_CACHE_URL}" NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" # Try building the docker image -aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws # prune old image and containers to save disk space, and only once a day # by using a timestamp file in tmp. @@ -47,8 +48,16 @@ trap remove_docker_container EXIT docker run --rm -it --device=/dev/neuron0 --network bridge \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "HF_TOKEN=${HF_TOKEN}" \ -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ --name "${container_name}" \ ${image_name} \ - /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys" + /bin/bash -c " + python3 /workspace/vllm/examples/offline_inference/neuron.py; + python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; + for f in /workspace/vllm/tests/neuron/2_core/*.py; do + echo 'Running test file: '$f; + python3 -m pytest \$f -v --capture=tee-sys; + done + " \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 939daddad92b..2d375d7e9d87 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -26,27 +26,27 @@ docker run --privileged --net host --shm-size=16G -it \ && tpu-info \ && { \ echo TEST_0: Running test_perf.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \ echo TEST_0_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_1: Running test_compilation.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \ echo TEST_1_EXIT_CODE: \$?; \ } & \ { \ echo TEST_2: Running test_basic.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \ echo TEST_2_EXIT_CODE: \$?; \ } & \ { \ echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ - pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ + python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ echo TEST_3_EXIT_CODE: \$?; \ } & \ { \ echo TEST_4: Running test_quantization_accuracy.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \ echo TEST_4_EXIT_CODE: \$?; \ } & \ { \ @@ -56,43 +56,43 @@ docker run --privileged --net host --shm-size=16G -it \ } & \ { \ echo TEST_6: Running test_tpu_model_runner.py; \ - pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \ echo TEST_6_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_7: Running test_sampler.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \ echo TEST_7_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_8: Running test_topk_topp_sampler.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \ echo TEST_8_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_9: Running test_multimodal.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \ echo TEST_9_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_10: Running test_pallas.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \ echo TEST_10_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_11: Running test_struct_output_generate.py; \ - pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \ echo TEST_11_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_12: Running test_moe_pallas.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \ echo TEST_12_EXIT_CODE: \$?; \ } & \ # Disable the TPU LoRA tests until the feature is activated - # && { \ + # & { \ # echo TEST_13: Running test_moe_pallas.py; \ - # pytest -s -v /workspace/vllm/tests/tpu/lora/; \ + # python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/; \ # echo TEST_13_EXIT_CODE: \$?; \ # } & \ wait \ diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 75e3ef264095..037897e53dbe 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -75,3 +75,4 @@ else fi aws s3 cp "$wheel" "s3://vllm-wheels/$version/" +aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 01d04759f536..80a5a610c8ac 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -32,16 +32,17 @@ steps: ##### fast check tests ##### - label: Documentation Build # 2min - working_dir: "/vllm-workspace/test_docs/docs" + mirror_hardwares: [amdexperimental] + working_dir: "/vllm-workspace/test_docs" fast_check: true no_gpu: True commands: - - pip install -r ../../requirements/docs.txt - - SPHINXOPTS=\"-W\" make html - # Check API reference (if it fails, you may have missing mock imports) - - grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html + - pip install -r ../requirements/docs.txt + # TODO: add `--strict` once warnings in docstrings are fixed + - mkdocs build - label: Async Engine, Inputs, Utils, Worker Test # 24min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/mq_llm_engine @@ -57,11 +58,13 @@ steps: - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py + - pytest -v -s test_outputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils - pytest -v -s worker # Worker - label: Python-only Installation Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - tests/standalone_tests/python_only_compile.sh - setup.py @@ -69,7 +72,7 @@ steps: - bash standalone_tests/python_only_compile.sh - label: Basic Correctness Test # 30min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] fast_check: true torch_nightly: true source_file_dependencies: @@ -86,6 +89,7 @@ steps: - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Chunked Prefill Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/basic_correctness/test_chunked_prefill @@ -94,7 +98,7 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] fast_check: true source_file_dependencies: - vllm/core @@ -104,10 +108,10 @@ steps: - pytest -v -s core - label: Entrypoints Test # 40min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true torch_nightly: true - #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/entrypoints/llm @@ -121,11 +125,12 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ - pytest -v -s entrypoints/test_chat_utils.py - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Distributed Tests (4 GPUs) # 10min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -133,6 +138,7 @@ steps: - vllm/core/ - tests/distributed/test_utils - tests/distributed/test_pynccl + - tests/distributed/test_events - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile/test_basic_correctness - examples/offline_inference/rlhf.py @@ -143,22 +149,25 @@ steps: # test with tp=2 and external_dp=2 - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with tp=2 and pp=2 + - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py + - pytest -v -s distributed/test_events.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - pushd ../examples/offline_inference - - python3 rlhf.py - - RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd - label: Metrics, Tracing Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 source_file_dependencies: - vllm/ @@ -172,7 +181,7 @@ steps: ##### 1 GPU test ##### - label: Regression Test # 5min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/test_regression @@ -182,7 +191,7 @@ steps: working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/engine @@ -196,7 +205,7 @@ steps: - pytest -v -s tokenization - label: V1 Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/v1 @@ -209,10 +218,11 @@ steps: - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode + - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/test_serial_utils.py - - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py + - pytest -v -s v1/test_metrics_reader.py # TODO: accuracy does not match, whether setting # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - pytest -v -s v1/e2e @@ -221,8 +231,8 @@ steps: - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/examples" - #mirror_hardwares: [amd] source_file_dependencies: - vllm/entrypoints - examples/ @@ -237,7 +247,7 @@ steps: - python3 offline_inference/vision_language.py --seed 0 - python3 offline_inference/vision_language_embedding.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - - VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/encoder_decoder.py - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py @@ -246,7 +256,7 @@ steps: - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/prefix_caching @@ -254,6 +264,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test # 36min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers - vllm/sampling_metadata.py @@ -264,7 +275,7 @@ steps: - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers - label: LogitsProcessor Test # 5min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/model_executor/layers - vllm/model_executor/guided_decoding @@ -275,6 +286,7 @@ steps: - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 40min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/spec_decode - tests/spec_decode @@ -285,7 +297,7 @@ steps: - pytest -v -s spec_decode/e2e/test_eagle_correctness.py - label: LoRA Test %N # 15min each - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/lora - tests/lora @@ -293,6 +305,7 @@ steps: parallelism: 4 - label: PyTorch Compilation Unit Tests + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -300,9 +313,12 @@ steps: commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py + - pytest -v -s compile/test_async_tp.py - label: PyTorch Fullgraph Smoke Test # 9min + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -314,6 +330,7 @@ steps: - pytest -v -s compile/piecewise/test_toy_llama.py - label: PyTorch Fullgraph Test # 18min + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -322,7 +339,7 @@ steps: - pytest -v -s compile/test_full_graph.py - label: Kernels Core Operation Test - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/ - tests/kernels/core @@ -330,7 +347,7 @@ steps: - pytest -v -s kernels/core - label: Kernels Attention Test %N - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/attention/ - vllm/attention @@ -341,7 +358,7 @@ steps: parallelism: 2 - label: Kernels Quantization Test %N - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/quantization/ - vllm/model_executor/layers/quantization @@ -351,7 +368,7 @@ steps: parallelism: 2 - label: Kernels MoE Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/moe/ - tests/kernels/moe @@ -360,7 +377,7 @@ steps: - pytest -v -s kernels/moe - label: Kernels Mamba Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/mamba/ - tests/kernels/mamba @@ -368,25 +385,28 @@ steps: - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min - # mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader - tests/tensorizer_loader + - tests/entrypoints/openai/test_tensorizer_entrypoint.py commands: - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s tensorizer_loader + - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py - label: Benchmarks # 9min + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/.buildkite" - mirror_hardwares: [amd] source_file_dependencies: - benchmarks/ commands: - bash scripts/run-benchmarks.sh - label: Benchmarks CLI Test # 10min + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/benchmarks/ @@ -394,6 +414,7 @@ steps: - pytest -v -s benchmarks/ - label: Quantization Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization @@ -402,6 +423,7 @@ steps: - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ @@ -411,6 +433,7 @@ steps: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - vllm/entrypoints/openai/ @@ -419,6 +442,7 @@ steps: - pytest -s entrypoints/openai/correctness/ - label: Encoder Decoder tests # 5min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/encoder_decoder @@ -426,8 +450,8 @@ steps: - pytest -v -s encoder_decoder - label: OpenAI-Compatible Tool Use # 20 min + mirror_hardwares: [amdexperimental] fast_check: false - #mirror_hardwares: [ amd ] source_file_dependencies: - vllm/ - tests/tool_use @@ -439,6 +463,7 @@ steps: ##### models test ##### - label: Basic Models Test # 24min + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -448,43 +473,55 @@ steps: - pytest -v -s models/test_registry.py - pytest -v -s models/test_utils.py - pytest -v -s models/test_vision.py - # V1 Test: https://github.com/vllm-project/vllm/issues/14531 - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2' - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' + - pytest -v -s models/test_initialization.py - label: Language Models Test (Standard) - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] + torch_nightly: true source_file_dependencies: - vllm/ - tests/models/language commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + - pip freeze | grep -E 'torch' - pytest -v -s models/language -m core_model -- label: Language Models Test (Extended) +- label: Language Models Test (Extended Generation) # 1hr20min + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ - - tests/models/language + - tests/models/language/generation commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' - - pytest -v -s models/language -m 'not core_model' + - pytest -v -s models/language/generation -m 'not core_model' + +- label: Language Models Test (Extended Pooling) # 36min + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling + commands: + - pytest -v -s models/language/pooling -m 'not core_model' - label: Multi-Modal Models Test (Standard) - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] + torch_nightly: true source_file_dependencies: - vllm/ - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal/processing - pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -494,6 +531,7 @@ steps: - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' - label: Multi-Modal Models Test (Extended) 2 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -503,6 +541,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' - label: Multi-Modal Models Test (Extended) 3 + mirror_hardwares: [amdexperimental, amdproduction] optional: true source_file_dependencies: - vllm/ @@ -512,7 +551,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' - label: Quantized Models Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/model_executor/layers/quantization - tests/models/quantization @@ -521,7 +560,7 @@ steps: # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] optional: true commands: - echo 'Testing custom models...' @@ -533,7 +572,7 @@ steps: ##### multi gpus test ##### - label: Distributed Comm Ops Test # 7min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -544,6 +583,7 @@ steps: - pytest -v -s distributed/test_shm_broadcast.py - label: 2 Node Tests (4 GPUs in total) # 16min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 num_nodes: 2 @@ -562,7 +602,7 @@ steps: - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - label: Distributed Tests (2 GPUs) # 40min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -599,13 +639,14 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - label: Plugin Tests (2 GPUs) # 40min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: - vllm/plugins/ - tests/plugins/ commands: - # begin platform plugin tests, all the code in-between runs on dummy platform + # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform - pip install -e ./plugins/vllm_add_dummy_platform - pytest -v -s plugins_tests/test_platform_plugins.py - pip uninstall vllm_add_dummy_platform -y @@ -616,8 +657,10 @@ steps: - pytest -v -s distributed/test_distributed_oot.py - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins - label: Multi-step Tests (4 GPUs) # 36min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -638,6 +681,7 @@ steps: - pytest -v -s multi_step/test_correctness_llm.py - label: Pipeline Parallelism Test # 45min + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -651,6 +695,7 @@ steps: - pytest -v -s distributed/test_pipeline_parallel.py - label: LoRA TP Test (Distributed) + mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 4 source_file_dependencies: - vllm/lora @@ -666,6 +711,7 @@ steps: - label: Weight Loading Multiple GPU Test # 33min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -675,6 +721,7 @@ steps: - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt - label: Weight Loading Multiple GPU Test - Large Models # optional + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 gpu: a100 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 76aa5f7a35d5..4452ce22d504 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -13,6 +13,7 @@ /vllm/model_executor/guided_decoding @mgoin @russellb /vllm/multimodal @DarkLight1337 @ywang96 /vllm/vllm_flash_attn @LucasWilkinson +/vllm/lora @jeejeelee CMakeLists.txt @tlrmchlsmth # vLLM V1 @@ -40,3 +41,8 @@ CMakeLists.txt @tlrmchlsmth /tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb /tests/v1/structured_output @mgoin @russellb /tests/weight_loading @mgoin @youkaichao +/tests/lora @jeejeelee + +# Docs +/docs @hmellor +mkdocs.yaml @hmellor \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml index 00b0f024c0da..f05be2ba8707 100644 --- a/.github/ISSUE_TEMPLATE/400-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -81,14 +81,14 @@ body: required: true - type: markdown attributes: - value: > - โš ๏ธ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the models' output: + value: | + โš ๏ธ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the model's output: - Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc). - If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect. - Thanks for contributing ๐ŸŽ‰! + Thanks for reporting ๐Ÿ™! - type: checkboxes id: askllm attributes: diff --git a/.github/ISSUE_TEMPLATE/450-ci-failure.yml b/.github/ISSUE_TEMPLATE/450-ci-failure.yml new file mode 100644 index 000000000000..7af0e0673a2f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/450-ci-failure.yml @@ -0,0 +1,69 @@ +name: ๐Ÿงช CI failure report +description: Report a failing test. +title: "[CI Failure]: " +labels: ["ci-failure"] + +body: +- type: markdown + attributes: + value: > + #### Include the name of the failing Buildkite step and test file in the title. +- type: input + attributes: + label: Name of failing test + description: | + Paste in the fully-qualified name of the failing test from the logs. + placeholder: | + `path/to/test_file.py::test_name[params]` + validations: + required: true +- type: checkboxes + attributes: + label: Basic information + description: Select all items that apply to the failing test. + options: + - label: Flaky test + - label: Can reproduce locally + - label: Caused by external libraries (e.g. bug in `transformers`) +- type: textarea + attributes: + label: ๐Ÿงช Describe the failing test + description: | + Please provide a clear and concise description of the failing test. + placeholder: | + A clear and concise description of the failing test. + + ``` + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. + ``` + validations: + required: true +- type: textarea + attributes: + label: ๐Ÿ“ History of failing test + description: | + Since when did the test start to fail? + You can look up its history via [Buildkite Test Suites](https://buildkite.com/organizations/vllm/analytics/suites/ci-1/tests?branch=main). + + If you have time, identify the PR that caused the test to fail on main. You can do so via the following methods: + + - Use Buildkite Test Suites to find the PR where the test failure first occurred, and reproduce the failure locally. + + - Run [`git bisect`](https://git-scm.com/docs/git-bisect) locally. + + - Manually unblock Buildkite steps for suspected PRs on main and check the results. (authorized users only) + placeholder: | + Approximate timeline and/or problematic PRs + + A link to the Buildkite analytics of the failing test (if available) + validations: + required: true +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. Usually, this includes those who worked on the PR that failed the test. +- type: markdown + attributes: + value: > + Thanks for reporting ๐Ÿ™! diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 7042e81a84da..65be771b94fb 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,4 +3,4 @@ FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) -**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) +**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) diff --git a/.github/mergify.yml b/.github/mergify.yml index 15fa3660a87d..e595060c325a 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -58,7 +58,7 @@ pull_request_rules: - files~=^benchmarks/structured_schemas/ - files=benchmarks/benchmark_serving_structured_output.py - files=benchmarks/run_structured_output_benchmark.sh - - files=docs/source/features/structured_outputs.md + - files=docs/features/structured_outputs.md - files=examples/offline_inference/structured_outputs.py - files=examples/online_serving/openai_chat_completion_structured_outputs.py - files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py @@ -135,9 +135,7 @@ pull_request_rules: - files~=^tests/entrypoints/openai/tool_parsers/ - files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py - files~=^vllm/entrypoints/openai/tool_parsers/ - - files=docs/source/features/tool_calling.md - - files=docs/source/getting_started/examples/openai_chat_completion_client_with_tools.md - - files=docs/source/getting_started/examples/chat_with_tools.md + - files=docs/features/tool_calling.md - files~=^examples/tool_chat_* - files=examples/offline_inference/chat_with_tools.py - files=examples/online_serving/openai_chat_completion_client_with_tools_required.py @@ -163,6 +161,17 @@ pull_request_rules: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork +- name: assign reviewer for tensorizer changes + conditions: + - files~=^vllm/model_executor/model_loader/tensorizer.py + - files~=^vllm/model_executor/model_loader/tensorizer_loader.py + - files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py + - files~=^tests/tensorizer_loader/ + actions: + assign: + users: + - "sangstar" + - name: remove 'needs-rebase' label when conflict is resolved conditions: - -conflict diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh index 3246c6f9bc4b..8d65936fba1d 100755 --- a/.github/scripts/cleanup_pr_body.sh +++ b/.github/scripts/cleanup_pr_body.sh @@ -26,7 +26,7 @@ sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}" # Remove HTML
section that includes text of "PR Checklist (Click to Expand)" python3 - < - - vLLM + + vLLM

@@ -16,18 +16,20 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* ๐Ÿ”ฅ +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). +- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). +- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). + +
+Previous News + - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). - [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0). - [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted. -- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). - [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing). - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! - -
-Previous News - - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! - [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! @@ -56,7 +58,7 @@ vLLM is fast with: - Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph -- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. +- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, INT8, and FP8. - Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. - Speculative decoding - Chunked prefill @@ -72,7 +74,7 @@ vLLM is flexible and easy to use with: - OpenAI-compatible API server - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. - Prefix caching support -- Multi-lora support +- Multi-LoRA support vLLM seamlessly supports most popular open-source models on HuggingFace, including: - Transformer-like LLMs (e.g., Llama) @@ -98,14 +100,14 @@ Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more. ## Contributing We welcome and value any contributions and collaborations. -Please check out [Contributing to vLLM](https://docs.vllm.ai/en/stable/contributing/overview.html) for how to get involved. +Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved. ## Sponsors vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! - + Cash Donations: - a16z - Dropbox diff --git a/benchmarks/README.md b/benchmarks/README.md index 4a8ab895e18e..ecab570bb31c 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -146,10 +146,9 @@ python3 vllm/benchmarks/benchmark_serving.py \ ``` bash VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ - --speculative-model "[ngram]" \ --ngram_prompt_lookup_min 2 \ --ngram-prompt-lookup-max 5 \ - --num_speculative_tokens 5 + --speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5} ``` ``` bash @@ -274,10 +273,9 @@ python3 vllm/benchmarks/benchmark_throughput.py \ --output-len=100 \ --num-prompts=2048 \ --async-engine \ - --speculative-model="[ngram]" \ --ngram_prompt_lookup_min=2 \ --ngram-prompt-lookup-max=5 \ - --num_speculative_tokens=5 + --speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5} ``` ``` diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index e6a67fda6827..88616e1108c5 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -12,8 +12,7 @@ import aiohttp import huggingface_hub.constants from tqdm.asyncio import tqdm -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast # NOTE(simon): do not import vLLM here so the benchmark script # can run without vLLM installed. @@ -43,8 +42,7 @@ class RequestFuncOutput: latency: float = 0.0 output_tokens: int = 0 ttft: float = 0.0 # Time to first token - itl: list[float] = field( - default_factory=list) # list of inter-token latencies + itl: list[float] = field(default_factory=list) # list of inter-token latencies tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" @@ -57,8 +55,9 @@ async def async_request_tgi( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: params = { "max_new_tokens": request_func_input.output_len, "do_sample": True, @@ -105,8 +104,7 @@ async def async_request_tgi( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp @@ -133,8 +131,9 @@ async def async_request_trt_llm( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { "accumulate_tokens": True, "text_input": request_func_input.prompt, @@ -159,8 +158,7 @@ async def async_request_trt_llm( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data:") + chunk = chunk_bytes.decode("utf-8").removeprefix("data:") data = json.loads(chunk) output.generated_text += data["text_output"] @@ -172,8 +170,7 @@ async def async_request_trt_llm( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp @@ -197,9 +194,14 @@ async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + api_url = request_func_input.api_url + assert api_url.endswith(("completions", "profile")), ( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { "model": request_func_input.model, "prompt": request_func_input.prompt, @@ -207,6 +209,8 @@ async def async_request_deepspeed_mii( "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. "top_p": 1.0, } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -217,19 +221,21 @@ async def async_request_deepspeed_mii( st = time.perf_counter() try: - async with session.post(url=request_func_input.api_url, - json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: parsed_resp = await response.json() output.latency = time.perf_counter() - st if "choices" in parsed_resp: - output.generated_text = parsed_resp["choices"][0][ - "text"] + output.generated_text = parsed_resp["choices"][0]["text"] elif "text" in parsed_resp: output.generated_text = parsed_resp["text"][0] else: - output.error = ("Unexpected response format: " - "neither 'choices' nor 'text' found") + output.error = ( + "Unexpected response format: " + "neither 'choices' nor 'text' found" + ) output.success = False output.success = True else: @@ -250,15 +256,17 @@ async def async_request_openai_completions( pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith( - ("completions", "profile") - ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + assert api_url.endswith(("completions", "profile")), ( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "repetition_penalty": 1.0, @@ -273,9 +281,7 @@ async def async_request_openai_completions( payload["ignore_eos"] = request_func_input.ignore_eos if request_func_input.extra_body: payload.update(request_func_input.extra_body) - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -284,8 +290,9 @@ async def async_request_openai_completions( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: first_chunk_received = False async for chunk_bytes in response.content: @@ -293,8 +300,7 @@ async def async_request_openai_completions( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": data = json.loads(chunk) @@ -314,21 +320,20 @@ async def async_request_openai_completions( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text += text or "" elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + output.output_tokens = usage.get("completion_tokens") if first_chunk_received: output.success = True else: output.success = False output.error = ( "Never received a valid chunk to calculate TTFT." - "This response will be marked as failed!") + "This response will be marked as failed!" + ) output.generated_text = generated_text output.latency = most_recent_timestamp - st else: @@ -349,23 +354,22 @@ async def async_request_openai_chat_completions( pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith( - ("chat/completions", "profile") - ), "OpenAI Chat Completions API URL must end with 'chat/completions'." + assert api_url.endswith(("chat/completions", "profile")), ( + "OpenAI Chat Completions API URL must end with 'chat/completions'." + ) - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "messages": [ - { - "role": "user", - "content": content - }, + {"role": "user", "content": content}, ], "temperature": 0.0, "max_completion_tokens": request_func_input.output_len, @@ -391,16 +395,16 @@ async def async_request_openai_chat_completions( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) @@ -414,13 +418,11 @@ async def async_request_openai_chat_completions( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) generated_text += content or "" elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + output.output_tokens = usage.get("completion_tokens") most_recent_timestamp = timestamp @@ -446,25 +448,28 @@ async def async_request_openai_audio( ) -> RequestFuncOutput: # Lazy import without PlaceholderModule to avoid vllm dep. import soundfile + api_url = request_func_input.api_url - assert api_url.endswith( - ("transcriptions", "translations" - )), "OpenAI Chat Completions API URL must end with 'transcriptions' " + assert api_url.endswith(("transcriptions", "translations")), ( + "OpenAI Chat Completions API URL must end with 'transcriptions' " + ) "or `translations`." - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: content = [{"type": "text", "text": request_func_input.prompt}] payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "temperature": 0.0, "max_completion_tokens": request_func_input.output_len, "stream": True, "language": "en", # Flattened due to multipart/form-data "stream_include_usage": True, - "stream_continuous_usage_stats": True + "stream_continuous_usage_stats": True, } if request_func_input.extra_body: payload.update(request_func_input.extra_body) @@ -479,9 +484,9 @@ def to_bytes(y, sr): buffer.seek(0) return buffer - with to_bytes(*request_func_input.multi_modal_content['audio']) as f: + with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: form = aiohttp.FormData() - form.add_field('file', f, content_type='audio/wav') + form.add_field("file", f, content_type="audio/wav") for key, value in payload.items(): form.add_field(key, str(value)) @@ -493,24 +498,22 @@ def to_bytes(y, sr): st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, - data=form, - headers=headers) as response: + async with session.post( + url=api_url, data=form, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) if choices := data.get("choices"): - content = choices[0]["delta"].get( - "content") + content = choices[0]["delta"].get("content") # First token if ttft == 0.0: ttft = timestamp - st @@ -519,12 +522,14 @@ def to_bytes(y, sr): # Decoding phase else: output.itl.append( - timestamp - most_recent_timestamp) + timestamp - most_recent_timestamp + ) generated_text += content or "" elif usage := data.get("usage"): output.output_tokens = usage.get( - "completion_tokens") + "completion_tokens" + ) most_recent_timestamp = timestamp @@ -545,7 +550,7 @@ def to_bytes(y, sr): def get_model(pretrained_model_name_or_path: str) -> str: - if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': + if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true": from modelscope import snapshot_download from vllm.model_executor.model_loader.weight_utils import get_lock @@ -556,7 +561,8 @@ def get_model(pretrained_model_name_or_path: str) -> str: model_path = snapshot_download( model_id=pretrained_model_name_or_path, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) return model_path return pretrained_model_name_or_path @@ -569,23 +575,23 @@ def get_tokenizer( **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path is not None and not os.path.exists( - pretrained_model_name_or_path): - pretrained_model_name_or_path = get_model( - pretrained_model_name_or_path) + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False if tokenizer_mode == "mistral": try: from vllm.transformers_utils.tokenizer import MistralTokenizer except ImportError as e: - raise ImportError("MistralTokenizer requires vllm package.\n" - "Please install it with `pip install vllm` " - "to use mistral tokenizer mode.") from e - return MistralTokenizer.from_pretrained( - str(pretrained_model_name_or_path)) + raise ImportError( + "MistralTokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use mistral tokenizer mode." + ) from e + return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path)) else: return AutoTokenizer.from_pretrained( pretrained_model_name_or_path, @@ -608,7 +614,7 @@ def get_tokenizer( } OPENAI_COMPATIBLE_BACKENDS = [ - k for k, v in ASYNC_REQUEST_FUNCS.items() - if v in (async_request_openai_completions, - async_request_openai_chat_completions) + k + for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, async_request_openai_chat_completions) ] diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 98d3360cd6ff..5513a5f78f1c 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -35,6 +35,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.image import convert_image_mode from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer logger = logging.getLogger(__name__) @@ -82,14 +83,12 @@ def __init__( self.dataset_path = dataset_path # Set the random seed, ensuring that a None value is replaced with the # default seed. - self.random_seed = (random_seed - if random_seed is not None else self.DEFAULT_SEED) + self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED self.data = None def apply_multimodal_chat_transformation( - self, - prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + self, prompt: str, mm_content: Optional[MultiModalDataDict] = None + ) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -111,8 +110,7 @@ def load_data(self) -> None: NotImplementedError: If a subclass does not implement this method. """ # TODO (jenniferzhao): add support for downloading data - raise NotImplementedError( - "load_data must be implemented in subclasses.") + raise NotImplementedError("load_data must be implemented in subclasses.") def get_random_lora_request( self, @@ -158,8 +156,9 @@ def get_random_lora_request( return lora_request, lora_tokenizer_cache[lora_id] or tokenizer @abstractmethod - def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int) -> list[SampleRequest]: + def sample( + self, tokenizer: PreTrainedTokenizerBase, num_requests: int + ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -177,8 +176,9 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, """ raise NotImplementedError("sample must be implemented in subclasses.") - def maybe_oversample_requests(self, requests: list[SampleRequest], - num_requests: int) -> None: + def maybe_oversample_requests( + self, requests: list[SampleRequest], num_requests: int + ) -> None: """ Oversamples the list of requests if its size is less than the desired number. @@ -189,11 +189,9 @@ def maybe_oversample_requests(self, requests: list[SampleRequest], """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, - k=num_requests - len(requests)) + additional = random.choices(requests, k=num_requests - len(requests)) requests.extend(additional) - logger.info("Oversampled requests to reach %d total samples.", - num_requests) + logger.info("Oversampled requests to reach %d total samples.", num_requests) # ----------------------------------------------------------------------------- @@ -218,14 +216,14 @@ def is_valid_sequence( """ # Check for invalid conditions prompt_too_short = prompt_len < min_len - output_too_short = (not skip_min_output_len_check) and (output_len - < min_len) + output_too_short = (not skip_min_output_len_check) and (output_len < min_len) prompt_too_long = prompt_len > max_prompt_len combined_too_long = (prompt_len + output_len) > max_total_len # Return True if none of the invalid conditions are met - return not (prompt_too_short or output_too_short or prompt_too_long - or combined_too_long) + return not ( + prompt_too_short or output_too_short or prompt_too_long or combined_too_long + ) @cache @@ -257,28 +255,28 @@ def process_image(image: Any) -> Mapping[str, Any]: Raises: ValueError: If the input is not a supported type. """ - if isinstance(image, dict) and 'bytes' in image: - image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, dict) and "bytes" in image: + image = Image.open(BytesIO(image["bytes"])) if isinstance(image, Image.Image): - image = image.convert("RGB") + image = convert_image_mode(image, "RGB") with io.BytesIO() as image_data: image.save(image_data, format="JPEG") - image_base64 = base64.b64encode( - image_data.getvalue()).decode("utf-8") + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") return { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - }, + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, } if isinstance(image, str): - image_url = (image if image.startswith( - ("http://", "file://")) else f"file://{image}") + image_url = ( + image if image.startswith(("http://", "file://")) else f"file://{image}" + ) return {"type": "image_url", "image_url": {"url": image_url}} - raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" - " or str or dictionary with raw image bytes.") + raise ValueError( + f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes." + ) # ----------------------------------------------------------------------------- @@ -318,8 +316,11 @@ def sample( num_special_tokens = tokenizer.num_special_tokens_to_add() real_input_len = input_len - num_special_tokens - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + prefix_token_ids = ( + np.random.randint(0, vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) # New sampling logic: [X * (1 - b), X * (1 + b)] input_low = int(real_input_len * (1 - range_ratio)) @@ -329,21 +330,17 @@ def sample( # Add logging for debugging logger.info("Sampling input_len from [%s, %s]", input_low, input_high) - logger.info("Sampling output_len from [%s, %s]", output_low, - output_high) - - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, - size=num_requests) + logger.info("Sampling output_len from [%s, %s]", output_low, output_high) + + input_lens = np.random.randint(input_low, input_high + 1, size=num_requests) + output_lens = np.random.randint(output_low, output_high + 1, size=num_requests) offsets = np.random.randint(0, vocab_size, size=num_requests) requests = [] for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() + inner_seq = ( + (offsets[i] + i + np.arange(input_lens[i])) % vocab_size + ).tolist() token_sequence = prefix_token_ids + inner_seq prompt = tokenizer.decode(token_sequence) # After decoding the prompt we have to encode and decode it again. @@ -354,8 +351,9 @@ def sample( # [1650, 939, 486] -> ['ฤ call', 'sh', 'ere'] # To avoid uncontrolled change of the prompt length, # the encoded sequence is truncated before being decode again. - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:input_lens[i]] + re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ + : input_lens[i] + ] prompt = tokenizer.decode(re_encoded_sequence) total_input_len = prefix_len + int(input_lens[i]) requests.append( @@ -363,7 +361,8 @@ def sample( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), - )) + ) + ) return requests @@ -390,7 +389,8 @@ def load_data(self) -> None: self.data = json.load(f) # Filter entries with at least two conversation turns. self.data = [ - entry for entry in self.data + entry + for entry in self.data if "conversations" in entry and len(entry["conversations"]) >= 2 ] random.seed(self.random_seed) @@ -416,27 +416,28 @@ def sample( ) lora_request, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path + ) prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids prompt_len = len(prompt_ids) - new_output_len = (len(completion_ids) - if output_len is None else output_len) - if not is_valid_sequence(prompt_len, - new_output_len, - skip_min_output_len_check=output_len - is not None): + new_output_len = len(completion_ids) if output_len is None else output_len + if not is_valid_sequence( + prompt_len, + new_output_len, + skip_min_output_len_check=output_len is not None, + ): continue if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation( - prompt, None) + prompt = self.apply_multimodal_chat_transformation(prompt, None) samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, lora_request=lora_request, - )) + ) + ) self.maybe_oversample_requests(samples, num_requests) return samples @@ -482,20 +483,20 @@ def sample( ) -> list: # Calculate average token length for a poem line. tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) - for tokens in tokenized_lines) / len(tokenized_lines) + avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines) # Build the base prompt. base_prompt = "Pick as many lines as you can from these poem lines:\n" base_msg = [{"role": "user", "content": base_prompt}] - base_fmt = tokenizer.apply_chat_template(base_msg, - add_generation_prompt=True, - tokenize=False) + base_fmt = tokenizer.apply_chat_template( + base_msg, add_generation_prompt=True, tokenize=False + ) base_offset = len(tokenizer(base_fmt).input_ids) if input_len <= base_offset: raise ValueError( f"'input_len' must be higher than the base prompt length " - f"({base_offset}).") + f"({base_offset})." + ) # Determine how many poem lines to use. num_input_lines = round((input_len - base_offset) / avg_len) @@ -504,21 +505,23 @@ def sample( samples = [] while len(samples) < num_requests: - extra_lines = random.choices(self.data, - k=num_input_lines - num_prefix_lines) + extra_lines = random.choices( + self.data, k=num_input_lines - num_prefix_lines + ) prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" msg = [{"role": "user", "content": prompt}] prompt_formatted = tokenizer.apply_chat_template( - msg, add_generation_prompt=True, tokenize=False) + msg, add_generation_prompt=True, tokenize=False + ) prompt_len = len(tokenizer(prompt_formatted).input_ids) if prompt_len <= input_len: samples.append( SampleRequest( - prompt=prompt_formatted - if return_prompt_formatted else prompt, + prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, - )) + ) + ) return samples @@ -538,7 +541,9 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.load_data() - def load_data(self, ): + def load_data( + self, + ): if self.dataset_path is None: raise ValueError("dataset_path must be provided for loading data.") @@ -552,8 +557,7 @@ def load_data(self, ): def _sample_loaded_data(self, num_requests: int) -> list: if num_requests <= len(self.data): - data = self.data.sample(n=num_requests, - random_state=self.random_seed) + data = self.data.sample(n=num_requests, random_state=self.random_seed) else: data = self.data.sample( n=num_requests, @@ -577,7 +581,8 @@ def sample( input_len = int(data[i][2]) output_len = int(data[i][3]) lora_req, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path + ) vocab_size = tokenizer.vocab_size # Generate a synthetic prompt: a list of token IDs computed as (i + # j) modulo vocab_size. @@ -589,7 +594,8 @@ def sample( prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, - )) + ) + ) return samples @@ -632,20 +638,23 @@ def load_data(self) -> None: class ConversationDataset(HuggingFaceDataset): """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { - 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + "lmms-lab/LLaVA-OneVision-Data", + "Aeala/ShareGPT_Vicuna_unfiltered", } IS_MULTIMODAL = True - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: # Filter examples with at least 2 conversations - filtered_data = self.data.filter( - lambda x: len(x["conversations"]) >= 2) + filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] dynamic_output = output_len is None @@ -661,24 +670,22 @@ def sample(self, completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence( - prompt_len, completion_len): + if dynamic_output and not is_valid_sequence(prompt_len, completion_len): continue - mm_content = process_image( - item["image"]) if "image" in item else None + mm_content = process_image(item["image"]) if "image" in item else None if enable_multimodal_chat: # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len and output len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -695,10 +702,8 @@ class VisionArenaDataset(HuggingFaceDataset): DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = { - "lmarena-ai/VisionArena-Chat": - lambda x: x["conversation"][0][0]["content"], - "lmarena-ai/vision-arena-bench-v0.1": - lambda x: x["turns"][0][0]["content"] + "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"], } IS_MULTIMODAL = True @@ -710,16 +715,14 @@ def sample( enable_multimodal_chat: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for item in self.data: if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) if parser_fn is None: - raise ValueError( - f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") prompt = parser_fn(item) mm_content = process_image(item["images"][0]) prompt_len = len(tokenizer(prompt).input_ids) @@ -727,15 +730,15 @@ def sample( # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -760,14 +763,15 @@ class InstructCoderDataset(HuggingFaceDataset): "likaixin/InstructCoder", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - **kwargs) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for item in self.data: if len(sampled_requests) >= num_requests: @@ -779,7 +783,8 @@ def sample(self, prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -794,38 +799,38 @@ class MTBenchDataset(HuggingFaceDataset): MT-Bench Dataset. https://huggingface.co/datasets/philschmid/mt-bench - We create a single turn dataset for MT-Bench. + We create a single turn dataset for MT-Bench. This is similar to Spec decoding benchmark setup in vLLM https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 - """ # noqa: E501 + """ # noqa: E501 DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM SUPPORTED_DATASET_PATHS = { "philschmid/mt-bench", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - **kwargs) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for item in self.data: if len(sampled_requests) >= num_requests: break - prompt = item['turns'][0] + prompt = item["turns"][0] # apply template - prompt = tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - add_generation_prompt=True, - tokenize=False) + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( @@ -833,7 +838,8 @@ def sample(self, prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -847,23 +853,27 @@ class AIMODataset(HuggingFaceDataset): """ Dataset class for processing a AIMO dataset with reasoning questions. """ + SUPPORTED_DATASET_PATHS = { - "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", - "AI-MO/NuminaMath-CoT" + "AI-MO/aimo-validation-aime", + "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: sampled_requests = [] dynamic_output = output_len is None for item in self.data: if len(sampled_requests) >= num_requests: break - prompt, completion = item['problem'], item["solution"] + prompt, completion = item["problem"], item["solution"] prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids @@ -871,10 +881,9 @@ def sample(self, completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence(prompt_len, - completion_len, - max_prompt_len=2048, - max_total_len=32000): + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000 + ): continue sampled_requests.append( SampleRequest( @@ -882,7 +891,8 @@ def sample(self, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -905,25 +915,25 @@ def sample(self, ### Response: -""" # noqa: E501 +""" # noqa: E501 def _format_zeta_prompt( - sample: dict, - original_start_marker: str = "<|editable_region_start|>") -> dict: + sample: dict, original_start_marker: str = "<|editable_region_start|>" +) -> dict: """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. - - This function formats examples from the NEP dataset - into prompts and expected outputs. It could be + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be further extended to support more NEP datasets. - + Args: - sample: The dataset sample containing events, + sample: The dataset sample containing events, inputs, and outputs. - original_start_marker: The marker indicating the - start of the editable region. Defaults to + original_start_marker: The marker indicating the + start of the editable region. Defaults to "<|editable_region_start|>". - + Returns: A dictionary with the formatted prompts and expected outputs. """ @@ -953,10 +963,8 @@ class NextEditPredictionDataset(HuggingFaceDataset): "zed-industries/zeta": _format_zeta_prompt, } - def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - **kwargs): - formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( - self.dataset_path) + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) if formatting_prompt_func is None: raise ValueError(f"Unsupported dataset path: {self.dataset_path}") samples = [] @@ -967,8 +975,10 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, prompt=sample["prompt"], prompt_len=len(tokenizer(sample["prompt"]).input_ids), expected_output_len=len( - tokenizer(sample["expected_output"]).input_ids), - )) + tokenizer(sample["expected_output"]).input_ids + ), + ) + ) if len(samples) >= num_requests: break self.maybe_oversample_requests(samples, num_requests) @@ -997,18 +1007,22 @@ class ASRDataset(HuggingFaceDataset): | AMI | Meetings | Spontaneous | ihm, sdm | +----------------+----------------------------------------+--------------------------+-----------------------------+ - """ # noqa: E501 + """ # noqa: E501 + SUPPORTED_DATASET_PATHS = { - "openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium", - "edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech" + "openslr/librispeech_asr", + "facebook/voxpopuli", + "LIUM/tedlium", + "edinburghcstr/ami", + "speechcolab/gigaspeech", + "kensho/spgispeech", } DEFAULT_OUTPUT_LEN = 128 IS_MULTIMODAL = True # TODO Whisper-specific. Abstract interface when more models are supported. - TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\ - "<|notimestamps|>" + TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" skip_long_audios: bool = True def sample( @@ -1019,8 +1033,8 @@ def sample( **kwargs, ) -> list: import librosa - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] @@ -1043,10 +1057,14 @@ def sample( prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) if skipped: - logger.warning("%d samples discarded from dataset due to" \ - " their length being greater than" \ - " what Whisper supports.", skipped) + logger.warning( + "%d samples discarded from dataset due to" + " their length being greater than" + " what Whisper supports.", + skipped, + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index dfd9bb1e6a4d..84759c5c354d 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,9 +11,9 @@ import numpy as np import torch -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm import tqdm +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType @@ -21,13 +21,14 @@ from vllm.utils import FlexibleArgumentParser -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={"latency": results["latencies"]}, - extra_info={k: results[k] - for k in ["avg_latency", "percentiles"]}) + extra_info={k: results[k] for k in ["avg_latency", "percentiles"]}, + ) if pt_records: pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" write_to_json(pt_file, pt_records) @@ -42,9 +43,11 @@ def main(args: argparse.Namespace): # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + - args.output_len), ("Please ensure that max_model_len is greater than" - " the sum of input_len and output_len.") + args.input_len + args.output_len + ), ( + "Please ensure that max_model_len is greater than" + " the sum of input_len and output_len." + ) sampling_params = SamplingParams( n=args.n, @@ -55,18 +58,16 @@ def main(args: argparse.Namespace): detokenize=not args.disable_detokenize, ) print(sampling_params) - dummy_prompt_token_ids = np.random.randint(10000, - size=(args.batch_size, - args.input_len)) - dummy_prompts: list[PromptType] = [{ - "prompt_token_ids": batch - } for batch in dummy_prompt_token_ids.tolist()] + dummy_prompt_token_ids = np.random.randint( + 10000, size=(args.batch_size, args.input_len) + ) + dummy_prompts: list[PromptType] = [ + {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() + ] def llm_generate(): if not args.use_beam_search: - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) else: llm.beam_search( dummy_prompts, @@ -80,12 +81,13 @@ def llm_generate(): def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - on_trace_ready=torch.profiler.tensorboard_trace_handler( - str(profile_dir)), + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir) + ), ) as p: llm_generate() print(p.key_averages().table(sort_by="self_cuda_time_total")) @@ -103,8 +105,9 @@ def run_to_completion(profile_dir: Optional[str] = None): if args.profile: profile_dir = args.profile_result_dir if not profile_dir: - profile_dir = (Path(".") / "vllm_benchmark_result" / - f"latency_result_{time.time()}") + profile_dir = ( + Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" + ) print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return @@ -135,7 +138,8 @@ def run_to_completion(profile_dir: Optional[str] = None): if __name__ == "__main__": parser = FlexibleArgumentParser( description="Benchmark the latency of processing a single batch of " - "requests till completion.") + "requests till completion." + ) parser.add_argument("--input-len", type=int, default=32) parser.add_argument("--output-len", type=int, default=128) parser.add_argument("--batch-size", type=int, default=8) @@ -152,10 +156,9 @@ def run_to_completion(profile_dir: Optional[str] = None): default=10, help="Number of iterations to run for warmup.", ) - parser.add_argument("--num-iters", - type=int, - default=30, - help="Number of iterations to run.") + parser.add_argument( + "--num-iters", type=int, default=30, help="Number of iterations to run." + ) parser.add_argument( "--profile", action="store_true", @@ -165,8 +168,10 @@ def run_to_completion(profile_dir: Optional[str] = None): "--profile-result-dir", type=str, default=None, - help=("path to save the pytorch profiler output. Can be visualized " - "with ui.perfetto.dev or Tensorboard."), + help=( + "path to save the pytorch profiler output. Can be visualized " + "with ui.perfetto.dev or Tensorboard." + ), ) parser.add_argument( "--output-json", @@ -177,10 +182,15 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) + # V1 enables prefix caching by default which skews the latency + # numbers. We need to disable prefix caching by default. + parser.set_defaults(enable_prefix_caching=False) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py index 21480578edbd..109624c87789 100644 --- a/benchmarks/benchmark_long_document_qa_throughput.py +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -76,7 +76,7 @@ def repeat_prompts(prompts, repeat_count, mode: str): - 'random': Shuffle the prompts randomly after repetition. - 'tile': Repeat the entire prompt list in sequence. Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. - - 'interleave': Repeat each prompt consecutively before moving to + - 'interleave': Repeat each prompt consecutively before moving to the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. Returns: @@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str): ValueError: If an invalid mode is provided. """ print("Repeat mode: ", mode) - if mode == 'random': + if mode == "random": repeated_prompts = prompts * repeat_count random.shuffle(repeated_prompts) return repeated_prompts - elif mode == 'tile': + elif mode == "tile": return prompts * repeat_count - elif mode == 'interleave': + elif mode == "interleave": repeated_prompts = [] for prompt in prompts: repeated_prompts.extend([prompt] * repeat_count) return repeated_prompts else: - raise ValueError(f"Invalid mode: {mode}, only support " - "'random', 'tile', 'interleave'") + raise ValueError( + f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'" + ) def main(args): @@ -109,16 +110,16 @@ def main(args): # we append the document id at the beginning to avoid any of the document # being the prefix of other documents prompts = [ - str(i) + ' '.join(['hi'] * args.document_length) + str(i) + " ".join(["hi"] * args.document_length) for i in range(args.num_documents) ] prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) warmup_prompts = [ - "This is warm up request " + str(i) + \ - ' '.join(['hi'] * args.document_length) - for i in range(args.num_documents)] + "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length) + for i in range(args.num_documents) + ] # Create the LLM engine engine_args = EngineArgs.from_cli_args(args) @@ -142,42 +143,52 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description= - 'Benchmark the performance with or without automatic prefix caching.') + description="Benchmark the performance with or " + "without automatic prefix caching." + ) parser.add_argument( - '--document-length', + "--document-length", type=int, # Roughly the number of tokens for a system paper, # excluding images default=20000, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') - - parser.add_argument('--num-documents', - type=int, - default=8, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') - - parser.add_argument('--output-len', type=int, default=10) - - parser.add_argument('--repeat-count', - type=int, - default=2, - help='Number of times to repeat each prompt') - - parser.add_argument("--repeat-mode", - type=str, - default='random', - help='The mode to repeat prompts. The supported ' - 'modes are "random", "tile", and "interleave". ' - 'See repeat_prompts() in the source code for details.') - - parser.add_argument("--shuffle-seed", - type=int, - default=0, - help='Random seed when the repeat mode is "random"') + help="Range of input lengths for sampling prompts, " + 'specified as "min:max" (e.g., "128:256").', + ) + + parser.add_argument( + "--num-documents", + type=int, + default=8, + help="Range of input lengths for sampling prompts, " + 'specified as "min:max" (e.g., "128:256").', + ) + + parser.add_argument("--output-len", type=int, default=10) + + parser.add_argument( + "--repeat-count", + type=int, + default=2, + help="Number of times to repeat each prompt", + ) + + parser.add_argument( + "--repeat-mode", + type=str, + default="random", + help="The mode to repeat prompts. The supported " + 'modes are "random", "tile", and "interleave". ' + "See repeat_prompts() in the source code for details.", + ) + + parser.add_argument( + "--shuffle-seed", + type=int, + default=0, + help='Random seed when the repeat mode is "random"', + ) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index f44da95d3216..ffaa8035797c 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -63,8 +63,7 @@ class Request: output_len: int -def sample_tokens(tokenizer: PreTrainedTokenizerBase, - length: int) -> list[int]: +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]: vocab = tokenizer.get_vocab() all_special_ids = set(tokenizer.all_special_ids) @@ -91,8 +90,10 @@ def sample_requests_from_dataset( # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] # Shuffle the dataset. random.shuffle(dataset) @@ -113,8 +114,9 @@ def sample_requests_from_dataset( completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = (len(completion_token_ids) - if fixed_output_len is None else fixed_output_len) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) if min_len <= prompt_len <= max_len: filtered_requests.append(Request(prompt, prompt_len, output_len)) @@ -128,27 +130,27 @@ def sample_requests_from_random( fixed_output_len: Optional[int], prefix_len: int, ) -> list[Request]: - requests = [] prefix_token_ids = sample_tokens(tokenizer, prefix_len) min_len, max_len = input_length_range for i in range(num_requests): unique_part_token_ids = sample_tokens( - tokenizer, - random.randint(min_len - prefix_len, max_len - prefix_len)) + tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len) + ) prompt_token_ids = prefix_token_ids + unique_part_token_ids prompt = tokenizer.decode(prompt_token_ids) prompt_len = len(prompt_token_ids) - assert (min_len <= prompt_len <= max_len - ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + assert min_len <= prompt_len <= max_len, ( + f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + ) requests.append(Request(prompt, prompt_len, fixed_output_len)) return requests -def repeat_and_sort_requests(requests: list[Request], - repeat_count: int, - sort: bool = False) -> list[str]: +def repeat_and_sort_requests( + requests: list[Request], repeat_count: int, sort: bool = False +) -> list[str]: repeated_requests = requests * repeat_count if sort: repeated_requests.sort(key=lambda x: x[1]) @@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request], def main(args): tokenizer = get_tokenizer(args.model, trust_remote_code=True) - input_length_range = tuple(map(int, args.input_length_range.split(':'))) + input_length_range = tuple(map(int, args.input_length_range.split(":"))) random.seed(args.seed) if args.dataset_path is not None: if args.prefix_len > 0: - raise ValueError("prefix-len is not supported when " - "dataset-path is provided.") - print(f"Start to sample {args.num_prompts} prompts " - f"from {args.dataset_path}") + raise ValueError( + "prefix-len is not supported when dataset-path is provided." + ) + print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}") filtered_requests = sample_requests_from_dataset( dataset_path=args.dataset_path, num_requests=args.num_prompts, @@ -196,14 +198,16 @@ def main(args): llm = LLM(**dataclasses.asdict(engine_args)) - sampling_params = SamplingParams(temperature=0, - max_tokens=args.output_len, - detokenize=not args.disable_detokenize) + sampling_params = SamplingParams( + temperature=0, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) print("Testing filtered requests") - prompts = repeat_and_sort_requests(filtered_requests, - repeat_count=args.repeat_count, - sort=args.sort) + prompts = repeat_and_sort_requests( + filtered_requests, repeat_count=args.repeat_count, sort=args.sort + ) print("------start generating------") test_prefix( @@ -215,29 +219,35 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description= - 'Benchmark the performance with or without automatic prefix caching.') - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset.") - parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--num-prompts', - type=int, - required=True, - help="Number of the prompts sampled from dataset") - parser.add_argument('--repeat-count', - type=int, - default=1, - help='Number of times to repeat each prompt') - parser.add_argument('--sort', - action='store_true', - help='Sort prompts by input length') - parser.add_argument('--input-length-range', - type=str, - required=True, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') + description="Benchmark the performance with or without " + "automatic prefix caching." + ) + parser.add_argument( + "--dataset-path", type=str, default=None, help="Path to the dataset." + ) + parser.add_argument("--output-len", type=int, default=10) + parser.add_argument( + "--num-prompts", + type=int, + required=True, + help="Number of the prompts sampled from dataset", + ) + parser.add_argument( + "--repeat-count", + type=int, + default=1, + help="Number of times to repeat each prompt", + ) + parser.add_argument( + "--sort", action="store_true", help="Sort prompts by input length" + ) + parser.add_argument( + "--input-length-range", + type=str, + required=True, + help="Range of input lengths for sampling prompts," + 'specified as "min:max" (e.g., "128:256").', + ) parser.add_argument( "--prefix-len", type=int, @@ -248,10 +258,12 @@ def main(args): "when dataset-path is not provided.", ) parser.add_argument( - '--disable-detokenize', - action='store_true', - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + "--disable-detokenize", + action="store_true", + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index 76fe00ede249..a05dd24dece8 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Benchmark offline prioritization.""" + import argparse import dataclasses import json @@ -13,7 +14,7 @@ from vllm.utils import FlexibleArgumentParser -#Select a equi-probable random priority +# Select a equi-probable random priority def get_random_flag(): return 0 if random.random() < 0.5 else 1 @@ -33,8 +34,10 @@ def sample_requests( # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] # Shuffle the dataset. random.shuffle(dataset) @@ -51,8 +54,9 @@ def sample_requests( completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue @@ -74,13 +78,16 @@ def run_vllm( disable_detokenize: bool = False, ) -> float: from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " input_len and output_len for all requests.") + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " input_len and output_len for all requests." + ) # Add the requests to the engine. prompts = [] @@ -97,7 +104,8 @@ def run_vllm( ignore_eos=True, max_tokens=output_len, detokenize=not disable_detokenize, - )) + ) + ) start = time.perf_counter() llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) @@ -111,26 +119,33 @@ def main(args: argparse.Namespace): # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) + args.tokenizer, trust_remote_code=args.trust_remote_code + ) if args.dataset is None: # Synthesize a prompt with the given input length. prompt = "hi" * (args.input_len - 1) - requests = [(prompt, args.input_len, args.output_len, - get_random_flag()) for _ in range(args.num_prompts)] + requests = [ + (prompt, args.input_len, args.output_len, get_random_flag()) + for _ in range(args.num_prompts) + ] else: - requests = sample_requests(args.dataset, args.num_prompts, tokenizer, - args.output_len) + requests = sample_requests( + args.dataset, args.num_prompts, tokenizer, args.output_len + ) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.n, - EngineArgs.from_cli_args(args), - args.disable_detokenize) + elapsed_time = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize + ) else: raise ValueError(f"Unknown backend: {args.backend}") - total_num_tokens = sum(prompt_len + output_len - for _, prompt_len, output_len, priority in requests) - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} tokens/s") + total_num_tokens = sum( + prompt_len + output_len for _, prompt_len, output_len, priority in requests + ) + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s" + ) # Output JSON results if specified if args.output_json: @@ -147,41 +162,44 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii"], - default="vllm") - parser.add_argument("--dataset", - type=str, - default=None, - help="Path to the dataset.") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=200, - help="Number of prompts to process.") parser.add_argument( - '--output-json', + "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm" + ) + parser.add_argument( + "--dataset", type=str, default=None, help="Path to the dataset." + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) + parser.add_argument( + "--num-prompts", type=int, default=200, help="Number of prompts to process." + ) + parser.add_argument( + "--output-json", type=str, default=None, - help='Path to save the throughput results in JSON format.') + help="Path to save the throughput results in JSON format.", + ) parser.add_argument( - '--disable-detokenize', - action='store_true', - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + "--disable-detokenize", + action="store_true", + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 89fb0e1df035..a887e7150dc7 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -20,6 +20,7 @@ --endpoint /generate_stream to the end of the command above. """ + import argparse import asyncio import gc @@ -34,12 +35,16 @@ from typing import Any, Optional import numpy as np -from backend_request_func import (ASYNC_REQUEST_FUNCS, - OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, - RequestFuncOutput) from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase +from backend_request_func import ( + ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, + RequestFuncInput, + RequestFuncOutput, +) + try: from vllm.transformers_utils.tokenizer import get_tokenizer except ImportError: @@ -50,12 +55,21 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, - ConversationDataset, HuggingFaceDataset, - InstructCoderDataset, MTBenchDataset, - NextEditPredictionDataset, RandomDataset, - SampleRequest, ShareGPTDataset, SonnetDataset, - VisionArenaDataset) +from benchmark_dataset import ( + AIMODataset, + ASRDataset, + BurstGPTDataset, + ConversationDataset, + HuggingFaceDataset, + InstructCoderDataset, + MTBenchDataset, + NextEditPredictionDataset, + RandomDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, +) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -118,7 +132,8 @@ async def get_request( # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}.") + f"A positive burstiness factor is expected, but given {burstiness}." + ) theta = 1.0 / (request_rate * burstiness) for request in input_requests: @@ -164,8 +179,10 @@ def calculate_metrics( # bundled together # Note : this may inflate the output token count slightly output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + tokenizer( + outputs[i].generated_text, add_special_tokens=False + ).input_ids + ) actual_output_lens.append(output_len) total_input += input_requests[i].prompt_len tpot = 0 @@ -188,16 +205,19 @@ def calculate_metrics( if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(goodput_config_dict["ttft"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(goodput_config_dict["tpot"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(goodput_config_dict["e2el"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) for req_metric in zip(*valid_metrics): is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) @@ -208,7 +228,8 @@ def calculate_metrics( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -217,27 +238,31 @@ def calculate_metrics( request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) * - 1000, # ttfts is empty if streaming is not supported by backend + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend std_ttft_ms=np.std(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], mean_tpot_ms=np.mean(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) - for p in selected_percentiles], + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], ) return metrics, actual_output_lens @@ -270,10 +295,12 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len, test_mm_content = \ - input_requests[0].prompt, input_requests[0].prompt_len, \ - input_requests[0].expected_output_len, \ - input_requests[0].multi_modal_data + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( @@ -293,36 +320,36 @@ async def benchmark( if not test_output.success: raise ValueError( "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") + f"are correctly specified. Error: {test_output.error}" + ) else: print("Initial test run completed. Starting main benchmark run...") if lora_modules: # For each input request, choose a LoRA module at random. lora_modules = iter( - [random.choice(lora_modules) \ - for _ in range(len(input_requests))]) + [random.choice(lora_modules) for _ in range(len(input_requests))] + ) if profile: print("Starting profiler...") - profile_input = RequestFuncInput(model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body) + profile_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: print("Profiler started") - if burstiness == 1.0: - distribution = "Poisson process" - else: - distribution = "Gamma distribution" + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") @@ -334,42 +361,45 @@ async def benchmark( # and it will simplify the code in limited_request_func. # semaphore = (asyncio.Semaphore(max_concurrency) # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None async def limited_request_func(request_func_input, pbar): if semaphore is None: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): - prompt, prompt_len, output_len, mm_content = request.prompt, \ - request.prompt_len, request.expected_output_len, \ - request.multi_modal_data + prompt, prompt_len, output_len, mm_content = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + ) req_model_id, req_model_name = model_id, model_name if lora_modules: req_lora_module = next(lora_modules) req_model_id, req_model_name = req_lora_module, req_lora_module - request_func_input = RequestFuncInput(model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body) + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - pbar=pbar))) + limited_request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -401,22 +431,32 @@ async def limited_request_func(request_func_input, pbar): goodput_config_dict=goodput_config_dict, ) - print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) - print("{:<40} {:<10.2f}".format("Request throughput (req/s):", - metrics.request_throughput)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) if goodput_config_dict: - print("{:<40} {:<10.2f}".format("Request goodput (req/s):", - metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) - print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", - metrics.total_token_throughput)) + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total Token throughput (tok/s):", metrics.total_token_throughput + ) + ) result = { "duration": benchmark_duration, @@ -424,8 +464,7 @@ async def limited_request_func(request_func_input, pbar): "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput:": - metrics.request_goodput if goodput_config_dict else None, + "request_goodput:": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], @@ -448,29 +487,35 @@ def process_one_metric( # metric. if metric_attribute_name not in selected_percentile_metrics: return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"))) + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms") + metrics, f"mean_{metric_attribute_name}_ms" + ) result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms") + metrics, f"median_{metric_attribute_name}_ms" + ) result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms") - for p, value in getattr(metrics, - f"percentiles_{metric_attribute_name}_ms"): + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -490,12 +535,14 @@ def check_goodput_args(args): raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " "The service level objective name should be one of " - f"{str(VALID_NAMES)}. ") + f"{str(VALID_NAMES)}. " + ) if slo_val < 0: raise ValueError( f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " - "non-negative.") + "non-negative." + ) return goodput_config_dict @@ -508,31 +555,42 @@ def parse_goodput(slo_pairs): except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " - "Specify service level objectives for goodput as \"KEY:VALUE\" " + 'Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is a " - "number in milliseconds.") from err + "number in milliseconds." + ) from err return goodput_config_dict -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any], - file_name: str) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any], file_name: str +) -> None: metrics = [ - "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", - "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", - "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + "median_ttft_ms", + "mean_ttft_ms", + "std_ttft_ms", + "p99_ttft_ms", + "mean_tpot_ms", + "median_tpot_ms", + "std_tpot_ms", + "p99_tpot_ms", + "median_itl_ms", + "mean_itl_ms", + "std_itl_ms", + "p99_itl_ms", ] # These raw data might be useful, but they are rather big. They can be added # later if needed ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] pt_records = convert_to_pytorch_benchmark_format( args=args, - metrics={k: [results[k]] - for k in metrics}, + metrics={k: [results[k]] for k in metrics}, extra_info={ k: results[k] - for k in results if k not in metrics and k not in ignored_metrics - }) + for k in results + if k not in metrics and k not in ignored_metrics + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" @@ -557,34 +615,42 @@ def main(args: argparse.Namespace): api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" - tokenizer = get_tokenizer(tokenizer_id, - tokenizer_mode=tokenizer_mode, - trust_remote_code=args.trust_remote_code) + tokenizer = get_tokenizer( + tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) if args.dataset_name is None: raise ValueError( "Please specify '--dataset-name' and the corresponding " - "'--dataset-path' if required.") + "'--dataset-path' if required." + ) if args.dataset_name == "sonnet": dataset = SonnetDataset(dataset_path=args.dataset_path) # For the "sonnet" dataset, formatting depends on the backend. if args.backend == "openai-chat": - input_requests = dataset.sample(num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=False) + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") - input_requests = dataset.sample(num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=True) + "Tokenizer/model must have chat template for sonnet dataset." + ) + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + ) elif args.dataset_name == "hf": # all following datasets are implemented from the @@ -611,23 +677,30 @@ def main(args: argparse.Namespace): dataset_class = ASRDataset args.hf_split = "train" else: - supported_datasets = set([ - dataset_name for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ]) + supported_datasets = set( + [ + dataset_name + for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ] + ) raise ValueError( f"Unsupported dataset path: {args.dataset_path}. " "Huggingface dataset only supports dataset_path" f" from one of following: {supported_datasets}. " "Please consider contributing if you would " - "like to add support for additional dataset formats.") + "like to add support for additional dataset formats." + ) - if (dataset_class.IS_MULTIMODAL and backend not in \ - ["openai-chat", "openai-audio"]): + if dataset_class.IS_MULTIMODAL and backend not in [ + "openai-chat", + "openai-audio", + ]: # multi-modal benchmark is only available on OpenAI Chat backend. raise ValueError( - "Multi-modal content is only supported on 'openai-chat' and " \ - "'openai-audio' backend.") + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backend." + ) input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -642,26 +715,24 @@ def main(args: argparse.Namespace): else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts), - "random": - lambda: RandomDataset(dataset_path=args.dataset_path).sample( + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, - ) + ), } try: @@ -677,15 +748,16 @@ def main(args: argparse.Namespace): "top_p": args.top_p, "top_k": args.top_k, "min_p": args.min_p, - "temperature": args.temperature - }.items() if v is not None + "temperature": args.temperature, + }.items() + if v is not None } # Sampling parameters are only supported by openai-compatible backend. if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: raise ValueError( - "Sampling parameters are only supported by openai-compatible " - "backends.") + "Sampling parameters are only supported by openai-compatible backends." + ) if "temperature" not in sampling_params: sampling_params["temperature"] = 0.0 # Default to greedy decoding. @@ -709,15 +781,14 @@ def main(args: argparse.Namespace): disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, extra_body=sampling_params, - )) + ) + ) # Save config and results to json if args.save_result or args.append_result: @@ -742,8 +813,9 @@ def main(args: argparse.Namespace): "Invalid metadata format. Please use KEY=VALUE format." ) # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") + result_json["request_rate"] = ( + args.request_rate if args.request_rate < float("inf") else "inf" + ) result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -753,24 +825,31 @@ def main(args: argparse.Namespace): if not args.save_detailed: # Remove fields with too many data points for field in [ - "input_lens", "output_lens", "ttfts", "itls", - "generated_texts", "errors" + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", ]: if field in result_json: del result_json[field] # Save to file base_model_id = model_id.split("/")[-1] - max_concurrency_str = (f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None else "") - file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa + max_concurrency_str = ( + f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None + else "" + ) + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) - with open(file_name, - mode="a+" if args.append_result else "w", - encoding='utf-8') as outfile: + with open( + file_name, mode="a+" if args.append_result else "w", encoding="utf-8" + ) as outfile: # Append a newline. if args.append_result and outfile.tell() != 0: outfile.write("\n") @@ -780,7 +859,8 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput.") + description="Benchmark the online serving throughput." + ) parser.add_argument( "--backend", type=str, @@ -809,11 +889,13 @@ def main(args: argparse.Namespace): choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.") + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) parser.add_argument( "--max-concurrency", type=int, @@ -825,7 +907,8 @@ def main(args: argparse.Namespace): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", @@ -836,8 +919,7 @@ def main(args: argparse.Namespace): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -850,11 +932,13 @@ def main(args: argparse.Namespace): "--logprobs", type=int, default=None, - help=("Number of logprobs-per-token to compute & return as part of " - "the request. If unspecified, then either (1) if beam search " - "is disabled, no logprobs are computed & a single dummy " - "logprob is returned for each token; or (2) if beam search " - "is enabled 1 logprob per token is computed"), + help=( + "Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed" + ), ) parser.add_argument( "--request-rate", @@ -938,35 +1022,38 @@ def main(args: argparse.Namespace): "--ignore-eos", action="store_true", help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) parser.add_argument( "--percentile-metrics", type=str, default="ttft,tpot,itl", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " - "Default value is \"ttft,tpot,itl\".") + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'Default value is "ttft,tpot,itl".', + ) parser.add_argument( "--metric-percentiles", type=str, default="99", help="Comma-separated list of percentiles for selected metrics. " - "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " - "Default value is \"99\". " - "Use \"--percentile-metrics\" to select metrics.", + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99". ' + 'Use "--percentile-metrics" to select metrics.', ) parser.add_argument( "--goodput", nargs="+", required=False, - help="Specify service level objectives for goodput as \"KEY:VALUE\" " + help='Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is in " - "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' "separated by spaces. Allowed request level metric names are " - "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + '"ttft", "tpot", "e2el". For more context on the definition of ' "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) # group for dataset specific arguments sonnet_group = parser.add_argument_group("sonnet dataset options") @@ -974,22 +1061,19 @@ def main(args: argparse.Namespace): "--sonnet-input-len", type=int, default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", + help="Number of input tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-output-len", type=int, default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", + help="Number of output tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-prefix-len", type=int, default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", + help="Number of prefix tokens per request, used only for sonnet dataset.", ) sharegpt_group = parser.add_argument_group("sharegpt dataset options") @@ -998,22 +1082,21 @@ def main(args: argparse.Namespace): type=int, default=None, help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") + "from the ShareGPT dataset.", + ) random_group = parser.add_argument_group("random dataset options") random_group.add_argument( "--random-input-len", type=int, default=1024, - help= - "Number of input tokens per request, used only for random sampling.", + help="Number of input tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-output-len", type=int, default=128, - help= - "Number of output tokens per request, used only for random sampling.", + help="Number of output tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-range-ratio", @@ -1028,23 +1111,23 @@ def main(args: argparse.Namespace): "--random-prefix-len", type=int, default=0, - help=("Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]."), + help=( + "Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]." + ), ) hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - hf_group.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + hf_group.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) hf_group.add_argument( "--hf-output-len", type=int, @@ -1058,52 +1141,58 @@ def main(args: argparse.Namespace): "--top-p", type=float, default=None, - help="Top-p sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Top-p sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--top-k", type=int, default=None, - help="Top-k sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Top-k sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--min-p", type=float, default=None, - help="Min-p sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Min-p sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--temperature", type=float, default=None, help="Temperature sampling parameter. Only has effect on " "openai-compatible backends. If not specified, default to greedy " - "decoding (i.e. temperature==0.0).") + "decoding (i.e. temperature==0.0).", + ) parser.add_argument( - '--tokenizer-mode', + "--tokenizer-mode", type=str, default="auto", - choices=['auto', 'slow', 'mistral', 'custom'], + choices=["auto", "slow", "mistral", "custom"], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' - 'always use the slow tokenizer. \n* ' + "always use the slow tokenizer. \n* " '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.') - - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The model name used in the API. " - "If not specified, the model name will be the " - "same as the ``--model`` argument. ") - - parser.add_argument("--lora-modules", - nargs='+', - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.") + '"custom" will use --tokenizer to select the preregistered tokenizer.', + ) + + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) + + parser.add_argument( + "--lora-modules", + nargs="+", + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.", + ) args = parser.parse_args() diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 9084255d2440..6a50f47d3951 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -19,6 +19,7 @@ --endpoint /generate_stream to the end of the command above. """ + import argparse import asyncio import copy @@ -36,11 +37,15 @@ import datasets import numpy as np import pandas as pd -from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, - RequestFuncOutput) from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase +from backend_request_func import ( + ASYNC_REQUEST_FUNCS, + RequestFuncInput, + RequestFuncOutput, +) + try: from vllm.transformers_utils.tokenizer import get_tokenizer except ImportError: @@ -52,7 +57,8 @@ from argparse import ArgumentParser as FlexibleArgumentParser from vllm.v1.structured_output.backend_xgrammar import ( - has_xgrammar_unsupported_json_features) + has_xgrammar_unsupported_json_features, +) MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -98,6 +104,7 @@ class SampleRequest: prompt_len: The length of the prompt in tokens. expected_output_len: The expected length of the output in tokens. """ + prompt: str prompt_len: int expected_output_len: int @@ -106,32 +113,28 @@ class SampleRequest: completion: str = None -def sample_requests(tokenizer: PreTrainedTokenizerBase, - args: argparse.Namespace) -> list[SampleRequest]: - if args.dataset == 'json' or args.dataset == 'json-unique': +def sample_requests( + tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace +) -> list[SampleRequest]: + if args.dataset == "json" or args.dataset == "json-unique": if args.json_schema_path is None: dir_path = os.path.dirname(os.path.realpath(__file__)) - args.json_schema_path = os.path.join(dir_path, - "structured_schemas", - "structured_schema_1.json") + args.json_schema_path = os.path.join( + dir_path, "structured_schemas", "structured_schema_1.json" + ) json_schemas = [] with open(args.json_schema_path) as f: schema = json.load(f) - if args.dataset == 'json-unique': - json_schemas = [ - copy.deepcopy(schema) for _ in range(args.num_prompts) - ] + if args.dataset == "json-unique": + json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)] for i in range(len(json_schemas)): if "properties" not in json_schemas[i]: json_schemas[i]["properties"] = {} - json_schemas[i]["properties"][ - f"__optional_field_{uuid.uuid4()}"] = { - "type": - "string", - "description": - "An unique optional field to avoid cached schemas" - } + json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = { + "type": "string", + "description": "An unique optional field to avoid cached schemas", + } else: json_schemas = [schema] * args.num_prompts @@ -142,11 +145,13 @@ def get_schema(index: int): return json_schemas[index % len(json_schemas)] requests = [ - SampleRequest(prompt=gen_prompt(i), - prompt_len=len(tokenizer(gen_prompt(i)).input_ids), - expected_output_len=args.output_len, - schema=get_schema(i), - structure_type=args.structure_type) + SampleRequest( + prompt=gen_prompt(i), + prompt_len=len(tokenizer(gen_prompt(i)).input_ids), + expected_output_len=args.output_len, + schema=get_schema(i), + structure_type=args.structure_type, + ) for i in range(args.num_prompts) ] @@ -170,11 +175,13 @@ def get_schema(index: int): input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] @@ -188,11 +195,13 @@ def get_schema(index: int): input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=regex, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=regex, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] @@ -203,48 +212,55 @@ def get_schema(index: int): input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=choice, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=choice, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] elif args.dataset == "xgrammar_bench": requests: list[SampleRequest] = [] - dataset = datasets.load_dataset("NousResearch/json-mode-eval", - split="train") + dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train") full_dataset_len = len(dataset) def _filter_func(item): import json + schema = json.loads(item["schema"]) return not has_xgrammar_unsupported_json_features(schema) dataset = dataset.filter(_filter_func) num_filtered_out = full_dataset_len - len(dataset) - print(f"dataset has {len(dataset)} entries after filtering " - f"out {num_filtered_out} entries with unsupported features") + print( + f"dataset has {len(dataset)} entries after filtering " + f"out {num_filtered_out} entries with unsupported features" + ) len_dataset = len(dataset) for data_point_idx in range(args.num_prompts): idx = data_point_idx while idx >= len_dataset: idx -= len_dataset schema = dataset["schema"][idx] - prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], - tokenize=False, - add_generation_prompt=True) + prompt = tokenizer.apply_chat_template( + dataset["prompt"][idx], tokenize=False, add_generation_prompt=True + ) input_len = len(tokenizer(prompt).input_ids) completion = dataset["completion"][idx] requests.append( - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - structure_type=args.structure_type, - completion=completion)) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + completion=completion, + ) + ) return requests @@ -276,7 +292,8 @@ async def get_request( # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}.") + f"A positive burstiness factor is expected, but given {burstiness}." + ) theta = 1.0 / (request_rate * burstiness) for i, request in enumerate(input_requests): @@ -318,8 +335,8 @@ def calculate_metrics( # multiple output tokens may be bundled together # Note : this may inflate the output token count slightly output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) actual_output_lens.append(output_len) total_input += input_requests[i].prompt_len tpot = 0 @@ -343,16 +360,19 @@ def calculate_metrics( if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(goodput_config_dict["ttft"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(goodput_config_dict["tpot"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(goodput_config_dict["e2el"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) for req_metric in zip(*valid_metrics): is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) @@ -363,7 +383,8 @@ def calculate_metrics( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -372,27 +393,31 @@ def calculate_metrics( request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) * - 1000, # ttfts is empty if streaming is not supported by backend + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend std_ttft_ms=np.std(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], mean_tpot_ms=np.mean(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) - for p in selected_percentiles], + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], ) return metrics, actual_output_lens @@ -429,12 +454,13 @@ def prepare_extra_body(request) -> dict: print("Starting initial single prompt test run...") structured_output_req_idx = random.sample( - range(len(input_requests)), - int(len(input_requests) * structured_output_ratio)) + range(len(input_requests)), int(len(input_requests) * structured_output_ratio) + ) test_request = input_requests[0] - test_req_extra_body = (prepare_extra_body(test_request) - if 0 in structured_output_req_idx else None) + test_req_extra_body = ( + prepare_extra_body(test_request) if 0 in structured_output_req_idx else None + ) test_input = RequestFuncInput( model=model_id, prompt=test_request.prompt, @@ -448,7 +474,8 @@ def prepare_extra_body(request) -> dict: if not test_output.success: raise ValueError( "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") + f"are correctly specified. Error: {test_output.error}" + ) else: print("Initial test run completed. Starting main benchmark run...") @@ -467,10 +494,7 @@ def prepare_extra_body(request) -> dict: if profile_output.success: print("Profiler started") - if burstiness == 1.0: - distribution = "Poisson process" - else: - distribution = "Gamma distribution" + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") @@ -482,24 +506,21 @@ def prepare_extra_body(request) -> dict: # and it will simplify the code in limited_request_func. # semaphore = (asyncio.Semaphore(max_concurrency) # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None async def limited_request_func(request_func_input, pbar): if semaphore is None: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] expected: list[str] = [] - async for i, request in get_request(input_requests, request_rate, - burstiness): - extra_body = prepare_extra_body( - request) if i in structured_output_req_idx else None + async for i, request in get_request(input_requests, request_rate, burstiness): + extra_body = ( + prepare_extra_body(request) if i in structured_output_req_idx else None + ) request_func_input = RequestFuncInput( model=model_id, prompt=request.prompt, @@ -512,8 +533,9 @@ async def limited_request_func(request_func_input, pbar): expected.append(request.completion) tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - pbar=pbar))) + limited_request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -545,54 +567,58 @@ async def limited_request_func(request_func_input, pbar): goodput_config_dict=goodput_config_dict, ) - print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) - print("{:<40} {:<10.2f}".format("Request throughput (req/s):", - metrics.request_throughput)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) if goodput_config_dict: - print("{:<40} {:<10.2f}".format("Request goodput (req/s):", - metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) - print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", - metrics.total_token_throughput)) + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total Token throughput (tok/s):", metrics.total_token_throughput + ) + ) result = { - "duration": - benchmark_duration, - "completed": - metrics.completed, - "total_input_tokens": - metrics.total_input, - "total_output_tokens": - metrics.total_output, - "request_throughput": - metrics.request_throughput, - "output_throughput": - metrics.output_throughput, - "total_token_throughput": - metrics.total_token_throughput, - "ttft_description": - pd.Series([output.ttft for output in outputs]).describe().to_dict(), - "tpot_description": - pd.Series([output.tpot for output in outputs]).describe().to_dict(), + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "ttft_description": pd.Series([output.ttft for output in outputs]) + .describe() + .to_dict(), + "tpot_description": pd.Series([output.tpot for output in outputs]) + .describe() + .to_dict(), "input_lens": [output.prompt_len for output in outputs], - "output_lens": - actual_output_lens, + "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], "itls": [output.itl for output in outputs], "errors": [output.error for output in outputs], } - ret = [{ - 'generated': output.generated_text, - 'expected': gt - } for output, gt in zip(outputs, expected)] + ret = [ + {"generated": output.generated_text, "expected": gt} + for output, gt in zip(outputs, expected) + ] def process_one_metric( # E.g., "ttft" @@ -606,29 +632,35 @@ def process_one_metric( # metric. if metric_attribute_name not in selected_percentile_metrics: return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"))) + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms") + metrics, f"mean_{metric_attribute_name}_ms" + ) result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms") + metrics, f"median_{metric_attribute_name}_ms" + ) result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms") - for p, value in getattr(metrics, - f"percentiles_{metric_attribute_name}_ms"): + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -638,13 +670,13 @@ def process_one_metric( def evaluate(ret, args): - def _eval_correctness_json(expected, actual): # extract json string from string using regex - import re - actual = actual.replace('\n', '').replace(' ', '').strip() + import regex as re + + actual = actual.replace("\n", "").replace(" ", "").strip() try: - actual = re.search(r'\{.*\}', actual).group() + actual = re.search(r"\{.*\}", actual).group() actual = json.loads(actual) except Exception: return False @@ -655,29 +687,33 @@ def _eval_correctness_choice(expected, actual): return actual in args.choice def _eval_correctness_regex(expected, actual): - import re + import regex as re + return re.match(args.regex, actual) is not None def _eval_correctness(expected, actual): - if args.structure_type == 'guided_json': + if args.structure_type == "guided_json": return _eval_correctness_json(expected, actual) - elif args.structure_type == 'guided_regex': + elif args.structure_type == "guided_regex": return _eval_correctness_regex(expected, actual) - elif args.structure_type == 'guided_choice': + elif args.structure_type == "guided_choice": return _eval_correctness_choice(expected, actual) else: return None scores = [] for res in ret: - score = _eval_correctness(res['expected'], res['generated']) - res['correctness'] = score + score = _eval_correctness(res["expected"], res["generated"]) + res["correctness"] = score scores.append(score) not_none_scores = [score for score in scores if score is not None] - return (sum(not_none_scores) / len(not_none_scores) * - 100) if len(not_none_scores) > 0 else None + return ( + (sum(not_none_scores) / len(not_none_scores) * 100) + if len(not_none_scores) > 0 + else None + ) def parse_goodput(slo_pairs): @@ -689,9 +725,10 @@ def parse_goodput(slo_pairs): except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " - "Specify service level objectives for goodput as \"KEY:VALUE\" " + 'Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is a " - "number in milliseconds.") from err + "number in milliseconds." + ) from err return goodput_config_dict @@ -705,12 +742,14 @@ def check_goodput_args(args): raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " "The service level objective name should be one of " - f"{str(VALID_NAMES)}. ") + f"{str(VALID_NAMES)}. " + ) if slo_val < 0: raise ValueError( f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " - "non-negative.") + "non-negative." + ) return goodput_config_dict @@ -736,19 +775,19 @@ def main(args: argparse.Namespace): tokenizer_mode=args.tokenizer_mode, ) - if args.dataset == 'grammar': - args.structure_type = 'guided_grammar' - elif args.dataset == 'regex': - args.structure_type = 'guided_regex' - elif args.dataset == 'choice': - args.structure_type = 'guided_choice' + if args.dataset == "grammar": + args.structure_type = "guided_grammar" + elif args.dataset == "regex": + args.structure_type = "guided_regex" + elif args.dataset == "choice": + args.structure_type = "guided_choice" else: - args.structure_type = 'guided_json' + args.structure_type = "guided_json" if args.no_structured_output: args.structured_output_ratio = 0 if args.save_results: - result_file_name = f'{args.structured_output_ratio}guided' + result_file_name = f"{args.structured_output_ratio}guided" result_file_name += f"_{backend}" result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.model.split('/')[-1]}" @@ -776,36 +815,29 @@ def main(args: argparse.Namespace): disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, max_concurrency=args.max_concurrency, structured_output_ratio=args.structured_output_ratio, goodput_config_dict=goodput_config_dict, - )) + ) + ) # Save config and results to json score = evaluate(ret, args) - print("correct_rate(%)", score, '\n') + print("correct_rate(%)", score, "\n") if args.save_results: results = { - "backend": - backend, - "model_id": - model_id, - "tokenizer_id": - tokenizer_id, - "num_prompts": - args.num_prompts, - "request_rate": - args.request_rate if args.request_rate < float("inf") else "inf", - "burstiness": - args.burstiness, - "max_concurrency": - args.max_concurrency, - "correct_rate(%)": - score + "backend": backend, + "model_id": model_id, + "tokenizer_id": tokenizer_id, + "num_prompts": args.num_prompts, + "request_rate": args.request_rate + if args.request_rate < float("inf") + else "inf", + "burstiness": args.burstiness, + "max_concurrency": args.max_concurrency, + "correct_rate(%)": score, } results = {"outputs": ret, **results, **benchmark_result} @@ -814,13 +846,14 @@ def main(args: argparse.Namespace): result_file_name = args.result_filename if args.result_dir: result_file_name = os.path.join(args.result_dir, result_file_name) - with open(result_file_name, "w", encoding='utf-8') as outfile: + with open(result_file_name, "w", encoding="utf-8") as outfile: json.dump(results, outfile, indent=4) if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput.") + description="Benchmark the online serving throughput." + ) parser.add_argument( "--backend", type=str, @@ -842,16 +875,14 @@ def main(args: argparse.Namespace): default="/v1/completions", help="API endpoint.", ) - parser.add_argument("--dataset", - default='json', - choices=[ - 'json', 'json-unique', 'grammar', 'regex', - 'choice', 'xgrammar_bench' - ]) - parser.add_argument("--json-schema-path", - type=str, - default=None, - help="Path to json schema.") + parser.add_argument( + "--dataset", + default="json", + choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"], + ) + parser.add_argument( + "--json-schema-path", type=str, default=None, help="Path to json schema." + ) parser.add_argument( "--max-concurrency", type=int, @@ -863,7 +894,8 @@ def main(args: argparse.Namespace): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", type=str, @@ -873,15 +905,13 @@ def main(args: argparse.Namespace): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument( "--tokenizer-mode", type=str, default="auto", - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument( "--num-prompts", @@ -958,44 +988,51 @@ def main(args: argparse.Namespace): "--ignore-eos", action="store_true", help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) parser.add_argument( "--percentile-metrics", type=str, default="ttft,tpot,itl", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " - "Default value is \"ttft,tpot,itl\".") + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'Default value is "ttft,tpot,itl".', + ) parser.add_argument( "--metric-percentiles", type=str, default="99", help="Comma-separated list of percentiles for selected metrics. " - "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " - "Default value is \"99\". " - "Use \"--percentile-metrics\" to select metrics.", + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99". ' + 'Use "--percentile-metrics" to select metrics.', ) parser.add_argument( "--goodput", nargs="+", required=False, - help="Specify service level objectives for goodput as \"KEY:VALUE\" " + help='Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is in " - "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' "separated by spaces. Allowed request level metric names are " - "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + '"ttft", "tpot", "e2el". For more context on the definition of ' "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve") - - parser.add_argument("--no-structured-output", - action='store_true', - default=False, - help="Whether to disable JSON decoding or not.") - parser.add_argument("--structured-output-ratio", - type=float, - default=1.0, - help="Ratio of Structured Outputs requests") + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + + parser.add_argument( + "--no-structured-output", + action="store_true", + default=False, + help="Whether to disable JSON decoding or not.", + ) + parser.add_argument( + "--structured-output-ratio", + type=float, + default=1.0, + help="Ratio of Structured Outputs requests", + ) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1f65277e1bfe..7a13babda9d1 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Benchmark offline inference throughput.""" + import argparse import dataclasses import json @@ -11,18 +12,25 @@ import torch import uvloop -from benchmark_dataset import (AIMODataset, BurstGPTDataset, - ConversationDataset, InstructCoderDataset, - RandomDataset, SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm import tqdm -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase) - +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase + +from benchmark_dataset import ( + AIMODataset, + BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, + RandomDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, +) +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -37,23 +45,30 @@ def run_vllm( disable_detokenize: bool = False, ) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] for request in requests: prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + TokensPrompt( + prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data, + ) + if "prompt_token_ids" in request.prompt + else TextPrompt( + prompt=request.prompt, multi_modal_data=request.multi_modal_data + ) + ) sampling_params.append( SamplingParams( n=n, @@ -62,7 +77,8 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests: Optional[list[LoRARequest]] = None if engine_args.enable_lora: lora_requests = [request.lora_request for request in requests] @@ -72,10 +88,9 @@ def run_vllm( outputs = None if not use_beam_search: start = time.perf_counter() - outputs = llm.generate(prompts, - sampling_params, - lora_request=lora_requests, - use_tqdm=True) + outputs = llm.generate( + prompts, sampling_params, lora_request=lora_requests, use_tqdm=True + ) end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" @@ -91,30 +106,35 @@ def run_vllm( beam_width=n, max_tokens=output_len, ignore_eos=True, - )) + ), + ) end = time.perf_counter() return end - start, outputs def run_vllm_chat( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, list[RequestOutput]]: """ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking multimodal models as it properly handles multimodal inputs and chat formatting. For non-multimodal models, use run_vllm() instead. """ from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of " - "prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests." + ) prompts = [] sampling_params: list[SamplingParams] = [] @@ -128,7 +148,8 @@ def run_vllm_chat( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) start = time.perf_counter() outputs = llm.chat(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() @@ -145,13 +166,17 @@ async def run_vllm_async( from vllm import SamplingParams async with build_async_engine_client_from_engine_args( - engine_args, disable_frontend_multiprocessing) as llm: + engine_args, disable_frontend_multiprocessing + ) as llm: + model_config = await llm.get_model_config() assert all( - llm.model_config.max_model_len >= (request.prompt_len + - request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] @@ -159,11 +184,15 @@ async def run_vllm_async( lora_requests: list[Optional[LoRARequest]] = [] for request in requests: prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + TokensPrompt( + prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data, + ) + if "prompt_token_ids" in request.prompt + else TextPrompt( + prompt=request.prompt, multi_modal_data=request.multi_modal_data + ) + ) sampling_params.append( SamplingParams( n=n, @@ -172,17 +201,16 @@ async def run_vllm_async( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() - for i, (prompt, sp, - lr) in enumerate(zip(prompts, sampling_params, lora_requests)): - generator = llm.generate(prompt, - sp, - lora_request=lr, - request_id=f"test{i}") + for i, (prompt, sp, lr) in enumerate( + zip(prompts, sampling_params, lora_requests) + ): + generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: @@ -201,7 +229,8 @@ def run_hf( disable_detokenize: bool = False, ) -> float: llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token @@ -224,14 +253,15 @@ def run_hf( # Check if we can add more requests to the batch. next_prompt_len = requests[i + 1].prompt_len next_output_len = requests[i + 1].expected_output_len - if (max(max_prompt_len, next_prompt_len) + - max(max_output_len, next_output_len)) <= 2048: + if ( + max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len) + ) <= 2048: # We can add more requests to the batch. continue # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", - padding=True).input_ids + input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), do_sample=True, @@ -261,6 +291,7 @@ def run_mii( output_len: int, ) -> float: from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) prompts = [request.prompt for request in requests] @@ -272,8 +303,9 @@ def run_mii( return end - start -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={ @@ -281,9 +313,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, "tokens_per_second": [results["tokens_per_second"]], }, extra_info={ - k: results[k] - for k in ["elapsed_time", "num_requests", "total_num_tokens"] - }) + k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" @@ -315,7 +347,8 @@ def get_requests(args, tokenizer): sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_name == "sonnet": assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") + "Tokenizer/model must have chat template for sonnet dataset." + ) dataset_cls = SonnetDataset sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["return_prompt_formatted"] = True @@ -324,21 +357,21 @@ def get_requests(args, tokenizer): elif args.dataset_name == "hf": if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: dataset_cls = VisionArenaDataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_cls = InstructCoderDataset - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_split"] = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_cls = ConversationDataset - common_kwargs['dataset_subset'] = args.hf_subset - common_kwargs['dataset_split'] = args.hf_split + common_kwargs["dataset_subset"] = args.hf_subset + common_kwargs["dataset_split"] = args.hf_split sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_cls = AIMODataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" else: raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values @@ -353,10 +386,10 @@ def main(args: argparse.Namespace): random.seed(args.seed) # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) + args.tokenizer, trust_remote_code=args.trust_remote_code + ) requests = get_requests(args, tokenizer) - is_multi_modal = any(request.multi_modal_data is not None - for request in requests) + is_multi_modal = any(request.multi_modal_data is not None for request in requests) request_outputs: Optional[list[RequestOutput]] = None if args.backend == "vllm": if args.async_engine: @@ -367,23 +400,34 @@ def main(args: argparse.Namespace): AsyncEngineArgs.from_cli_args(args), args.disable_frontend_multiprocessing, args.disable_detokenize, - )) + ) + ) else: elapsed_time, request_outputs = run_vllm( - requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + requests, + args.n, + EngineArgs.from_cli_args(args), + args.disable_detokenize, + ) elif args.backend == "hf": assert args.tensor_parallel_size == 1 - elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.hf_max_batch_size, args.trust_remote_code, - args.disable_detokenize) + elapsed_time = run_hf( + requests, + args.model, + tokenizer, + args.n, + args.hf_max_batch_size, + args.trust_remote_code, + args.disable_detokenize, + ) elif args.backend == "mii": - elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, - args.output_len) + elapsed_time = run_mii( + requests, args.model, args.tensor_parallel_size, args.output_len + ) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat( - requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize + ) else: raise ValueError(f"Unknown backend: {args.backend}") @@ -395,28 +439,31 @@ def main(args: argparse.Namespace): for ro in request_outputs: if not isinstance(ro, RequestOutput): continue - total_prompt_tokens += len( - ro.prompt_token_ids) if ro.prompt_token_ids else 0 - total_output_tokens += sum( - len(o.token_ids) for o in ro.outputs if o) + total_prompt_tokens += ( + len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 + ) + total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) total_num_tokens = total_prompt_tokens + total_output_tokens else: - total_num_tokens = sum(r.prompt_len + r.expected_output_len - for r in requests) + total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) total_output_tokens = sum(r.expected_output_len for r in requests) total_prompt_tokens = total_num_tokens - total_output_tokens if is_multi_modal and args.backend != "vllm-chat": - print("\033[91mWARNING\033[0m: Multi-modal request with " - f"{args.backend} backend detected. The " - "following metrics are not accurate because image tokens are not" - " counted. See vllm-project/vllm/issues/9778 for details.") + print( + "\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details." + ) # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. # vllm-chat backend counts the image tokens now - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s" + ) print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") @@ -444,7 +491,8 @@ def validate_args(args): warnings.warn( "The '--dataset' argument will be deprecated in the next release. " "Please use '--dataset-name' and '--dataset-path' instead.", - stacklevel=2) + stacklevel=2, + ) args.dataset_path = args.dataset if not getattr(args, "tokenizer", None): @@ -457,9 +505,8 @@ def validate_args(args): # === Dataset Configuration === if not args.dataset and not args.dataset_path: - print( - "When dataset path is not set, it will default to random dataset") - args.dataset_name = 'random' + print("When dataset path is not set, it will default to random dataset") + args.dataset_name = "random" if args.input_len is None: raise ValueError("input_len must be provided for a random dataset") @@ -467,41 +514,55 @@ def validate_args(args): # --hf-subset and --hf-split: only used # when dataset_name is 'hf' if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None): - warnings.warn("--hf-subset and --hf-split will be ignored \ + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None + ): + warnings.warn( + "--hf-subset and --hf-split will be ignored \ since --dataset-name is not 'hf'.", - stacklevel=2) + stacklevel=2, + ) elif args.dataset_name == "hf": if args.dataset_path in ( - VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() - | ConversationDataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 - elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm-chat", ( + f"{args.dataset_path} needs to use vllm-chat as the backend." + ) # noqa: E501 + elif args.dataset_path in ( + InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm", ( + f"{args.dataset_path} needs to use vllm as the backend." + ) # noqa: E501 else: - raise ValueError( - f"{args.dataset_path} is not supported by hf dataset.") + raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != 'random' and args.random_range_ratio is not None: - warnings.warn("--random-range-ratio will be ignored since \ + if args.dataset_name != "random" and args.random_range_ratio is not None: + warnings.warn( + "--random-range-ratio will be ignored since \ --dataset-name is not 'random'.", - stacklevel=2) + stacklevel=2, + ) # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # set. - if args.dataset_name not in {"random", "sonnet", None - } and args.prefix_len is not None: - warnings.warn("--prefix-len will be ignored since --dataset-name\ + if ( + args.dataset_name not in {"random", "sonnet", None} + and args.prefix_len is not None + ): + warnings.warn( + "--prefix-len will be ignored since --dataset-name\ is not 'random', 'sonnet', or not set.", - stacklevel=2) + stacklevel=2, + ) # === LoRA Settings === if getattr(args, "enable_lora", False) and args.backend != "vllm": - raise ValueError( - "LoRA benchmarking is only supported for vLLM backend") + raise ValueError("LoRA benchmarking is only supported for vLLM backend") if getattr(args, "enable_lora", False) and args.lora_path is None: raise ValueError("LoRA path must be provided when enable_lora is True") @@ -511,8 +572,10 @@ def validate_args(args): if args.backend != "hf" and args.hf_max_batch_size is not None: raise ValueError("HF max batch size is only for HF backend.") - if args.backend in {"hf", "mii"} and getattr(args, "quantization", - None) is not None: + if ( + args.backend in {"hf", "mii"} + and getattr(args, "quantization", None) is not None + ): raise ValueError("Quantization is only for vLLM backend.") if args.backend == "mii" and args.dtype != "auto": @@ -520,29 +583,32 @@ def validate_args(args): if args.backend == "mii" and args.n != 1: raise ValueError("n must be 1 for MII backend.") if args.backend == "mii" and args.tokenizer != args.model: - raise ValueError( - "Tokenizer must be the same as the model for MII backend.") + raise ValueError("Tokenizer must be the same as the model for MII backend.") # --data-parallel is not supported currently. # https://github.com/vllm-project/vllm/issues/16222 if args.data_parallel_size > 1: raise ValueError( "Data parallel is not supported in offline benchmark, \ - please use benchmark serving instead") + please use benchmark serving instead" + ) if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii", "vllm-chat"], - default="vllm") + parser.add_argument( + "--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm", + ) parser.add_argument( "--dataset-name", type=str, choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], help="Name of the dataset to benchmark on.", - default="sharegpt") + default="sharegpt", + ) parser.add_argument( "--dataset", type=str, @@ -550,57 +616,70 @@ def validate_args(args): help="Path to the ShareGPT dataset, will be deprecated in\ the next release. The dataset is expected to " "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: ]]]]") - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.") - parser.add_argument("--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.") + "list[dict[..., value: ]]]]", + ) + parser.add_argument( + "--dataset-path", type=str, default=None, help="Path to the dataset" + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) parser.add_argument( - '--output-json', + "--num-prompts", type=int, default=1000, help="Number of prompts to process." + ) + parser.add_argument( + "--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.", + ) + parser.add_argument( + "--output-json", type=str, default=None, - help='Path to save the throughput results in JSON format.') - parser.add_argument("--async-engine", - action='store_true', - default=False, - help="Use vLLM async engine rather than LLM class.") - parser.add_argument("--disable-frontend-multiprocessing", - action='store_true', - default=False, - help="Disable decoupled async engine frontend.") + help="Path to save the throughput results in JSON format.", + ) + parser.add_argument( + "--async-engine", + action="store_true", + default=False, + help="Use vLLM async engine rather than LLM class.", + ) + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + default=False, + help="Disable decoupled async engine frontend.", + ) parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize the response (i.e. do not include " - "detokenization time in the measurement)")) + help=( + "Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)" + ), + ) # LoRA parser.add_argument( "--lora-path", type=str, default=None, - help="Path to the lora adapters to use. This can be an absolute path, " - "a relative path, or a Hugging Face model identifier.") + help="Path to the LoRA adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.", + ) parser.add_argument( "--prefix-len", type=int, @@ -614,7 +693,8 @@ def validate_args(args): f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " "controls how much of the input is fixed lines versus " "random lines, but the total input length remains approximately " - "input_len tokens.") + "input_len tokens.", + ) # random dataset parser.add_argument( "--random-range-ratio", @@ -628,14 +708,12 @@ def validate_args(args): ) # hf dtaset - parser.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - parser.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + parser.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + parser.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index 45a0ddbd5d08..b0c4fca92c3d 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -7,9 +7,9 @@ from typing import Any -def convert_to_pytorch_benchmark_format(args: argparse.Namespace, - metrics: dict[str, list], - extra_info: dict[str, Any]) -> list: +def convert_to_pytorch_benchmark_format( + args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any] +) -> list: """ Save the benchmark results in the format used by PyTorch OSS benchmark with on metric per record @@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, }, } - tp = record["benchmark"]["extra_info"]["args"].get( - "tensor_parallel_size") + tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size") # Save tensor_parallel_size parameter if it's part of the metadata if not tp and "tensor_parallel_size" in extra_info: - record["benchmark"]["extra_info"]["args"][ - "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = ( + extra_info["tensor_parallel_size"] + ) records.append(record) @@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, class InfEncoder(json.JSONEncoder): - def clear_inf(self, o: Any): if isinstance(o, dict): return {k: self.clear_inf(v) for k, v in o.items()} diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index 9e36b0a9d3bb..da258f98e085 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -23,8 +23,9 @@ # bench -def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, - **kwargs) -> TMeasurement: +def bench_fn( + label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs +) -> TMeasurement: min_run_time = 1 globals = { @@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ).blocked_autorange(min_run_time=min_run_time) -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_int8( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: assert dtype == torch.int8 b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) - out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, - torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_compressed, e, scale_a, scale_b, torch.bfloat16 + ) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) if not torch.allclose(out, out_ref): @@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, timers = [] # pytorch impl - bfloat16 timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16))) + bench_fn( + label, + sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16), + ) + ) # pytorch impl - float16 timers.append( - bench_fn(label, sub_label, - "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, - a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + bench_fn( + label, + sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.float16), + b.to(dtype=torch.float16), + ) + ) # cutlass impl timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass with bias timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) # cutlass sparse impl timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass sparse with bias timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16, bias)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) return timers -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_fp8( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: assert dtype == torch.float8_e4m3fn - b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, - k) + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) - out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, - torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_compressed, e, scale_a, scale_b, torch.bfloat16 + ) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) if not torch.allclose(out, out_ref): @@ -124,97 +180,165 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, # pytorch impl w. bf16 timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"))) + bench_fn( + label, + sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), + ) + ) # pytorch impl: bf16 output, without fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + ) + ) # pytorch impl: bf16 output, with fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + ) # pytorch impl: fp16 output, without fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + ) + ) # pytorch impl: fp16 output, with fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16, - use_fast_accum=True)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True, + ) + ) # cutlass impl: bf16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass impl: bf16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass impl: fp16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.float16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.float16, + ) + ) # cutlass impl: bf16 output, with bias timers.append( - bench_fn(label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16, bias)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) # cutlass impl: fp16 output, with bias timers.append( - bench_fn(label, sub_label, - "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.float16, bias.to(dtype=torch.float16))) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.float16, + bias.to(dtype=torch.float16), + ) + ) return timers -def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: if dtype == torch.int8: return bench_int8(dtype, m, k, n, label, sub_label) if dtype == torch.float8_e4m3fn: @@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]): compare.print() -def run(dtype: torch.dtype, - MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: +def run( + dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]] +) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})") + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})") print_timers(timers) results.extend(timers) @@ -241,10 +365,12 @@ def run(dtype: torch.dtype, # output makers -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[tuple[int, int, int]], - base_description: str, - timestamp=None): +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): print(f"== All Results {base_description} ====") print_timers(data) @@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement], def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, MKNs) @@ -319,7 +444,7 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: pkl.dump(all_data, f) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -344,12 +469,15 @@ def to_torch_dtype(dt): Output: - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) - - parser.add_argument("--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']") + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']", + ) subparsers = parser.add_subparsers(dest="cmd") square_parser = subparsers.add_parser("square_bench") @@ -368,19 +496,19 @@ def to_torch_dtype(dt): range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index fe4d8fdfc066..7e9f5a7fc0f4 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -10,8 +10,9 @@ def to_fp8(tensor: torch.Tensor) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) def to_int8(tensor: torch.Tensor) -> torch.Tensor: @@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor: return tensor.to(dtype=torch.float16) -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 +def make_rand_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device="cuda") * 5 + b = torch.randn((n, k), device="cuda").t() * 5 if dtype == torch.int8: return to_int8(a), to_int8(b) @@ -49,9 +51,7 @@ def prune_to_2_4(tensor): # Create binary mask mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, - index=indices, - src=torch.ones_like(indices, dtype=mask.dtype)) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) # Apply mask and reshape back pruned = reshaped * mask @@ -62,10 +62,11 @@ def prune_to_2_4(tensor): return pruned.reshape(original_shape) -def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 +def make_rand_sparse_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device="cuda") * 5 + b = torch.randn((n, k), device="cuda").t() * 5 b = prune_to_2_4(b.t()).t() @@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, return b_compressed, e, a, b -def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, - m: int, n: int, k: int) -> \ - tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: +def make_n_rand_sparse_tensors( + num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: ABs = [] for _ in range(num_tensors): b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index e7b742d8bec9..08e93837f7dd 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -16,7 +16,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul) + w8a8_block_fp8_matmul, +) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -25,8 +26,9 @@ # bench -def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, - **kwargs) -> TMeasurement: +def bench_fn( + label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs +) -> TMeasurement: min_run_time = 1 globals = { @@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_int8( - dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: """Benchmark INT8-based kernels.""" assert dtype == torch.int8 a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) - azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m,), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32) bench_fns = { - "pytorch_bf16_bf16_bf16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) - ), - "pytorch_fp16_fp16_fp16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), - "cutlass_i8_i8_bf16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), - "cutlass_i8_i8_bf16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, - bias), - "cutlass_i8_i8_bf16_scaled_mm_azp": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj), - "cutlass_i8_i8_bf16_scaled_mm_azp_bias": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, None, bias), - "cutlass_i8_i8_bf16_scaled_mm_azp_pt": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, azp), - "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, azp, bias), + "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.float16), b.to(dtype=torch.float16) + ), + "cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16 + ), + "cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16, bias + ), + "cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias + ), } timers = [] @@ -96,73 +101,73 @@ def bench_int8( def bench_fp8( - dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: """Benchmark FP8-based kernels.""" assert dtype == torch.float8_e4m3fn a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) a_cont = a.contiguous() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - block_scale_a = torch.rand((m, k // 128), - device="cuda", - dtype=torch.float32) - block_scale_b = torch.rand((k // 128, n // 128), - device="cuda", - dtype=torch.float32) + + def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + block_scale_a = torch.rand( + (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32 + ) + block_scale_b = torch.rand( + ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32 + ) block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t() - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) print(m, k, n) bench_fns = { - "pytorch_bf16_bf16_bf16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) - ), - "pytorch_fp16_fp16_fp16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), - "pytorch_fp8_fp8_fp16_scaled_mm": - lambda: torch._scaled_mm( - a, b, scale_a, scale_b, out_dtype=torch.float16), - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": - lambda: torch._scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.float16, - use_fast_accum=True), - "pytorch_fp8_fp8_bf16_scaled_mm": - lambda: torch._scaled_mm( - a, b, scale_a, scale_b, out_dtype=torch.bfloat16), - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": - lambda: torch._scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True), - "cutlass_fp8_fp8_bf16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), - "cutlass_fp8_fp8_fp16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16), - "cutlass_fp8_fp8_bf16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, - bias), - "cutlass_fp8_fp8_fp16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16, - bias.to(dtype=torch.float16)), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": - lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a, - block_scale_b.t(), (128, 128)), - "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": - lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major, - block_scale_b_K_major, torch.float16), + "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.float16), b.to(dtype=torch.float16) + ), + "pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16 + ), + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True + ), + "pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16 + ), + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True + ), + "cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16 + ), + "cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.float16 + ), + "cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16, bias + ), + "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) + ), + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( + a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) + ), + "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( + a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16 + ), } timers = [] @@ -175,13 +180,15 @@ def bench_fp8( return timers -def bench(dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: +def bench( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: if dtype == torch.int8: return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) if dtype == torch.float8_e4m3fn: @@ -195,27 +202,33 @@ def print_timers(timers: Iterable[TMeasurement]): compare.print() -def run(dtype: torch.dtype, - MKNs: Iterable[tuple[int, int, int]], - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: +def run( + dtype: torch.dtype, + MKNs: Iterable[tuple[int, int, int]], + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, - m, - k, - n, - f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})", - bench_kernels=bench_kernels) + timers = bench( + dtype, + m, + k, + n, + f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + bench_kernels=bench_kernels, + ) print_timers(timers) results.extend(timers) return results -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[tuple[int, int, int]], - base_description: str, - timestamp=None): +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): print(f"== All Results {base_description} ====") print_timers(data) @@ -226,8 +239,7 @@ def make_output(data: Iterable[TMeasurement], def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -285,7 +297,7 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: pkl.dump(all_data, f) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -310,19 +322,21 @@ def to_torch_dtype(dt): Output: - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) - parser.add_argument("--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']") + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']", + ) parser.add_argument( "--kernels", nargs="+", type=str, default=None, - help= - "Exact names of the kernels to benchmark. If not set, runs all kernels." + help="Exact names of the kernels to benchmark. If not set, runs all kernels.", ) subparsers = parser.add_subparsers(dest="cmd") @@ -343,19 +357,19 @@ def to_torch_dtype(dt): range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 3d1121df40d0..d31b623a1ee6 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -42,4 +42,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} \ No newline at end of file +} diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index 980e68668911..fce156e1c96c 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -12,39 +12,37 @@ async def forward_request(url, data): async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } - async with session.post(url=url, json=data, - headers=headers) as response: + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with session.post(url=url, json=data, headers=headers) as response: if response.status == 200: # if response.headers.get('Transfer-Encoding') == 'chunked': if True: - async for chunk_bytes in response.content.iter_chunked( - 1024): + async for chunk_bytes in response.content.iter_chunked(1024): yield chunk_bytes else: content = await response.read() yield content -@app.route('/v1/completions', methods=['POST']) +@app.route("/v1/completions", methods=["POST"]) async def handle_request(): try: original_request_data = await request.get_json() prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill - prefill_request['max_tokens'] = 1 + prefill_request["max_tokens"] = 1 # finish prefill - async for _ in forward_request('http://localhost:8100/v1/completions', - prefill_request): + async for _ in forward_request( + "http://localhost:8100/v1/completions", prefill_request + ): continue # return decode - generator = forward_request('http://localhost:8200/v1/completions', - original_request_data) + generator = forward_request( + "http://localhost:8200/v1/completions", original_request_data + ) response = await make_response(generator) response.timeout = None @@ -53,11 +51,12 @@ async def handle_request(): except Exception as e: import sys import traceback + exc_info = sys.exc_info() print("Error occurred in disagg prefill proxy server") print(e) print("".join(traceback.format_exception(*exc_info))) -if __name__ == '__main__': +if __name__ == "__main__": app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py index c2ad4916bf07..fd19b40bf252 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.py +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -8,7 +8,6 @@ class RoundRobinProxy: - def __init__(self, target_ports): self.target_ports = target_ports self.port_cycle = itertools.cycle(self.target_ports) @@ -21,14 +20,15 @@ async def handle_request(self, request): try: # Forward the request async with session.request( - method=request.method, - url=target_url, - headers=request.headers, - data=request.content, + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, ) as response: # Start sending the response - resp = web.StreamResponse(status=response.status, - headers=response.headers) + resp = web.StreamResponse( + status=response.status, headers=response.headers + ) await resp.prepare(request) # Stream the response content @@ -45,11 +45,11 @@ async def handle_request(self, request): async def main(): proxy = RoundRobinProxy([8100, 8200]) app = web.Application() - app.router.add_route('*', '/{path:.*}', proxy.handle_request) + app.router.add_route("*", "/{path:.*}", proxy.handle_request) runner = web.AppRunner(app) await runner.setup() - site = web.TCPSite(runner, 'localhost', 8000) + site = web.TCPSite(runner, "localhost", 8000) await site.start() print("Proxy server started on http://localhost:8000") @@ -58,5 +58,5 @@ async def main(): await asyncio.Event().wait() -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index a7b4b9e8bf30..484d0cb3cba7 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -6,43 +6,41 @@ import pandas as pd if __name__ == "__main__": - data = [] - for name in ['disagg_prefill', 'chunked_prefill']: + for name in ["disagg_prefill", "chunked_prefill"]: for qps in [2, 4, 6, 8]: with open(f"results/{name}-qps-{qps}.json") as f: x = json.load(f) - x['name'] = name - x['qps'] = qps + x["name"] = name + x["qps"] = qps data.append(x) df = pd.DataFrame.from_dict(data) - dis_df = df[df['name'] == 'disagg_prefill'] - chu_df = df[df['name'] == 'chunked_prefill'] + dis_df = df[df["name"] == "disagg_prefill"] + chu_df = df[df["name"] == "chunked_prefill"] - plt.style.use('bmh') - plt.rcParams['font.size'] = 20 + plt.style.use("bmh") + plt.rcParams["font.size"] = 20 for key in [ - 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', - 'median_itl_ms', 'p99_itl_ms' + "mean_ttft_ms", + "median_ttft_ms", + "p99_ttft_ms", + "mean_itl_ms", + "median_itl_ms", + "p99_itl_ms", ]: - fig, ax = plt.subplots(figsize=(11, 7)) - plt.plot(dis_df['qps'], - dis_df[key], - label='disagg_prefill', - marker='o', - linewidth=4) - plt.plot(chu_df['qps'], - chu_df[key], - label='chunked_prefill', - marker='o', - linewidth=4) + plt.plot( + dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4 + ) + plt.plot( + chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4 + ) ax.legend() - ax.set_xlabel('QPS') + ax.set_xlabel("QPS") ax.set_ylabel(key) ax.set_ylim(bottom=0) - fig.savefig(f'results/{key}.png') + fig.savefig(f"results/{key}.png") plt.close(fig) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index 3da583a33448..37a9173a1a93 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -24,10 +24,12 @@ class bench_params_t: dtype: torch.dtype def description(self): - return (f'N {self.num_tokens} ' - f'x D {self.hidden_size} ' - f'x R {self.add_residual} ' - f'x DT {self.dtype}') + return ( + f"N {self.num_tokens} " + f"x D {self.hidden_size} " + f"x R {self.add_residual} " + f"x DT {self.dtype}" + ) def get_bench_params() -> list[bench_params_t]: @@ -38,15 +40,19 @@ def get_bench_params() -> list[bench_params_t]: DTYPES = [torch.bfloat16, torch.float] combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) - bench_params = list(map(lambda x: \ - bench_params_t(x[0], x[1], x[2], x[3]), combinations)) + bench_params = list( + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations) + ) return bench_params # Reference impls -def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): +def unfused_int8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): # Norm torch_out = None if residual is None: @@ -58,9 +64,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, torch_out, _, _ = ops.scaled_int8_quant(torch_out) -def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): +def unfused_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): # Norm torch_out = None if residual is None: @@ -73,22 +82,27 @@ def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, def fused_impl( - rms_norm_layer: RMSNorm, # this stores the weights - x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): - out, _ = ops.rms_norm_dynamic_per_token_quant(x, - rms_norm_layer.weight, - 1e-6, - quant_dtype, - residual=residual) + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): + out, _ = ops.rms_norm_dynamic_per_token_quant( + x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual + ) # Bench functions -def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, - quant_dtype: torch.dtype, label: str, sub_label: str, - fn: Callable, description: str) -> TMeasurement: - +def bench_fn( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor, + quant_dtype: torch.dtype, + label: str, + sub_label: str, + fn: Callable, + description: str, +) -> TMeasurement: min_run_time = 1 globals = { @@ -106,43 +120,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, description=description, ).blocked_autorange(min_run_time=min_run_time) -def bench(params: bench_params_t, label: str, sub_label: str) \ - -> Iterable[TMeasurement]: +def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]: # Make inputs layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) # Make weights layer.weight.data.normal_(mean=1.0, std=0.1) # Make inputs scale = 1 / params.hidden_size - x = torch.randn(params.num_tokens, - params.hidden_size, - dtype=params.dtype, - device='cuda') * scale - residual = (torch.randn_like(x) * scale).to(device='cuda') \ - if params.add_residual else None + x = ( + torch.randn( + params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda" + ) + * scale + ) + residual = ( + (torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None + ) timers = [] # unfused int8 impl. timers.append( - bench_fn(layer, x, residual, torch.int8, label, sub_label, - unfused_int8_impl, "unfused_int8_impl")) + bench_fn( + layer, + x, + residual, + torch.int8, + label, + sub_label, + unfused_int8_impl, + "unfused_int8_impl", + ) + ) # unfused fp8 impl. timers.append( - bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, - unfused_fp8_impl, "unfused_fp8_impl")) + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + label, + sub_label, + unfused_fp8_impl, + "unfused_fp8_impl", + ) + ) # fused int8 impl. timers.append( - bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, - "fused_int8_impl")) + bench_fn( + layer, + x, + residual, + torch.int8, + label, + sub_label, + fused_impl, + "fused_int8_impl", + ) + ) # fused fp8 impl. timers.append( - bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, - fused_impl, "fused_fp8_impl")) + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + label, + sub_label, + fused_impl, + "fused_fp8_impl", + ) + ) print_timers(timers) @@ -157,13 +209,12 @@ def print_timers(timers: Iterable[TMeasurement]): def main(): - torch.set_default_device('cuda') + torch.set_default_device("cuda") bench_params = get_bench_params() timers = [] for bp in tqdm(bench_params): - timers.extend( - bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) print_timers(timers) # pickle all the results @@ -172,5 +223,5 @@ def main(): pkl.dump(timers, f) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index 8d20b91560dd..e9934aa479dd 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -9,32 +9,39 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.aqlm import ( - dequantize_weight, generic_dequantize_gemm, get_int_dtype, - optimized_dequantize_gemm) + dequantize_weight, + generic_dequantize_gemm, + get_int_dtype, + optimized_dequantize_gemm, +) from vllm.utils import FlexibleArgumentParser -os.environ['CUDA_VISIBLE_DEVICES'] = '0' +os.environ["CUDA_VISIBLE_DEVICES"] = "0" def torch_mult( - input: torch.Tensor, # [..., in_features] - weights: torch.Tensor, - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + weights: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, ) -> torch.Tensor: output = F.linear(input, weights) return output def dequant_out_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) if bias is None: @@ -46,40 +53,42 @@ def dequant_out_scale( flattened_output *= b_scales return flattened_output.view(orig_shape) else: - b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( - -1, weights.shape[1]) + b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) weights *= b_scales return F.linear(input, weights, bias) def dequant_weight_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( - -1, weights.shape[1]) + b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) weights *= b_scales return F.linear(input, weights, bias) def dequant_no_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) return F.linear(input, weights, bias) @@ -89,23 +98,26 @@ def dequant_no_scale( # the generic pytorch version. # Just visual comparison. def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: - n = int(parts.sum().item()) - device = torch.device('cuda:0') + device = torch.device("cuda:0") code_range = (1 << bits) // 2 ingroups = 8 - codes = torch.randint(-code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device) + codes = torch.randint( + -code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device, + ) - codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device) + codebooks = torch.randn( + size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device, + ) count = 0 for index in range(16): @@ -138,24 +150,25 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: def main(): - parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") # Add arguments - parser.add_argument("--nbooks", - type=int, - default=1, - help="Number of codebooks (default: 1)") - parser.add_argument("--bits", - type=int, - default=16, - help="Number of bits per code element (default: 16)") + parser.add_argument( + "--nbooks", type=int, default=1, help="Number of codebooks (default: 1)" + ) + parser.add_argument( + "--bits", + type=int, + default=16, + help="Number of bits per code element (default: 16)", + ) parser.add_argument( "--test", type=bool, default=False, help="Run the decompression/dequant tester rather than benchmarking " - "(default: False)") + "(default: False)", + ) # Parse the arguments args = parser.parse_args() @@ -165,7 +178,7 @@ def main(): bits = args.bits if args.test: - dequant_test(4096, torch.tensor((4096, )), nbooks, bits) + dequant_test(4096, torch.tensor((4096,)), nbooks, bits) return # Otherwise, benchmark. @@ -184,31 +197,54 @@ def main(): with open(filename, "w") as f: sys.stdout = f - print('m | k | n | n parts', end='') + print("m | k | n | n parts", end="") for method in methods: - print(f" | {method.__name__.replace('_', ' ')} (ยตs)", end='') - print('') + print(f" | {method.__name__.replace('_', ' ')} (ยตs)", end="") + print("") # These are reasonable prefill sizes. - ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), - (4096, (11008, 11008)), (11008, (4096, ))) + ksandpartions = ( + (4096, (4096, 4096, 4096)), + (4096, (4096,)), + (4096, (11008, 11008)), + (11008, (4096,)), + ) # reasonable ranges for m. for m in [ - 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, - 128, 256, 512, 1024, 1536, 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 10, + 12, + 14, + 16, + 24, + 32, + 48, + 52, + 56, + 64, + 96, + 112, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ]: - print(f'{m}', file=sys.__stdout__) + print(f"{m}", file=sys.__stdout__) for ksp in ksandpartions: - run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, - methods) + run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods) sys.stdout = sys.__stdout__ -def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, - methods): - +def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods): # I didn't see visible improvements from increasing these, but feel free :) num_warmup_trials = 1 num_trials = 1 @@ -229,7 +265,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, ) n = parts.sum().item() - print(f'{m} | {k} | {n} | {parts.tolist()}', end='') + print(f"{m} | {k} | {n} | {parts.tolist()}", end="") for method in methods: best_time_us = 1e20 @@ -249,32 +285,36 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, if kernel_dur_us < best_time_us: best_time_us = kernel_dur_us - print(f' | {kernel_dur_us:.0f}', end='') + print(f" | {kernel_dur_us:.0f}", end="") - print('') + print("") -def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor, - nbooks: int, bits: int, method) -> float: - +def run_timing( + num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method +) -> float: n = int(parts.sum().item()) - device = torch.device('cuda:0') + device = torch.device("cuda:0") input = torch.randn((1, m, k), dtype=torch.float16, device=device) code_range = (1 << bits) // 2 ingroups = 8 - codes = torch.randint(-code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device) - - codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device) + codes = torch.randint( + -code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device, + ) + + codebooks = torch.randn( + size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device, + ) scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py index b23b4f3ea685..d40ab70ec539 100644 --- a/benchmarks/kernels/benchmark_bitblas.py +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -3,27 +3,33 @@ # Licensed under the MIT License. from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - MINIMUM_BITBLAS_VERSION) + MINIMUM_BITBLAS_VERSION, +) try: import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: - raise ImportError("bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError as e: bitblas_import_exception = e - raise ValueError("Trying to use the bitblas backend, but could not import" - f"with the following error: {bitblas_import_exception}. " - "Please install bitblas through the following command: " - f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" - ) from bitblas_import_exception + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target from vllm.utils import FlexibleArgumentParser parser = FlexibleArgumentParser( - description="Benchmark BitBLAS int4 on a specific target.") + description="Benchmark BitBLAS int4 on a specific target." +) # Add arguments to the parser parser.add_argument( @@ -32,10 +38,9 @@ default=auto_detect_nvidia_target(), help="Specify the target device for benchmarking.", ) -parser.add_argument("--group_size", - type=int, - default=None, - help="Group size for grouped quantization.") +parser.add_argument( + "--group_size", type=int, default=None, help="Group size for grouped quantization." +) parser.add_argument( "--A_dtype", type=str, @@ -82,17 +87,17 @@ choices=["nt", "nn"], help="Matrix layout, 'nt' for non-transpose A and transpose W.", ) -parser.add_argument("--with_bias", - action="store_true", - help="Include bias in the benchmark.") +parser.add_argument( + "--with_bias", action="store_true", help="Include bias in the benchmark." +) parser.add_argument( "--with_scaling", action="store_true", help="Include scaling factor in the quantization.", ) -parser.add_argument("--with_zeros", - action="store_true", - help="Include zeros in the quantization.") +parser.add_argument( + "--with_zeros", action="store_true", help="Include zeros in the quantization." +) parser.add_argument( "--zeros_mode", type=str, @@ -170,8 +175,7 @@ ] # Build test shapes with all the shared arguments -test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) - for shape in shapes] +test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes] benchmark_sets = [] benchmark_sets.extend(test_shapes) @@ -206,12 +210,12 @@ func_name = args_split[0] input_args_str = "-".join(args_split[1:]) col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2) - col_widths[1] = max(col_widths[1], - len(input_args_str) + 2, - len(headers[1]) + 2) - col_widths[2] = max(col_widths[2], - len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, - len(headers[2]) + 2) + col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2) + col_widths[2] = max( + col_widths[2], + len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, + len(headers[2]) + 2, + ) # break only if you want to measure widths from a single example; # otherwise, let it loop over all items. @@ -232,5 +236,6 @@ f"{values['BitBLAS_top20_latency']:.3f} ms", ] row_str = "".join( - [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]) + [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)] + ) print(row_str) diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py new file mode 100644 index 000000000000..d39d8a6e3aba --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -0,0 +1,489 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe +kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit +activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8) +and 16-bit activations. +""" + +import nvtx +import torch +import torch.utils.benchmark as benchmark + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.scalar_type import scalar_types +from vllm.utils import FlexibleArgumentParser + +WEIGHT_SHAPES_MOE = { + "nvidia/DeepSeek-R1-FP4": [ + [256, 8, 2048, 7168], + ], +} + +DEFAULT_MODELS = [ + "nvidia/DeepSeek-R1-FP4", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +def bench_run( + results: list[benchmark.Measurement], + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): + label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( + model, num_experts, topk, per_act_token, per_out_ch, mkn + ) + ) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + _, a_fp8_scale = ops.scaled_fp8_quant(a) + + w1_fp8q = torch.empty( + (num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn + ) + w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn) + w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert]) + + w1_fp8q_notransp = w1_fp8q.clone() + w2_fp8q_notransp = w2_fp8q.clone() + w1_fp8q = w1_fp8q.transpose(1, 2) + w2_fp8q = w2_fp8q.transpose(1, 2) + + score = torch.randn((m, num_experts), device=device, dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + quant_blocksize = 16 + w1_blockscale = torch.empty( + (num_experts, 2 * n, k // quant_blocksize), + device=device, + dtype=torch.float8_e4m3fn, + ) + w2_blockscale = torch.empty( + (num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn + ) + + # n_b_scales = 2 * n if per_out_ch else 1 + # k_b_scales = k if per_out_ch else 1 + w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8) + w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8) + + w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) + w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) + a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) + a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_e = w1[expert] + w2_e = w2[expert] + w1_amax = torch.abs(w1_e).max().to(torch.float32) + w2_amax = torch.abs(w2_e).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1_e, w1_gs[expert] + ) + + w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2_e, w2_gs[expert] + ) + + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, + num_repeats: int, + ): + for _ in range(num_repeats): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + + def run_cutlass_moe_fp4( + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + w1_gs: torch.Tensor, + w2_gs: torch.Tensor, + a1_gs: torch.Tensor, + a2_gs: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + num_repeats: int, + ): + for _ in range(num_repeats): + with nvtx.annotate("cutlass_moe_fp4", color="green"): + cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + + def run_cutlass_from_graph( + a: torch.Tensor, + a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + ): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_alphas, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + + def run_triton_from_graph( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, + ): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph( + a, + w1_fp8q_notransp, + w2_fp8q_notransp, + topk_weights, + topk_ids, + w1_fp8scale, + w2_fp8scale, + a_fp8_scale, + ) + torch.cuda.synchronize() + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "w1": w1, + "w2": w2, + "score": score, + "topk": topk, + "w1_fp8q_notransp": w1_fp8q_notransp, + "w2_fp8q_notransp": w2_fp8q_notransp, + "w1_fp8scale": w1_fp8scale, + "w2_fp8scale": w2_fp8scale, + "a_fp8_scale": a_fp8_scale, + # Cutlass params + "a": a, + "a1_gscale": a1_gs, + "w1_fp4": w1_fp4, + "w1_blockscale": w1_blockscale, + "w1_alphas": w1_gs, + "a2_gscale": a2_gs, + "w2_fp4": w2_fp4, + "w2_blockscale": w2_blockscale, + "w2_alphas": w2_gs, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "m": m, + "n": n, + "k": k, + "e": num_experts, + "device": device, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe_fp4": run_cutlass_moe_fp4, + "replay_graph": replay_graph, + } + + # Warmup + run_triton_moe( + a, + w1_fp8q_notransp, + w2_fp8q_notransp, + topk_weights, + topk_ids, + w1_fp8scale, + w2_fp8scale, + a_fp8_scale, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + + run_cutlass_moe_fp4( + a, + w1_fp4, + w2_fp4, + w1_blockscale, + w2_blockscale, + w1_gs, + w2_gs, + a1_gs, + a2_gs, + topk_weights, + topk_ids, + m, + n, + k, + num_experts, + device, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + bench_run( + results, + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches" + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index c92ea43e8260..2197bceabe6c 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -6,14 +6,18 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, - fused_experts, - fused_topk) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + cutlass_moe_fp8, + fused_experts, + fused_topk, +) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = [ - "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", - "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" + "nm-testing/Mixtral-8x7B-Instruct-v0.1", + "nm-testing/deepseekv2-lite", + "ibm-granite/granite-3.0-1b-a400m", + "ibm-granite/granite-3.0-3b-a800m", ] DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] @@ -24,19 +28,27 @@ def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) -def bench_run(results: list[benchmark.Measurement], model: str, - num_experts: int, topk: int, per_act_token: bool, - per_out_ch: bool, mkn: tuple[int, int, int]): +def bench_run( + results: list[benchmark.Measurement], + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): label = "Quant Matmul" sub_label = ( - "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " - "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, - mkn)) + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( + model, num_experts, topk, per_act_token, per_out_ch, mkn + ) + ) print(f"Testing: {sub_label}") @@ -50,35 +62,17 @@ def bench_run(results: list[benchmark.Measurement], model: str, _, a_scale = ops.scaled_fp8_quant(a) - w1_q = torch.empty((num_experts, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((num_experts, k, n), - device="cuda", - dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((num_experts, 1, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((num_experts, 1, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_experts, ), - 2 * n, - device="cuda", - dtype=torch.int64) - ab_strides2 = torch.full((num_experts, ), - n, - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) + w1_q = torch.empty( + (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn + ) + w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) + + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) for expert in range(num_experts): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) @@ -91,82 +85,120 @@ def bench_run(results: list[benchmark.Measurement], model: str, score = torch.randn((m, num_experts), device="cuda", dtype=dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score, topk, renormalize=False) + a, score, topk, renormalize=False + ) - def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a_scale: torch.Tensor, num_repeats: int): + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_scale: torch.Tensor, + num_repeats: int, + ): for _ in range(num_repeats): - fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor, - num_repeats: int): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) + + def run_cutlass_moe( + a: torch.Tensor, + a_scale: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + num_repeats: int, + ): for _ in range(num_repeats): - cutlass_moe_fp8(a, - w1, - w2, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) + cutlass_moe_fp8( + a, + w1, + w2, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale, + ) def run_cutlass_from_graph( - a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + a: torch.Tensor, + a_scale: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + ): with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, w1_scale: torch.Tensor, - w2_scale: torch.Tensor, a_scale: torch.Tensor): + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return cutlass_moe_fp8( + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale, + ) + + def run_triton_from_graph( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_scale: torch.Tensor, + ): with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) def replay_graph(graph, num_repeats): for _ in range(num_repeats): @@ -176,16 +208,35 @@ def replay_graph(graph, num_repeats): cutlass_stream = torch.cuda.Stream() cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): - run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, c_strides1, - ab_strides2, c_strides2) + run_cutlass_from_graph( + a, + a_scale, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) torch.cuda.synchronize() triton_stream = torch.cuda.Stream() triton_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(triton_graph, stream=triton_stream): - run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, - topk_ids, w1_scale, w2_scale, a_scale) + run_triton_from_graph( + a, + w1_q_notransp, + w2_q_notransp, + topk_weights, + topk_ids, + w1_scale, + w2_scale, + a_scale, + ) torch.cuda.synchronize() min_run_time = 5 @@ -225,18 +276,27 @@ def replay_graph(graph, num_repeats): } # Warmup - run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, - w1_scale, w2_scale, a_scale, num_warmup) + run_triton_moe( + a, + w1_q_notransp, + w2_q_notransp, + topk_weights, + topk_ids, + w1_scale, + w2_scale, + a_scale, + num_warmup, + ) results.append( benchmark.Timer( - stmt= - "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="triton_moe", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup replay_graph(triton_graph, num_warmup) @@ -248,22 +308,35 @@ def replay_graph(graph, num_repeats): label=label, sub_label=sub_label, description="triton_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup - run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, - num_warmup) + run_cutlass_moe( + a, + a_scale, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + num_warmup, + ) results.append( benchmark.Timer( - stmt= - "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="grouped_gemm_moe", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup replay_graph(cutlass_graph, num_warmup) @@ -275,7 +348,8 @@ def replay_graph(graph, num_repeats): label=label, sub_label=sub_label, description="grouped_gemm_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) def main(args): @@ -303,8 +377,15 @@ def main(args): for per_out_ch in PER_OUT_CH_OPTS: for size_m in DEFAULT_BATCH_SIZES: mkn = (size_m, size_k, size_n) - bench_run(results, model, num_experts, topk, - per_act_token, per_out_ch, mkn) + bench_run( + results, + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) compare = benchmark.Compare(results) compare.print() @@ -312,7 +393,8 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") + description="Benchmark Marlin across specified models/shapes/batches" + ) parser.add_argument( "--models", nargs="+", @@ -320,21 +402,14 @@ def main(args): default=DEFAULT_MODELS, choices=WEIGHT_SHAPES_MOE.keys(), ) - parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) - parser.add_argument("--limit-per-act-token", - nargs="+", - type=int, - default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) args = parser.parse_args() diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index e12d74c01e43..f21ca97eeb8a 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -10,14 +10,16 @@ @torch.inference_mode() -def main(num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, - seed: int = 0, - do_profile: bool = False, - num_warmup_iters: int = 5, - num_iters: int = 100) -> None: +def main( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: current_platform.seed_everything(seed) torch.set_default_device("cuda") @@ -56,33 +58,35 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': - parser = FlexibleArgumentParser( - description="Benchmark the layernorm kernel.") +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.") parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--add-residual", action="store_true") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument("--num-warmup-iters", type=int, default=5) - parser.add_argument("--num-iters", - type=int, - default=100, - help="Number of benchmark iterations. " - "If --profile is set, this number is ignored") + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored", + ) args = parser.parse_args() print(args) - main(num_tokens=args.num_tokens, - hidden_size=args.hidden_size, - add_residual=args.add_residual, - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - seed=args.seed, - do_profile=args.profile, - num_warmup_iters=args.num_warmup_iters, - num_iters=args.num_iters) + main( + num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + ) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index d382ede10b41..6c1284930c1e 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -20,18 +20,36 @@ from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, - lora_shrink) - from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, - _LORA_B_PTR_DICT) + from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_TP_SIZES = [1] DEFAULT_BATCH_SIZES = [ - 1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024, - 2048, 3072, 4096, 5120, 6144, 7168, 8192 + 1, + 16, + 32, + 64, + 128, + 192, + 256, + 320, + 384, + 448, + 512, + 640, + 768, + 896, + 1024, + 2048, + 3072, + 4096, + 5120, + 6144, + 7168, + 8192, ] DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384] DEFAULT_LORA_RANKS = [16] @@ -52,12 +70,9 @@ def dtype_to_str(dtype: torch.dtype): raise ValueError(f"Unsupported dtype {dtype}") -def make_rand_lora_weight_tensor(k: int, - n: int, - num_loras: int, - dtype: torch.dtype, - device: str = "cuda") -> torch.Tensor: - +def make_rand_lora_weight_tensor( + k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda" +) -> torch.Tensor: # LoRA weights column major return torch.rand((num_loras, n, k), dtype=dtype).to(device) @@ -78,18 +93,15 @@ def make_rand_tensors( A = torch.rand(a_shape, dtype=a_dtype).to(device) # LoRA weights column major - Bs = [ - torch.rand(b_shape, dtype=b_dtype).to(device) - for _ in range(num_slices) - ] + Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)] C = torch.zeros(c_shape, dtype=c_dtype).to(device) return A, Bs, C -def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, - sort_by_lora_id: bool, - device: str) -> torch.Tensor: +def make_prompt_lora_mapping( + num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str +) -> torch.Tensor: """ All prompts are mapped to a LoRA ID in range [0, num_active_loras). where 0 refers to first lora, 1 refers to second lora and so on. @@ -97,9 +109,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, assert num_active_loras > 0 if not sort_by_lora_id: - return torch.randint(0, - num_active_loras, (num_prompts, ), - dtype=torch.long) + return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long) # Divide LoRAs equally and in order. part_size = num_prompts // num_active_loras @@ -110,14 +120,18 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, while len(prompt_lora_mapping) < num_prompts: prompt_lora_mapping.extend([lora_id] * part_size) lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id - return torch.tensor(prompt_lora_mapping[:num_prompts], - dtype=torch.long, - device=device) - - -def make_token_lora_mapping(num_tokens: int, num_prompts: int, - prompt_lora_mapping: torch.Tensor, - seq_len_tensor: torch.Tensor, device: str): + return torch.tensor( + prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device + ) + + +def make_token_lora_mapping( + num_tokens: int, + num_prompts: int, + prompt_lora_mapping: torch.Tensor, + seq_len_tensor: torch.Tensor, + device: str, +): """ Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor """ @@ -136,11 +150,15 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int, return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) -def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, - lora_weights: list[torch.Tensor], - seq_lens_cpu: torch.Tensor, - prompt_lora_mapping_cpu: torch.Tensor, scaling: float, - add_inputs: Optional[bool]): +def ref_group_gemm( + ref_out: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + seq_lens_cpu: torch.Tensor, + prompt_lora_mapping_cpu: torch.Tensor, + scaling: float, + add_inputs: Optional[bool], +): """ Torch group gemm reference implementation to test correctness of benchmarking operations. @@ -149,7 +167,7 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, out_list = [] current_offset = 0 for lora_index, b_length in zip(range(batches), seq_lens_cpu): - x = input[current_offset:b_length + current_offset, :] + x = input[current_offset : b_length + current_offset, :] current_offset += b_length w = lora_weights[prompt_lora_mapping_cpu[lora_index]] result = torch.nn.functional.linear(x, w) @@ -168,6 +186,7 @@ class OpType(Enum): """ LoRA Ops to benchmark and its properties. """ + LORA_SHRINK = auto() LORA_EXPAND = auto() @@ -188,8 +207,9 @@ def is_expand_fn(self) -> bool: def num_slices(self) -> list[int]: return [1, 2, 3] - def mkn(self, batch_size: int, seq_length: int, hidden_size: int, - lora_rank: int) -> tuple[int, int, int]: + def mkn( + self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int + ) -> tuple[int, int, int]: num_tokens = batch_size * seq_length if self.is_shrink_fn(): m = num_tokens @@ -203,7 +223,7 @@ def mkn(self, batch_size: int, seq_length: int, hidden_size: int, return m, k, n def matmul_dtypes( - self, op_dtype: torch.dtype + self, op_dtype: torch.dtype ) -> tuple[torch.dtype, torch.dtype, torch.dtype]: """ return a type, b type and c type for A x B = C @@ -215,9 +235,14 @@ def matmul_dtypes( return torch.float32, op_dtype, op_dtype def matmul_shapes( - self, batch_size: int, seq_length: int, hidden_size: int, - lora_rank: int, num_loras: int, - num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]: + self, + batch_size: int, + seq_length: int, + hidden_size: int, + lora_rank: int, + num_loras: int, + num_slices: int, + ) -> tuple[tuple[int], tuple[int], tuple[int]]: """ Given num_slices, return the shapes of the A, B, and C matrices in A x B = C, for the op_type @@ -241,31 +266,38 @@ def bench_fn(self) -> Callable: raise ValueError(f"Unrecognized optype {self}") - def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, - lora_weights: list[torch.Tensor], - **kwargs) -> Callable: + def run_ref_group_gemm( + self, + output: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + **kwargs, + ) -> Callable: """Each benchmark operation expects the input, lora_weights and outputs - in a slightly different format. Refer to self.matmul_shapes(). - run_ref_group_gemm accounts for those differences in executing a - reference group gemm for correctness testing. + in a slightly different format. Refer to self.matmul_shapes(). + run_ref_group_gemm accounts for those differences in executing a + reference group gemm for correctness testing. """ w_dtype = lora_weights[0].dtype num_slices = len(lora_weights) if self in [OpType.LORA_SHRINK]: for slice_idx in range(num_slices): - ref_group_gemm(ref_out=output[slice_idx, :], - input=input, - lora_weights=lora_weights[slice_idx], - **kwargs) + ref_group_gemm( + ref_out=output[slice_idx, :], + input=input, + lora_weights=lora_weights[slice_idx], + **kwargs, + ) elif self in [OpType.LORA_EXPAND]: hidden_size = lora_weights[0].shape[1] for slice_idx in range(num_slices): slice_offset = slice_idx * hidden_size ref_group_gemm( - ref_out=output[:, slice_offset:slice_offset + hidden_size], + ref_out=output[:, slice_offset : slice_offset + hidden_size], input=input[slice_idx].clone().to(dtype=w_dtype), lora_weights=lora_weights[slice_idx], - **kwargs) + **kwargs, + ) else: raise ValueError(f"Unrecognized optype {self}") @@ -275,6 +307,7 @@ class BenchmarkContext: """ LoRA benchmark context """ + batch_size: int hidden_size: int num_loras: int @@ -299,17 +332,18 @@ def bench_label(self) -> str: return f"lora-{self.dtype}" def bench_sublabel(self, op_type: OpType) -> str: - m, k, n = op_type.mkn(self.batch_size, self.seq_length, - self.hidden_size, self.lora_rank) + m, k, n = op_type.mkn( + self.batch_size, self.seq_length, self.hidden_size, self.lora_rank + ) desc = { - 'bs': self.batch_size, - 'sl': self.seq_length, - 'm': m, - 'k': k, - 'n': n, - 'num_loras': self.num_loras, - 'sort_by_lora': self.sort_by_lora_id, - 'num_slices': self.num_slices, + "bs": self.batch_size, + "sl": self.seq_length, + "m": m, + "k": k, + "n": n, + "num_loras": self.num_loras, + "sort_by_lora": self.sort_by_lora_id, + "num_slices": self.num_slices, } return json.dumps(desc) @@ -319,6 +353,7 @@ class BenchmarkTensors: """ Input/Output tensors used for benchmarks """ + # matmul tensors input: torch.Tensor lora_weights_lst: list[torch.Tensor] @@ -330,23 +365,29 @@ class BenchmarkTensors: prompt_lora_mapping: torch.Tensor def io_types(self) -> str: - return (f"{dtype_to_str(self.input.dtype)}x" - f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" - f"{dtype_to_str(self.output.dtype)}") + return ( + f"{dtype_to_str(self.input.dtype)}x" + f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" + f"{dtype_to_str(self.output.dtype)}" + ) @staticmethod - def make(ctx: BenchmarkContext, - op_type: OpType, - device: str = "cuda") -> "BenchmarkTensors": - + def make( + ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" + ) -> "BenchmarkTensors": # Make input / output matmul tensors. a_shape, b_shape, c_shape = op_type.matmul_shapes( - ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank, - ctx.num_loras, ctx.num_slices) + ctx.batch_size, + ctx.seq_length, + ctx.hidden_size, + ctx.lora_rank, + ctx.num_loras, + ctx.num_slices, + ) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) - input_tensor, lora_weights, output_tensor = \ - make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type, - num_slices = ctx.num_slices) + input_tensor, lora_weights, output_tensor = make_rand_tensors( + a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices + ) # Make metadata tensors. # Keep the metadata tensors in the CPU for further processing if needed. @@ -356,27 +397,38 @@ def make(ctx: BenchmarkContext, # Make metadata tensors involved in correctness testing. # Prepare seq lens tensor - seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, - (ctx.batch_size, )) + seq_len_tensor = torch.randint( + ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,) + ) assert total_tokens == seq_len_tensor.sum() # Prepare prompt lora indices tensor prompt_lora_indices_tensor = make_prompt_lora_mapping( - ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") + ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu" + ) # Make LoRAKernelMeta token_lora_indices_tensor = make_token_lora_mapping( - total_tokens, ctx.batch_size, prompt_lora_indices_tensor, - seq_len_tensor, "cpu") + total_tokens, + ctx.batch_size, + prompt_lora_indices_tensor, + seq_len_tensor, + "cpu", + ) lora_kernel_meta = LoRAKernelMeta.make( max_loras=ctx.num_loras, max_num_tokens=token_lora_indices_tensor.size(0), - device="cpu") - lora_kernel_meta.prepare_tensors( - token_lora_mapping=token_lora_indices_tensor) - - return BenchmarkTensors(input_tensor, lora_weights, output_tensor, - lora_kernel_meta, seq_len_tensor, - prompt_lora_indices_tensor) + device="cpu", + ) + lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor) + + return BenchmarkTensors( + input_tensor, + lora_weights, + output_tensor, + lora_kernel_meta, + seq_len_tensor, + prompt_lora_indices_tensor, + ) def sanity_check(self) -> None: """ @@ -386,7 +438,7 @@ def sanity_check(self) -> None: # check metadata tensors assert torch.sum(self.seq_lens) == num_tokens num_seqs = self.seq_lens.shape[0] - #assert self.seq_start_loc.shape[0] == num_seqs + # assert self.seq_start_loc.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens @@ -430,8 +482,11 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]: _, num_tokens, _, num_slices = self.metadata() # Sanity check matrix shapes. - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) # Expected input shape [num_tokens, hidden_size] assert len(i_shape) == 2 assert i_shape[0] == num_tokens @@ -445,16 +500,17 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]: assert o_shape == (num_slices, num_tokens, lora_rank) return { - 'inputs': self.input, - 'lora_a_weights': self.lora_weights_lst, - 'output_tensor': self.output, - 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, - 'token_indices_sorted_by_lora_ids': - self.lora_kernel_meta.token_indices_sorted_by_lora_ids, - 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, - 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, - 'lora_ids': self.lora_kernel_meta.active_lora_ids, - 'scaling': 1.0, + "inputs": self.input, + "lora_a_weights": self.lora_weights_lst, + "output_tensor": self.output, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "token_indices_sorted_by_lora_ids": ( + self.lora_kernel_meta.token_indices_sorted_by_lora_ids + ), + "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, + "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, + "lora_ids": self.lora_kernel_meta.active_lora_ids, + "scaling": 1.0, } def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: @@ -464,8 +520,11 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: _, num_tokens, _, num_slices = self.metadata() # Sanity check matrix shapes. - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) # Expected input shape : [num_slices, num_tokens, lora_rank] assert len(i_shape) == 3 assert i_shape[0] == num_slices @@ -480,22 +539,23 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: assert o_shape == (num_tokens, hidden_size * num_slices) return { - 'inputs': self.input, - 'lora_b_weights': self.lora_weights_lst, - 'output_tensor': self.output, - 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, - 'token_indices_sorted_by_lora_ids': - self.lora_kernel_meta.token_indices_sorted_by_lora_ids, - 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, - 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, - 'lora_ids': self.lora_kernel_meta.active_lora_ids, - 'offset_start': 0, - 'add_inputs': add_inputs, + "inputs": self.input, + "lora_b_weights": self.lora_weights_lst, + "output_tensor": self.output, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "token_indices_sorted_by_lora_ids": ( + self.lora_kernel_meta.token_indices_sorted_by_lora_ids + ), + "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, + "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, + "lora_ids": self.lora_kernel_meta.active_lora_ids, + "offset_start": 0, + "add_inputs": add_inputs, } - def bench_fn_kwargs(self, - op_type: OpType, - add_inputs: Optional[bool] = None) -> dict[str, Any]: + def bench_fn_kwargs( + self, op_type: OpType, add_inputs: Optional[bool] = None + ) -> dict[str, Any]: if op_type.is_shrink_fn(): assert add_inputs is None else: @@ -507,8 +567,9 @@ def bench_fn_kwargs(self, return self.as_lora_expand_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") - def test_correctness(self, op_type: OpType, - expand_fn_add_inputs: Optional[bool]) -> bool: + def test_correctness( + self, op_type: OpType, expand_fn_add_inputs: Optional[bool] + ) -> bool: """ Test correctness of op_type implementation against a grouped gemm reference implementation. @@ -518,8 +579,7 @@ def test_correctness(self, op_type: OpType, ref_output = self.output.clone() self.output.zero_() - op_type.bench_fn()( - **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) + op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) op_type.run_ref_group_gemm( ref_output, @@ -528,7 +588,8 @@ def test_correctness(self, op_type: OpType, seq_lens_cpu=seq_lens_cpu, prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, scaling=1.0, - add_inputs=expand_fn_add_inputs) + add_inputs=expand_fn_add_inputs, + ) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -539,13 +600,14 @@ def test_correctness(self, op_type: OpType, return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) -def bench_optype(ctx: BenchmarkContext, - arg_pool_size: int, - op_type: OpType, - cuda_graph_nops: Optional[int] = None, - expand_fn_add_inputs: Optional[bool] = None, - test_correctness: bool = False) -> TMeasurement: - +def bench_optype( + ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None, + expand_fn_add_inputs: Optional[bool] = None, + test_correctness: bool = False, +) -> TMeasurement: assert arg_pool_size >= 1 if op_type.is_shrink_fn(): assert expand_fn_add_inputs is None @@ -553,17 +615,17 @@ def bench_optype(ctx: BenchmarkContext, assert expand_fn_add_inputs is not None # BenchmarkContext -> BenchmarkTensors - bench_tensors : list[BenchmarkTensors] = \ - [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) + ] for bt in bench_tensors: bt.sanity_check() # Test correctness of our implementation. if test_correctness: - assert all([ - bt.test_correctness(op_type, expand_fn_add_inputs) - for bt in bench_tensors - ]) + assert all( + [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] + ) # BenchmarkTensors -> dict (kwargs) kwargs_list = [ @@ -585,40 +647,49 @@ def bench_optype(ctx: BenchmarkContext, for k, v in _kwargs.items(): kwargs[k].values.append(v) - describe_args = (f"add_inputs={expand_fn_add_inputs}" - if expand_fn_add_inputs is not None else "") - description = ( - f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") + describe_args = ( + f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "" + ) + description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})" cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) timer = None - with Bench(cuda_graph_params, - ctx.bench_label(), ctx.bench_sublabel(op_type), description, - op_type.bench_fn(), **kwargs) as bench: + with Bench( + cuda_graph_params, + ctx.bench_label(), + ctx.bench_sublabel(op_type), + description, + op_type.bench_fn(), + **kwargs, + ) as bench: timer = bench.run() return timer -def bench_torch_mm(ctx: BenchmarkContext, - arg_pool_size: int, - op_type: OpType, - cuda_graph_nops: Optional[int] = None) -> TMeasurement: +def bench_torch_mm( + ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None, +) -> TMeasurement: """ Benchmark basic torch.mm as a roofline. When all the input tokens have the same LoRA ID, the LoRA kernels are just - a matmul. This torch.mm benchmark serves as a roofline for that case. + a matmul. This torch.mm benchmark serves as a roofline for that case. input op_type is used in determining the m, k, n dimensions for the matmul. """ - batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size, - ctx.hidden_size, - ctx.lora_rank, - ctx.seq_length, - ctx.dtype) + batch_size, hidden_size, lora_rank, seq_length, dtype = ( + ctx.batch_size, + ctx.hidden_size, + ctx.lora_rank, + ctx.seq_length, + ctx.dtype, + ) m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) # For a fairer comparison. @@ -632,18 +703,24 @@ def bench_torch_mm(ctx: BenchmarkContext, Cs.append(torch.rand((m, n), dtype=dtype).to("cuda")) # Make torch.mm kwargs - mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)} + mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)} description = ( f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}" f"x{dtype_to_str(dtype)}" - f"=>{dtype_to_str(dtype)})") + f"=>{dtype_to_str(dtype)})" + ) cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) - with Bench(cuda_graph_params, ctx.bench_label(), - ctx.bench_sublabel(op_type), description, torch.mm, - **mm_kwargs) as bench: + with Bench( + cuda_graph_params, + ctx.bench_label(), + ctx.bench_sublabel(op_type), + description, + torch.mm, + **mm_kwargs, + ) as bench: return bench.run() @@ -660,8 +737,7 @@ def use_cuda_graph_recommendation() -> str: """ -def print_timers(timers: list[TMeasurement], - args: Optional[argparse.Namespace] = None): +def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None): compare = TBenchmark.Compare(timers) compare.print() @@ -670,22 +746,23 @@ def print_timers(timers: list[TMeasurement], f"Note : The timings reported above is for {args.cuda_graph_nops} " "consecutive invocations of the benchmarking functions. " f"Please divide by {args.cuda_graph_nops} for single invocation " - "timings.") + "timings." + ) - print("Note on Comparison with torch.mm : The torch.mm numbers are " - "benchmark numbers of a simple matmul emulating the single lora " - "case. It is provided as a roofline for comparing our LoRA Kernel " - "implementations. It is expected that the LoRA kernels will be " - "slower than torch.mm in cases where num_loras is big. But for " - "small num_loras the goal should be to match the torch.mm numbers.") + print( + "Note on Comparison with torch.mm : The torch.mm numbers are " + "benchmark numbers of a simple matmul emulating the single lora " + "case. It is provided as a roofline for comparing our LoRA Kernel " + "implementations. It is expected that the LoRA kernels will be " + "slower than torch.mm in cases where num_loras is big. But for " + "small num_loras the goal should be to match the torch.mm numbers." + ) def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): - if args.cuda_graph_nops is not None: assert args.cuda_graph_nops > 0 - print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA " - "Graph") + print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph") else: print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") @@ -697,21 +774,30 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): for bench_op in bench_ops: for num_slices in bench_op.num_slices(): _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( - num_slices) + num_slices + ) # Benchmark torch.mm as a roofline seq_len_timers.append( - bench_torch_mm(_ctx, args.arg_pool_size, bench_op, - args.cuda_graph_nops)) + bench_torch_mm( + _ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops + ) + ) # Benchmark bench_op - expand_fn_add_inputs = [ - None - ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + expand_fn_add_inputs = ( + [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + ) for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( - bench_optype(_ctx, args.arg_pool_size, bench_op, - args.cuda_graph_nops, add_input_arg, - args.test_correctness)) + bench_optype( + _ctx, + args.arg_pool_size, + bench_op, + args.cuda_graph_nops, + add_input_arg, + args.test_correctness, + ) + ) print_timers(seq_len_timers) timers.extend(seq_len_timers) @@ -733,13 +819,17 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): pickle.dump(timers, f) -def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], - args: argparse.Namespace) -> list[BenchmarkContext]: - +def as_benchmark_contexts( + hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace +) -> list[BenchmarkContext]: ctxs: list[BenchmarkContext] = [] for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa - args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, - args.sort_by_lora_id): + args.batch_sizes, + list(hidden_sizes), + lora_ranks, + args.num_loras, + args.sort_by_lora_id, + ): ctxs.append( BenchmarkContext( batch_size=batch_size, @@ -747,13 +837,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], lora_rank=lora_rank, num_loras=num_loras, num_active_loras=args.num_active_loras - if args.num_active_loras else num_loras, + if args.num_active_loras + else num_loras, # To be filled based on the OpType to benchmark seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, # To be filled based on the OpType to benchmark - num_slices=None)) + num_slices=None, + ) + ) return ctxs @@ -761,13 +854,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], def run_list_bench(args: argparse.Namespace): print(args) - print("List bench :\n" - f" Hidden Sizes {args.hidden_sizes}" - f" LoRA Ranks {args.lora_ranks}") + print( + "List bench :\n" + f" Hidden Sizes {args.hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}" + ) # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) + hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args + ) run(args, bench_contexts) @@ -776,19 +872,22 @@ def run_range_bench(args: argparse.Namespace): print(args) hidden_sizes = list( - range(args.hidden_sizes_start, args.hidden_sizes_end + 1, - args.hidden_sizes_increment)) + range( + args.hidden_sizes_start, + args.hidden_sizes_end + 1, + args.hidden_sizes_increment, + ) + ) lora_ranks = list( - range(args.lora_ranks_start, args.lora_ranks_end + 1, - args.lora_ranks_increment)) + range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment) + ) - print("Range bench :\n" - f" Hidden Sizes {hidden_sizes}" - f" LoRA Ranks {lora_ranks}") + print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}") # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) + hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args + ) run(args, bench_contexts) @@ -806,21 +905,19 @@ def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]: # Get all hidden sizes hidden_sizes: set[int] = set() for model_name, tp_size in product(args.models, args.tp_sizes): - hidden_sizes = hidden_sizes.union( - hidden_sizes_from_model(model_name, tp_size)) + hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size)) - print("Model bench :\n" - f" Hidden Sizes {hidden_sizes}" - f" LoRA Ranks {args.lora_ranks}") + print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}") # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) + hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args + ) run(args, bench_contexts) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "torch.float16": @@ -830,14 +927,15 @@ def to_torch_dtype(dt): raise ValueError("unsupported dtype") def get_bool(s: str) -> bool: - return s.lower() in ['true', '1'] + return s.lower() in ["true", "1"] def add_common_command_args(p: argparse.ArgumentParser): p.add_argument( "--dtype", type=to_torch_dtype, required=True, - help="Available options are ['torch.float16', 'torch.bfloat16']") + help="Available options are ['torch.float16', 'torch.bfloat16']", + ) p.add_argument( "--arg-pool-size", @@ -845,56 +943,66 @@ def add_common_command_args(p: argparse.ArgumentParser): default=32, help="Run profiles with a pool of input/output/meta tensors instead" "of simply reusing the same tensors for all runs. A bigger arg-pool" - "mitigates hardware caching effects during benchmarking.") + "mitigates hardware caching effects during benchmarking.", + ) p.add_argument( "--cuda-graph-nops", type=int, - help=("when set profiling is done using cudagraph, " - "with the given number of operations in a graph." - "Note that the measurement returned is the time " - "taken for N consecutive executions of the benchmarking " - "functions, where N is the value of this argument.")) - p.add_argument("--num-loras", - nargs="+", - type=int, - default=DEFAULT_NUM_LORAS) - p.add_argument("--num-active-loras", - type=int, - default=None, - help="Active LoRAs. When None, all LoRAs are active") - p.add_argument("--sort-by-lora-id", - nargs="+", - type=get_bool, - default=DEFAULT_SORT_BY_LORA_IDS) - p.add_argument("--op-types", - nargs="+", - type=OpType.from_str, - default=list(OpType)) - p.add_argument('--seq-lengths', - nargs="+", - type=int, - default=DEFAULT_SEQ_LENGTHS) - p.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) - p.add_argument("--expand-fn-add-inputs", - nargs="+", - type=get_bool, - default=DEFAULT_EXPAND_FN_ADD_INPUTS) + help=( + "when set profiling is done using cudagraph, " + "with the given number of operations in a graph." + "Note that the measurement returned is the time " + "taken for N consecutive executions of the benchmarking " + "functions, where N is the value of this argument." + ), + ) + p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS) + p.add_argument( + "--num-active-loras", + type=int, + default=None, + help="Active LoRAs. When None, all LoRAs are active", + ) + p.add_argument( + "--sort-by-lora-id", + nargs="+", + type=get_bool, + default=DEFAULT_SORT_BY_LORA_IDS, + ) + p.add_argument( + "--op-types", nargs="+", type=OpType.from_str, default=list(OpType) + ) + p.add_argument( + "--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS + ) + p.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + p.add_argument( + "--expand-fn-add-inputs", + nargs="+", + type=get_bool, + default=DEFAULT_EXPAND_FN_ADD_INPUTS, + ) p.add_argument( - '-o', - '--output-directory', + "-o", + "--output-directory", type=str, - help=("Output directory to store a the list of benchmarking" - "TMeasurement objects as a pickle file")) + help=( + "Output directory to store a the list of benchmarking" + "TMeasurement objects as a pickle file" + ), + ) p.add_argument( "--test-correctness", - action='store_true', - help=("When enabled, the benchmarking functions are tested" - "for correctness before the actual benchmarking")) + action="store_true", + help=( + "When enabled, the benchmarking functions are tested" + "for correctness before the actual benchmarking" + ), + ) parser = FlexibleArgumentParser( description=f""" @@ -910,50 +1018,45 @@ def add_common_command_args(p: argparse.ArgumentParser): range_bench example: python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) subparsers = parser.add_subparsers(dest="cmd", required=True) list_parser = subparsers.add_parser("list_bench") - list_parser.add_argument("--hidden-sizes", - nargs="+", - type=int, - default=DEFAULT_HIDDEN_SIZES) - list_parser.add_argument("--lora-ranks", - nargs="+", - type=int, - default=DEFAULT_LORA_RANKS) + list_parser.add_argument( + "--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES + ) + list_parser.add_argument( + "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS + ) add_common_command_args(list_parser) list_parser.set_defaults(func=run_list_bench) range_parser = subparsers.add_parser("range_bench") range_parser.add_argument("--hidden-sizes-start", type=int, required=True) range_parser.add_argument("--hidden-sizes-end", type=int, required=True) - range_parser.add_argument("--hidden-sizes-increment", - type=int, - required=True) + range_parser.add_argument("--hidden-sizes-increment", type=int, required=True) range_parser.add_argument("--lora-ranks-start", type=int, required=True) range_parser.add_argument("--lora-ranks-end", type=int, required=True) - range_parser.add_argument("--lora-ranks-increment", - type=int, - required=True) + range_parser.add_argument("--lora-ranks-increment", type=int, required=True) add_common_command_args(range_parser) range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--lora-ranks", - nargs="+", - type=int, - default=DEFAULT_LORA_RANKS) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS + ) add_common_command_args(model_parser) model_parser.set_defaults(func=run_model_bench) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index a661ea9d7e60..f8f1db04790b 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -20,12 +20,18 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, - marlin_zero_points) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + marlin_permute_scales, + marlin_zero_points, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace) + MarlinWorkspace, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) + pack_rows, + quantize_weights, +) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser @@ -82,12 +88,14 @@ def rand_data(shape, dtype=torch.float16, scale=1): return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") -def quantize_and_pack(atype: torch.dtype, - w: torch.Tensor, - wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], - zero_points: bool = False): +def quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False, +): assert wtype.is_integer(), "TODO: support floating point weights" w_ref, w_q, w_s, w_zp = quantize_weights( @@ -96,21 +104,24 @@ def quantize_and_pack(atype: torch.dtype, group_size=group_size, zero_points=zero_points, # to match how the kernel applies zps - ref_zero_points_after_scales=True) + ref_zero_points_after_scales=True, + ) w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) return w_ref, w_q, w_s, w_zp -def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, - group_size: Optional[int]) -> list[BenchmarkTensors]: +def create_bench_tensors( + shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] +) -> list[BenchmarkTensors]: m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so # we construct enough weights to exceed L2 cache, which is 50mb on a H100 # so we target total weight size > 2*50mb - num_weights = math.ceil(2 * 50 * 1024**2 * 8 / - (k * n * types.weight_type.size_bits)) + num_weights = math.ceil( + 2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits) + ) a = rand_data((m, k), types.act_type, scale=5) @@ -124,8 +135,13 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, w = w.to(torch.float16) w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( - a.dtype, w, types.weight_type, types.group_scale_type, group_size, - types.group_zero_type is not None) + a.dtype, + w, + types.weight_type, + types.group_scale_type, + group_size, + types.group_zero_type is not None, + ) if not a.dtype.is_floating_point: aiinfo = torch.iinfo(a.dtype) @@ -133,21 +149,30 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, w_ref = w_ref.to(torch.float32) - w_ch_s = None if types.channel_scale_type is None else\ - rand_data((n,), types.channel_scale_type) - w_tok_s = None if types.token_scale_type is None else\ - rand_data((m,), types.token_scale_type) + w_ch_s = ( + None + if types.channel_scale_type is None + else rand_data((n,), types.channel_scale_type) + ) + w_tok_s = ( + None + if types.token_scale_type is None + else rand_data((m,), types.token_scale_type) + ) benchmark_tensors.append( - BenchmarkTensors(w_ref=w_ref, - a=a, - w_q=w_q_packed, - wtype=types.weight_type, - w_g_s=w_s, - w_g_zp=w_zp, - group_size=group_size, - w_ch_s=w_ch_s, - w_tok_s=w_tok_s)) + BenchmarkTensors( + w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) + ) return benchmark_tensors @@ -170,50 +195,57 @@ def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() return lambda: ops.cutlass_scaled_mm( - bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16 + ) def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: device = bt.a.device - workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = MarlinWorkspace( + bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ) if bt.w_g_zp is None: w_zp = torch.empty(0, dtype=torch.int, device=device) else: - w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.wtype.size_bits) + w_zp = marlin_zero_points( + bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits + ) if bt.group_size is None: w_s = torch.tensor([], device="cuda", dtype=torch.half) else: - w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.group_size) + w_s = marlin_permute_scales( + bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size + ) sort_indices = torch.empty(0, dtype=torch.int, device=device) g_idx = torch.empty(0, dtype=torch.int, device=device) - w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.wtype.size_bits) + w_q = ops.gptq_marlin_repack( + bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits + ) if bt.a.dtype.is_floating_point: assert bt.w_ch_s is None assert bt.w_tok_s is None assert bt.group_size is not None - fn = lambda: ops.gptq_marlin_gemm(a=bt.a, - b_q_weight=w_q, - b_scales=w_s, - b_zeros=w_zp, - g_idx=g_idx, - perm=sort_indices, - workspace=workspace.scratch, - b_q_type=bt.wtype, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0], - is_k_full=True, - is_zp_float=False) + fn = lambda: ops.gptq_marlin_gemm( + a=bt.a, + b_q_weight=w_q, + b_scales=w_s, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True, + is_zp_float=False, + ) else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 @@ -221,36 +253,35 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: if bt.w_ch_s is not None: s_ch = bt.w_ch_s.to(torch.float32) else: - s_ch = torch.ones(bt.w_ref.shape[1], - dtype=torch.float32, - device=device) + s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) if bt.w_tok_s is not None: s_tok = bt.w_tok_s.to(torch.float32) else: - s_tok = torch.ones(bt.a.shape[0], - dtype=torch.float32, - device=device) - - fn = lambda: ops.marlin_qqq_gemm(a=bt.a, - b_q_weight=w_q, - s_group=w_s, - s_tok=s_tok, - s_ch=s_ch, - workspace=workspace.scratch, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0]) + s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) + + fn = lambda: ops.marlin_qqq_gemm( + a=bt.a, + b_q_weight=w_q, + s_group=w_s, + s_tok=s_tok, + s_ch=s_ch, + workspace=workspace.scratch, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + ) return fn -def machete_create_bench_fn(bt: BenchmarkTensors, - out_type=torch.dtype, - schedule=None) -> Callable: +def machete_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: w_q = bt.w_q.t().contiguous().t() # make col major - w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, - None if bt.w_g_s is None else bt.w_g_s.dtype) + w_q = ops.machete_prepack_B( + w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype + ) w_g_zp = bt.w_g_zp if w_g_zp is not None: @@ -275,26 +306,24 @@ def machete_create_bench_fn(bt: BenchmarkTensors, # bench -def bench_fns(label: str, sub_label: str, description: str, - fns: list[Callable]): - +def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]): min_run_time = 1 if not NVTX_PROFILE else 0.1 res = TBenchmark.Timer( stmt=""" for fn in fns: fn() """, - globals={ - "fns": fns - }, + globals={"fns": fns}, label=label, sub_label=sub_label, description=description, ).blocked_autorange(min_run_time=min_run_time) if NVTX_PROFILE: - with nvtx.annotate("mm-bench"), nvtx.annotate( - f"{label}|{sub_label}|{description}"): + with ( + nvtx.annotate("mm-bench"), + nvtx.annotate(f"{label}|{sub_label}|{description}"), + ): fns[0]() return res @@ -304,19 +333,20 @@ def bench_fns(label: str, sub_label: str, description: str, _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None -def bench(types: TypeConfig, - group_size: int, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - sweep_schedules: bool = True) -> list[TMeasurement]: +def bench( + types: TypeConfig, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + sweep_schedules: bool = True, +) -> list[TMeasurement]: benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) sub_label += f", L={len(benchmark_tensors)}" - name_type_string = f"W{types.weight_type}"+\ - f"-A{terse_type_name(types.act_type)}" + name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}" if types.group_scale_type is not None: name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" if types.group_zero_type is not None: @@ -332,31 +362,45 @@ def bench(types: TypeConfig, # pytorch impl timers.append( bench_fns( - label, sub_label, "torch.matmul (fp16)", - [torch_matmul_f16_create_bench_fn(bt) - for bt in benchmark_tensors])) + label, + sub_label, + "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: timers.append( bench_fns( - label, sub_label, - f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ - cutlass_scaled_mm_create_bench_fn(bt) - for bt in benchmark_tensors - ])) + label, + sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", + [cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) if types.act_type != torch.float8_e4m3fn: timers.append( - bench_fns(label, sub_label, f"marlin ({name_type_string})", - [marlin_create_bench_fn(bt) - for bt in benchmark_tensors])) + bench_fns( + label, + sub_label, + f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) # machete timers.append( - bench_fns(label, sub_label, f"machete ({name_type_string})", [ - machete_create_bench_fn(bt, out_type=types.output_type) - for bt in benchmark_tensors - ])) + bench_fns( + label, + sub_label, + f"machete ({name_type_string})", + [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS @@ -371,7 +415,8 @@ def bench(types: TypeConfig, group_zeros_type=types.group_zero_type, token_scales_type=types.token_scale_type, channel_scales_type=types.channel_scale_type, - out_type=types.output_type) + out_type=types.output_type, + ) if schedules is None or len(schedules) == 0: raise ValueError("No schedules found to sweep") @@ -383,11 +428,17 @@ def bench(types: TypeConfig, if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: continue - res = bench_fns(label, sub_label, "machete_best", [ - machete_create_bench_fn( - bt, out_type=types.output_type, schedule=schedule) - for bt in benchmark_tensors - ]) + res = bench_fns( + label, + sub_label, + "machete_best", + [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule + ) + for bt in benchmark_tensors + ], + ) results_row = { "M": m, @@ -398,10 +449,8 @@ def bench(types: TypeConfig, "median": res.median, } if _SWEEP_SCHEDULES_RESULTS is None: - _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( - columns=results_row.keys()) - _SWEEP_SCHEDULES_RESULTS.\ - loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row print(f" {res.median:5.5} ", schedule) if not best or res.median < best.median: @@ -422,8 +471,9 @@ def print_timers(timers: list[TMeasurement]): def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: types = TypeConfig( act_type=args.act_type, - weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ - else scalar_types.uint4, + weight_type=scalar_types.uint4b8 + if args.group_zero_type is None + else scalar_types.uint4, output_type=args.out_type, group_scale_type=args.group_scale_type, group_zero_type=args.group_zero_type, @@ -433,14 +483,16 @@ def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: results: list[TMeasurement] = [] for m, k, n in MKNs: - timers = bench(types, - args.group_size, - m, - k, - n, - f"{args.act_type}-gemm", - f"MKN=({m}x{k}x{n})", - sweep_schedules=args.sweep_schedules) + timers = bench( + types, + args.group_size, + m, + k, + n, + f"{args.act_type}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=args.sweep_schedules, + ) print_timers(timers) results.extend(timers) @@ -454,7 +506,6 @@ def make_output( base_description: str, timestamp=None, ): - print(f"== All Results {base_description} ====") print_timers(data) @@ -468,8 +519,7 @@ def make_output( def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, args.sweep_schedules, MKNs) @@ -479,8 +529,9 @@ def run_square_bench(args): def run_range_bench(args): m_start, k_start, n_start = (int(x) for x in args.dim_start.split(",")) m_end, k_end, n_end = (int(x) for x in args.dim_end.split(",")) - m_increment, k_increment, n_increment = \ - (int(x) for x in args.dim_increment.split(",")) + m_increment, k_increment, n_increment = ( + int(x) for x in args.dim_increment.split(",") + ) Ms = list(range(m_start, m_end + 1, m_increment)) Ks = list(range(k_start, k_end + 1, k_increment)) Ns = list(range(n_start, n_end + 1, n_increment)) @@ -492,7 +543,6 @@ def run_range_bench(args): def run_model_bench(args): - print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") @@ -535,10 +585,13 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: args_dict = vars(args) args_dict.pop("func") - pkl.dump({ - "args": args_dict, - "results": all_results, - }, f) + pkl.dump( + { + "args": args_dict, + "results": all_results, + }, + f, + ) if __name__ == "__main__": @@ -554,7 +607,6 @@ def to_torch_dtype(dt): }[dt] class ToTorchDtype(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, to_torch_dtype(values)) @@ -580,32 +632,32 @@ def __call__(self, parser, namespace, values, option_string=None): "--act-type", action=ToTorchDtype, required=True, - choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], + choices=["bfloat16", "float16", "int8", "float8_e4m3fn"], ) parser.add_argument( "--group-scale-type", action=ToTorchDtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--group-zero-type", type=to_torch_dtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--channel-scale-type", action=ToTorchDtype, - choices=['float'], + choices=["float"], ) parser.add_argument( "--token-scale-type", action=ToTorchDtype, - choices=['float'], + choices=["float"], ) parser.add_argument( "--out-type", action=ToTorchDtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--group-size", @@ -618,9 +670,11 @@ def __call__(self, parser, namespace, values, option_string=None): action="store_true", help="Run a sweep over all supported schedules", ) - parser.add_argument("--sweep-csv-out", - help="CSV to store sweep results", - default="sch_sweep_results.csv") + parser.add_argument( + "--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv", + ) subparsers = parser.add_subparsers(dest="cmd", required=True) square_parser = subparsers.add_parser("square_bench") @@ -634,17 +688,20 @@ def __call__(self, parser, namespace, values, option_string=None): "--dim-start", type=str, required=True, - help="Start value for M,K,N as common separated list") + help="Start value for M,K,N as common separated list", + ) range_parser.add_argument( "--dim-end", type=str, required=True, - help="End value (inclusive) for M,K,N as common separated list") + help="End value (inclusive) for M,K,N as common separated list", + ) range_parser.add_argument( "--dim-increment", type=str, required=True, - help="Increment value for M,K,N as common separated list") + help="Increment value for M,K,N as common separated list", + ) range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") @@ -655,14 +712,12 @@ def __call__(self, parser, namespace, values, option_string=None): default=DEFAULT_MODELS, choices=WEIGHT_SHAPES.keys(), ) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 1e785ac8fc73..b17baff2e5f5 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -6,19 +6,34 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) + GPTQ_MARLIN_24_MAX_PARALLEL, + GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, + GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.utils.allspark_utils import ( - ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + MARLIN_SUPPORTED_GROUP_SIZES, + query_marlin_supported_quant_types, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace, marlin_quantize) + MarlinWorkspace, + marlin_quantize, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( - marlin_24_quantize) + marlin_24_quantize, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) + gptq_pack, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser @@ -29,22 +44,29 @@ K_FULL_OPTS = [False, True] -def bench_run(results: list[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, quant_type: ScalarType, - group_size: int, size_m: int, size_k: int, size_n: int): +def bench_run( + results: list[benchmark.Measurement], + model: str, + act_order: bool, + is_k_full: bool, + quant_type: ScalarType, + group_size: int, + size_m: int, + size_k: int, + size_n: int, +): label = "Quant Matmul" - sub_label = ("{}, act={} k_full={}, q={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, - str(quant_type), group_size, size_m, - size_k, size_n)) + sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( + model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n + ) print(f"Testing: {sub_label}") a = torch.randn(size_m, size_k).to(torch.half).cuda() b = torch.rand(size_k, size_n).to(torch.half).cuda() - a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) + a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda() # Marlin quant ( @@ -57,14 +79,16 @@ def bench_run(results: list[benchmark.Measurement], model: str, ) = marlin_quantize(b, quant_type, group_size, act_order) # Marlin_24 quant - (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( + marlin_24_quantize(b, quant_type, group_size) + ) marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) # GPTQ quant - (w_ref, q_w, s, g_idx, - rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) + (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( + b, quant_type, group_size, act_order + ) q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" @@ -74,32 +98,37 @@ def bench_run(results: list[benchmark.Measurement], model: str, (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) # Prepare - marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + marlin_workspace = MarlinWorkspace( + size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ) - marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_MAX_PARALLEL) + marlin_24_workspace = MarlinWorkspace( + size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL + ) marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) # AllSpark W8A16 quant - as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES - and group_size == -1 and not act_order and is_k_full) + as_supported_case = ( + quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 + and not act_order + and is_k_full + ) if as_supported_case: properties = torch.cuda.get_device_properties(b.device.index) sm_count = properties.multi_processor_count sm_version = properties.major * 10 + properties.minor - supported_arch = (sm_version >= 80 and sm_version < 90) + supported_arch = sm_version >= 80 and sm_version < 90 as_supported_case = as_supported_case and supported_arch if supported_arch: has_zp = False - w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, - has_zp) + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) qw = qw.to(torch.uint8) - qw_reorder, s_reorder, zp_reorder = \ - ops.allspark_repack_weight( - qw, s, zp, has_zp) + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( + qw, s, zp, has_zp + ) CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD globals = { @@ -136,8 +165,7 @@ def bench_run(results: list[benchmark.Measurement], model: str, "zp_reorder": zp_reorder if as_supported_case else None, "sm_count": sm_count if as_supported_case else None, "sm_version": sm_version if as_supported_case else None, - "CUBLAS_M_THRESHOLD": - CUBLAS_M_THRESHOLD if as_supported_case else None, + "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None, # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, @@ -158,60 +186,63 @@ def bench_run(results: list[benchmark.Measurement], model: str, label=label, sub_label=sub_label, description="pytorch_gemm", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp16", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp32", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) - if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES - and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): + if ( + quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES + ): results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 + stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_24_gemm", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 + stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_repack", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) if as_supported_case: results.append( benchmark.Timer( - stmt= - "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 + stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="allspark_w8a16_gemm_fp32", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) def main(args): @@ -233,37 +264,50 @@ def main(args): continue for act_order in ACT_ORDER_OPTS: - if len(args.limit_act_order - ) > 0 and act_order not in args.limit_act_order: + if ( + len(args.limit_act_order) > 0 + and act_order not in args.limit_act_order + ): continue for is_k_full in K_FULL_OPTS: - if len(args.limit_k_full - ) > 0 and is_k_full not in args.limit_k_full: + if ( + len(args.limit_k_full) > 0 + and is_k_full not in args.limit_k_full + ): continue - for quant_type in query_marlin_supported_quant_types( - False): - if len(args.limit_num_bits) > 0 and \ - quant_type.size_bits not in args.limit_num_bits: + for quant_type in query_marlin_supported_quant_types(False): + if ( + len(args.limit_num_bits) > 0 + and quant_type.size_bits not in args.limit_num_bits + ): continue for group_size in MARLIN_SUPPORTED_GROUP_SIZES: - if len( - args.limit_group_size - ) > 0 and group_size not in args.limit_group_size: + if ( + len(args.limit_group_size) > 0 + and group_size not in args.limit_group_size + ): continue # For act_order, the group_size must be less than # size_k - if act_order and (group_size == size_k - or group_size == -1): + if act_order and (group_size == size_k or group_size == -1): continue for size_m in args.batch_sizes: - bench_run(results, model, act_order, is_k_full, - quant_type, group_size, size_m, - size_k, size_n) + bench_run( + results, + model, + act_order, + is_k_full, + quant_type, + group_size, + size_m, + size_k, + size_n, + ) compare = benchmark.Compare(results) compare.print() @@ -274,7 +318,8 @@ def main(args): # if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") + description="Benchmark Marlin across specified models/shapes/batches" + ) parser.add_argument( "--models", nargs="+", @@ -282,10 +327,9 @@ def main(args): default=DEFAULT_MODELS, choices=WEIGHT_SHAPES.keys(), ) - parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 4e328b4d49e5..c2f7660858f5 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -31,56 +31,60 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_config(config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: List[int] = None, - use_deep_gemm: bool = False) -> float: +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: List[int] = None, + use_deep_gemm: bool = False, +) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: - w1 = torch.randint(-127, - 127, ( - num_experts, - shard_intermediate_size, - hidden_size, - ), - dtype=torch.int8) - w2 = torch.randint(-127, - 127, ( - num_experts, - hidden_size, - shard_intermediate_size // 2, - ), - dtype=torch.int8) + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + ) else: - w1 = torch.randn(num_experts, - shard_intermediate_size, - hidden_size, - dtype=init_dtype) - w2 = torch.randn(num_experts, - hidden_size, - shard_intermediate_size // 2, - dtype=init_dtype) - gating_output = torch.randn(num_iters, - num_tokens, - num_experts, - dtype=torch.float32) + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) w1_scale = None w2_scale = None a1_scale = None a2_scale = None if use_int8_w8a16: - w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), - dtype=torch.float32) + w1_scale = torch.randn( + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + ) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) if use_fp8_w8a8: if block_quant_shape: @@ -93,10 +97,14 @@ def benchmark_config(config: BenchmarkConfig, n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w1 = (K + block_k - 1) // block_k k_tiles_w2 = (N + block_k - 1) // block_k - w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1), - dtype=torch.float32) * factor_for_scale - w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2), - dtype=torch.float32) * factor_for_scale + w1_scale = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_scale = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) else: w1_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32) @@ -114,10 +122,12 @@ def prepare(i: int): def run(): from vllm.model_executor.layers.fused_moe import override_config + with override_config(config): if use_deep_gemm: topk_weights, topk_ids, token_expert_indices = fused_topk( - x, input_gating, topk, False) + x, input_gating, topk, False + ) return fused_experts( x, w1, @@ -213,8 +223,7 @@ def get_rocm_tuning_space(use_fp16): return param_ranges -def get_configs_compute_bound(use_fp16, - block_quant_shape) -> list[dict[str, int]]: +def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]: configs: list[BenchmarkConfig] = [] if current_platform.is_rocm(): @@ -250,20 +259,25 @@ def get_configs_compute_bound(use_fp16, if block_quant_shape is not None and not use_fp16: block_n, block_k = block_quant_shape[0], block_quant_shape[1] for config in configs[:]: - if config["BLOCK_SIZE_K"] % block_k != 0 or config[ - "BLOCK_SIZE_N"] % block_n != 0: + if ( + config["BLOCK_SIZE_K"] % block_k != 0 + or config["BLOCK_SIZE_N"] % block_n != 0 + ): configs.remove(config) return configs -def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, - search_space, is_fp16, topk): +def prune_rocm_search_space( + num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk +): N1, K1 = shard_intermediate_size, hidden_size N2, K2 = hidden_size, shard_intermediate_size // 2 - pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1, - search_space, is_fp16) - pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2, - search_space, is_fp16) + pruned_space_1 = prune_rocm_configs( + num_tokens * topk, N1, K1, search_space, is_fp16 + ) + pruned_space_2 = prune_rocm_configs( + num_tokens * topk, N2, K2, search_space, is_fp16 + ) search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) return search_space @@ -301,14 +315,14 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True): SPLIT_K = config.get("SPLIT_K", 1) GROUP_M = config.get("GROUP_SIZE_M") if is_fp16: - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): + if ( + matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N + ): continue - if (matrix_instr_nonkdim >= M - and matrix_instr_nonkdim != BLOCK_SIZE_M): + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: continue - if (matrix_instr_nonkdim >= N - and matrix_instr_nonkdim != BLOCK_SIZE_N): + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: continue # Skip BLOCK_SIZE that is too large compare to M/N # unless BLOCK_SIZE is already small enough @@ -329,8 +343,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True): continue # out of shared memory resource # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + LDS = ( + BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + ) if LDS > 65536: continue # Skip small block sizes and num_warps for large gemm @@ -364,7 +380,6 @@ def merge_unique_dicts(list1, list2): @ray.remote(num_gpus=1) class BenchmarkWorker: - def __init__(self, seed: int) -> None: torch.set_default_device("cuda") current_platform.seed_everything(seed) @@ -388,36 +403,40 @@ def benchmark( use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) - dtype_str = get_config_dtype_str(dtype, - use_int8_w8a16=use_int8_w8a16, - use_fp8_w8a8=use_fp8_w8a8) + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. - op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, - dtype_str) + op_config = get_moe_configs( + num_experts, shard_intermediate_size // 2, dtype_str + ) if op_config is None: - config = get_default_config(num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype_str, - is_marlin=False) + config = get_default_config( + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype_str, + is_marlin=False, + ) else: - config = op_config[min(op_config.keys(), - key=lambda x: abs(x - num_tokens))] - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=100, - block_quant_shape=block_quant_shape, - use_deep_gemm=use_deep_gemm) + config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + ) return config, kernel_time def tune( @@ -438,10 +457,14 @@ def tune( best_time = float("inf") if current_platform.is_rocm(): is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = prune_rocm_search_space(num_tokens, - shard_intermediate_size, - hidden_size, search_space, - is_fp16, topk) + search_space = prune_rocm_search_space( + num_tokens, + shard_intermediate_size, + hidden_size, + search_space, + is_fp16, + topk, + ) need_device_guard = False if current_platform.is_rocm(): @@ -449,8 +472,7 @@ def tune( if visible_device != f"{self.device_id}": need_device_guard = True - with torch.cuda.device( - self.device_id) if need_device_guard else nullcontext(): + with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): for config in tqdm(search_space): try: kernel_time = benchmark_config( @@ -465,7 +487,8 @@ def tune( use_int8_w8a16, num_iters=20, block_quant_shape=block_quant_shape, - use_deep_gemm=use_deep_gemm) + use_deep_gemm=use_deep_gemm, + ) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. continue @@ -481,42 +504,44 @@ def tune( def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: return { - "BLOCK_SIZE_M": - config["BLOCK_SIZE_M"], - "BLOCK_SIZE_N": - config["BLOCK_SIZE_N"], - "BLOCK_SIZE_K": - config["BLOCK_SIZE_K"], - "GROUP_SIZE_M": - config["GROUP_SIZE_M"], - "num_warps": - config["num_warps"], - "num_stages": - config["num_stages"], - **({ - "waves_per_eu": config["waves_per_eu"] - } if "waves_per_eu" in config else {}), - **({ - "matrix_instr_nonkdim": config["matrix_instr_nonkdim"] - } if "matrix_instr_nonkdim" in config else {}), - **({ - "kpack": config["kpack"] - } if "kpack" in config else {}), + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + **( + {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + ), + **( + {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]} + if "matrix_instr_nonkdim" in config + else {} + ), + **({"kpack": config["kpack"]} if "kpack" in config else {}), } -def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, - shard_intermediate_size: int, hidden_size: int, topk: int, - dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - block_quant_shape: List[int]) -> None: - dtype_str = get_config_dtype_str(dtype, - use_int8_w8a16=use_int8_w8a16, - use_fp8_w8a8=use_fp8_w8a8) +def save_configs( + configs: dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_quant_shape: List[int], +) -> None: + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. - filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - dtype_str, block_quant_shape) + filename = get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) print(f"Writing best config to {filename}...") with open(filename, "w") as f: @@ -525,18 +550,16 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, def get_weight_block_size_safety(config, default_value=None): - - quantization_config = getattr(config, 'quantization_config', {}) + quantization_config = getattr(config, "quantization_config", {}) if isinstance(quantization_config, dict): - return quantization_config.get('weight_block_size', default_value) + return quantization_config.get("weight_block_size", default_value) return default_value def main(args: argparse.Namespace): print(args) - config = get_config(model=args.model, - trust_remote_code=args.trust_remote_code) + config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) if args.model_prefix: config = getattr(config, args.model_prefix) config = SimpleNamespace(**config) @@ -551,14 +574,12 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif (config.architectures[0] - in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")): + elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] in ("Qwen2MoeForCausalLM", - "Qwen3MoeForCausalLM"): + elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size @@ -573,16 +594,35 @@ def main(args: argparse.Namespace): shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - dtype = torch.float16 if current_platform.is_rocm() else getattr( - torch, config.torch_dtype) + dtype = ( + torch.float16 + if current_platform.is_rocm() + else getattr(torch, config.torch_dtype) + ) use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" block_quant_shape = get_weight_block_size_safety(config) if args.batch_size is None: batch_sizes = [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ] else: batch_sizes = [args.batch_size] @@ -593,7 +633,8 @@ def main(args: argparse.Namespace): # Ray will set ROCR_VISIBLE_DEVICES for device visibility logger.warning( "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility." - "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.") + "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES." + ) val = os.environ["HIP_VISIBLE_DEVICES"] os.environ["ROCR_VISIBLE_DEVICES"] = val del os.environ["HIP_VISIBLE_DEVICES"] @@ -620,25 +661,59 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: start = time.time() configs = _distribute( - "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, - block_quant_shape, use_deep_gemm) - for batch_size in batch_sizes]) + "tune", + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + search_space, + block_quant_shape, + use_deep_gemm, + ) + for batch_size in batch_sizes + ], + ) best_configs = { - M: sort_config(config) - for M, config in zip(batch_sizes, configs) + M: sort_config(config) for M, config in zip(batch_sizes, configs) } - save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, - block_quant_shape) + save_configs( + best_configs, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + block_quant_shape, + ) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: outputs = _distribute( "benchmark", - [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, - use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm) - for batch_size in batch_sizes]) + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + block_quant_shape, + use_deep_gemm, + ) + for batch_size in batch_sizes + ], + ) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}, config: {config}") @@ -647,18 +722,15 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: if __name__ == "__main__": parser = FlexibleArgumentParser() - parser.add_argument("--model", - type=str, - default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--tp-size", - "-tp", - "--tensor-parallel-size", - type=int, - default=2) - parser.add_argument("--dtype", - type=str, - choices=["auto", "fp8_w8a8", "int8_w8a16"], - default="auto") + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument( + "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2 + ) + parser.add_argument( + "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + ) parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 937df9624651..333986fdf5ef 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -8,7 +8,9 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _moe_permute, _moe_unpermute_and_reduce) + _moe_permute, + _moe_unpermute_and_reduce, +) from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize @@ -27,15 +29,17 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_permute(num_tokens: int, - num_experts: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - use_customized_permute: bool = False) -> float: +def benchmark_permute( + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False, +) -> float: # init_dtype = torch.float16 if use_fp8_w8a8 else dtype hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) # output_hidden_states = torch.empty_like(hidden_states) @@ -46,36 +50,41 @@ def benchmark_permute(num_tokens: int, align_block_size = None qhidden_states = hidden_states - gating_output = torch.randn(num_iters, - num_tokens, - num_experts, - dtype=torch.float32) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) topk_weights, topk_ids, token_expert_indices = fused_topk( - qhidden_states, input_gating, topk, False) + qhidden_states, input_gating, topk, False + ) def prepare(i: int): input_gating.copy_(gating_output[i]) def run(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - m_indices) = moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( + moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + ) else: - (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qhidden_states, None, topk_ids, - num_experts, None, align_block_size) + ( + permuted_hidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = _moe_permute( + qhidden_states, None, topk_ids, num_experts, None, align_block_size + ) # JIT compilation & warmup run() @@ -111,15 +120,17 @@ def run(): return avg -def benchmark_unpermute(num_tokens: int, - num_experts: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - use_customized_permute: bool = False) -> float: +def benchmark_unpermute( + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False, +) -> float: # init_dtype = torch.float16 if use_fp8_w8a8 else dtype hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) output_hidden_states = torch.empty_like(hidden_states) @@ -133,46 +144,74 @@ def benchmark_unpermute(num_tokens: int, input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) topk_weights, topk_ids, token_expert_indices = fused_topk( - qhidden_states, input_gating, topk, False) + qhidden_states, input_gating, topk, False + ) def prepare(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - m_indices) = moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( + moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + ) # convert to fp16/bf16 as gemm output - return (permuted_hidden_states.to(dtype), first_token_off, - inv_perm_idx, m_indices) + return ( + permuted_hidden_states.to(dtype), + first_token_off, + inv_perm_idx, + m_indices, + ) else: - (permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qhidden_states, None, topk_ids, - num_experts, None, align_block_size) + ( + permuted_qhidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = _moe_permute( + qhidden_states, None, topk_ids, num_experts, None, align_block_size + ) # convert to fp16/bf16 as gemm output - return (permuted_qhidden_states.to(dtype), a1q_scale, - sorted_token_ids, expert_ids, inv_perm) + return ( + permuted_qhidden_states.to(dtype), + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) def run(input: tuple): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - m_indices) = input - moe_unpermute(permuted_hidden_states, topk_weights, topk_ids, - inv_perm_idx, first_token_off, topk, num_experts, - num_experts) + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input + moe_unpermute( + permuted_hidden_states, + topk_weights, + topk_ids, + inv_perm_idx, + first_token_off, + topk, + num_experts, + num_experts, + ) else: - (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = input - _moe_unpermute_and_reduce(output_hidden_states, - permuted_hidden_states, inv_perm, - topk_weights) + ( + permuted_hidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = input + _moe_unpermute_and_reduce( + output_hidden_states, permuted_hidden_states, inv_perm, topk_weights + ) # JIT compilation & warmup input = prepare() @@ -209,7 +248,6 @@ def run(input: tuple): @ray.remote(num_gpus=1) class BenchmarkWorker: - def __init__(self, seed: int) -> None: torch.set_default_device("cuda") current_platform.seed_everything(seed) @@ -241,7 +279,8 @@ def benchmark( use_fp8_w8a8, use_int8_w8a16, num_iters=100, - use_customized_permute=use_customized_permute) + use_customized_permute=use_customized_permute, + ) unpermute_time = benchmark_unpermute( num_tokens, num_experts, @@ -251,15 +290,15 @@ def benchmark( use_fp8_w8a8, use_int8_w8a16, num_iters=100, - use_customized_permute=use_customized_permute) + use_customized_permute=use_customized_permute, + ) return permute_time, unpermute_time def get_weight_block_size_safety(config, default_value=None): - - quantization_config = getattr(config, 'quantization_config', {}) + quantization_config = getattr(config, "quantization_config", {}) if isinstance(quantization_config, dict): - return quantization_config.get('weight_block_size', default_value) + return quantization_config.get("weight_block_size", default_value) return default_value @@ -267,20 +306,21 @@ def main(args: argparse.Namespace): print(args) config = AutoConfig.from_pretrained( - args.model, trust_remote_code=args.trust_remote_code) + args.model, trust_remote_code=args.trust_remote_code + ) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k elif config.architectures[0] == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok - elif (config.architectures[0] == "DeepseekV3ForCausalLM" - or config.architectures[0] == "DeepseekV2ForCausalLM"): + elif ( + config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM" + ): E = config.n_routed_experts topk = config.num_experts_per_tok - elif config.architectures[0] in [ - "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" - ]: + elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: E = config.num_experts topk = config.num_experts_per_tok @@ -299,8 +339,24 @@ def main(args: argparse.Namespace): if args.batch_size is None: batch_sizes = [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ] else: batch_sizes = [args.batch_size] @@ -321,9 +377,21 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: return ray.get(outputs) outputs = _distribute( - "benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8, - use_int8_w8a16, use_customized_permute) - for batch_size in batch_sizes]) + "benchmark", + [ + ( + batch_size, + E, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + use_customized_permute, + ) + for batch_size in batch_sizes + ], + ) for batch_size, (permute, unpermute) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}") @@ -333,13 +401,12 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: if __name__ == "__main__": parser = FlexibleArgumentParser() - parser.add_argument("--model", - type=str, - default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--dtype", - type=str, - choices=["auto", "fp8_w8a8", "int8_w8a16"], - default="auto") + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument( + "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + ) parser.add_argument("--use-customized-permute", action="store_true") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 2625239b08ef..54f05e723226 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,8 +9,11 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + FlexibleArgumentParser, + create_kv_caches_with_random, +) logger = init_logger(__name__) @@ -38,19 +41,15 @@ def main( current_platform.seed_everything(seed) scale = float(1.0 / (head_size**0.5)) - query = torch.empty(num_seqs, - num_query_heads, - head_size, - dtype=dtype, - device=device) + query = torch.empty( + num_seqs, num_query_heads, head_size, dtype=dtype, device=device + ) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 alibi_slopes = None if use_alibi: - alibi_slopes = torch.randn(num_query_heads, - dtype=torch.float, - device=device) + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device) seq_lens = [seq_len for _ in range(num_seqs)] max_seq_len = max(seq_lens) @@ -61,24 +60,23 @@ def main( block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables_lst.append(block_table) - block_tables = torch.tensor(block_tables_lst, - dtype=torch.int, - device=device) + block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device) + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -86,11 +84,11 @@ def main( if version == "v2": if current_platform.is_rocm(): global PARTITION_SIZE - if not args.custom_paged_attn: + if not args.custom_paged_attn and not current_platform.is_navi(): PARTITION_SIZE = 1024 else: PARTITION_SIZE = PARTITION_SIZE_ROCM - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -110,9 +108,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = torch.tensor(1.0, - dtype=torch.float32, - device=device) + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) for _ in range(num_iters): if version == "v1": @@ -166,6 +162,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: scale, block_tables, seq_lens, + None, block_size, max_seq_len, alibi_slopes, @@ -195,30 +192,29 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': - logger.warning("This script benchmarks the paged attention kernel. " - "By default this is no longer used in vLLM inference.") +if __name__ == "__main__": + logger.warning( + "This script benchmarks the paged attention kernel. " + "By default this is no longer used in vLLM inference." + ) - parser = FlexibleArgumentParser( - description="Benchmark the paged attention kernel.") - parser.add_argument("--version", - type=str, - choices=["v1", "v2"], - default="v2") + parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") + parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) - parser.add_argument("--head-size", - type=int, - choices=[64, 80, 96, 112, 120, 128, 192, 256], - default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument( @@ -228,10 +224,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: default="auto", help="Data type for kv cache storage. If 'auto', will use model " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " - "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") - parser.add_argument("--custom-paged-attn", - action="store_true", - help="Use custom paged attention") + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)", + ) + parser.add_argument( + "--custom-paged-attn", action="store_true", help="Use custom paged attention" + ) args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index b643897a60ee..2463dfebe83c 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -10,15 +10,17 @@ @torch.inference_mode() -def main(num_tokens: int, - hidden_size: int, - static_scale: bool, - quant_dtype: torch.dtype, - dtype: torch.dtype, - seed: int = 0, - do_profile: bool = False, - num_warmup_iters: int = 5, - num_iters: int = 100) -> None: +def main( + num_tokens: int, + hidden_size: int, + static_scale: bool, + quant_dtype: torch.dtype, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: current_platform.seed_everything(seed) torch.set_default_device("cuda") @@ -56,7 +58,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -66,37 +68,40 @@ def to_torch_dtype(dt): raise ValueError(f"Unsupported dtype: {dt}") parser = FlexibleArgumentParser( - description="Benchmark the quantization (fp8 or int8) kernel.") + description="Benchmark the quantization (fp8 or int8) kernel." + ) parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--static-scale", action="store_true") - parser.add_argument("--quant-dtype", - type=str, - choices=["fp8", "int8"], - default="int8") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--quant-dtype", type=str, choices=["fp8", "int8"], default="int8" + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument("--num-warmup-iters", type=int, default=5) - parser.add_argument("--num-iters", - type=int, - default=100, - help="Number of benchmark iterations. " - "If --profile is set, this number is ignored") + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored", + ) args = parser.parse_args() print(args) - main(num_tokens=args.num_tokens, - hidden_size=args.hidden_size, - static_scale=args.static_scale, - quant_dtype=to_torch_dtype(args.quant_dtype), - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - seed=args.seed, - do_profile=args.profile, - num_warmup_iters=args.num_warmup_iters, - num_iters=args.num_iters) + main( + num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + static_scale=args.static_scale, + quant_dtype=to_torch_dtype(args.quant_dtype), + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + ) diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py index 09a319ccf1d1..d720083b6150 100644 --- a/benchmarks/kernels/benchmark_rmsnorm.py +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -12,7 +12,6 @@ class HuggingFaceRMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -114,23 +113,19 @@ def rmsnorm_vllm( def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): dtype = torch.bfloat16 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="cuda") + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda") residual = torch.randn_like(x) if use_residual else None output_naive = rmsnorm_naive( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) output_flashinfer = rmsnorm_flashinfer( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) output_vllm = rmsnorm_vllm( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) if use_residual: output_naive = output_naive[0] @@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): print(f"FlashInfer output={output_flashinfer}") print(f"vLLM output={output_vllm}") - if torch.allclose(output_naive, output_flashinfer, atol=1e-2, - rtol=1e-2) and torch.allclose( - output_naive, output_vllm, atol=1e-2, rtol=1e-2): + if torch.allclose( + output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 + ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): print("โœ… All implementations match") else: print("โŒ Implementations differ") @@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): batch_size_range = [2**i for i in range(0, 7, 2)] seq_length_range = [2**i for i in range(6, 11, 1)] head_num_range = [32, 48] -configs = list( - itertools.product(head_num_range, batch_size_range, seq_length_range)) +configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range)) def get_benchmark(use_residual): - @triton.testing.perf_report( triton.testing.Benchmark( x_names=["head_num", "batch_size", "seq_len"], @@ -167,19 +160,15 @@ def get_benchmark(use_residual): line_names=["HuggingFace", "FlashInfer", "vLLM"], styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="us", - plot_name= - f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", args={}, - )) + ) + ) def benchmark(head_num, batch_size, seq_len, provider): dtype = torch.bfloat16 hidden_size = head_num * 128 # assuming head_dim = 128 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="cuda") + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda") residual = torch.randn_like(x) if use_residual else None @@ -240,9 +229,9 @@ def benchmark(head_num, batch_size, seq_len, provider): default=4096, help="Hidden size (2nd dimension) of the sequence", ) - parser.add_argument("--use-residual", - action="store_true", - help="Whether to use residual connection") + parser.add_argument( + "--use-residual", action="store_true", help="Whether to use residual connection" + ) parser.add_argument( "--save-path", type=str, @@ -253,10 +242,12 @@ def benchmark(head_num, batch_size, seq_len, provider): args = parser.parse_args() # Run correctness test - calculate_diff(batch_size=args.batch_size, - seq_len=args.seq_len, - hidden_size=args.hidden_size, - use_residual=args.use_residual) + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual, + ) # Get the benchmark function with proper use_residual setting benchmark = get_benchmark(args.use_residual) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 05d24fc4b16d..110d36db157f 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -6,8 +6,7 @@ import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, - get_rope) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora( # silulating serving 4 LoRAs scaling_factors = [1, 2, 4, 8] # batched RoPE can take multiple scaling factors - batched_rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_style, { - "rope_type": "linear", - "factor": tuple(scaling_factors) - }) + batched_rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + {"rope_type": "linear", "factor": tuple(scaling_factors)}, + ) # non-batched RoPE takes only one scaling factor, we create multiple # instances to simulate the same behavior non_batched_ropes: list[RotaryEmbedding] = [] for scaling_factor in scaling_factors: non_batched_ropes.append( - get_rope(head_size, rotary_dim, max_position, base, is_neox_style, - { - "rope_type": "linear", - "factor": (scaling_factor, ) - })) + get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + {"rope_type": "linear", "factor": (scaling_factor,)}, + ) + ) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) + query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) key = torch.randn_like(query) # create query offsets for batched RoPE, we concat multiple kv cache # together and each query needs to find the right kv cache of its type offset_map = torch.tensor( list( - accumulate([0] + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ]))) - query_types = torch.randint(0, - len(scaling_factors), (batch_size, seq_len), - device=device) + accumulate( + [0] + + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ] + ) + ) + ) + query_types = torch.randint( + 0, len(scaling_factors), (batch_size, seq_len), device=device + ) # map query types to offsets query_offsets = offset_map[query_types] # the kernel takes flattened offsets @@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora( torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the rotary embedding kernels.") + description="Benchmark the rotary embedding kernels." + ) parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--num-heads", type=int, default=8) - parser.add_argument("--head-size", - type=int, - choices=[64, 80, 96, 112, 120, 128, 192, 256], - default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) - parser.add_argument("--dtype", - type=str, - choices=["bfloat16", "float"], - default="float") + parser.add_argument( + "--dtype", type=str, choices=["bfloat16", "float"], default="float" + ) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--device", - type=str, - choices=["cuda:0", "cuda:1"], - default="cuda:0") + parser.add_argument( + "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" + ) args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 8f07bc8ca52e..6315c1ee6cdd 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -14,14 +14,16 @@ import triton from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - _w8a8_block_fp8_matmul) + _w8a8_block_fp8_matmul, +) from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser mp.set_start_method("spawn", force=True) -assert current_platform.is_cuda( -), "Only support tune w8a8 block fp8 kernel on CUDA device." +assert current_platform.is_cuda(), ( + "Only support tune w8a8 block fp8 kernel on CUDA device." +) DTYPE_MAP = { "float32": torch.float32, @@ -40,7 +42,7 @@ def w8a8_block_matmul( config: dict[str, Any], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - """This function performs matrix multiplication with + """This function performs matrix multiplication with block-wise quantization. It takes two input tensors `A` and `B` with scales `As` and `Bs`. @@ -51,7 +53,7 @@ def w8a8_block_matmul( B: The input tensor, e.g., weight. As: The per-token-group quantization scale for `A`. Bs: The per-block quantization scale for `B`. - block_size: The block size for per-block quantization. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. output_dytpe: The dtype of the returned tensor. @@ -71,18 +73,18 @@ def w8a8_block_matmul( assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N, ) + C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) def grid(META): - return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * - triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) if A.dtype == torch.float8_e4m3fn: kernel = _w8a8_block_fp8_matmul else: - raise RuntimeError( - "Currently, only support tune w8a8 block fp8 kernel.") + raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") kernel[grid]( A, @@ -119,14 +121,16 @@ def get_configs_compute_bound(): for block_n in [32, 64, 128, 256]: for num_warps in [4, 8]: for group_size in [1, 16, 32, 64]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) return configs @@ -165,15 +169,9 @@ def get_weight_shapes(tp_size): return weight_shapes -def benchmark_config(A, - B, - As, - Bs, - block_size, - config, - out_dtype=torch.float16, - num_iters=10): - +def benchmark_config( + A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10 +): def run(): w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) @@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): fp8_max, fp8_min = fp8_info.max, fp8_info.min A_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * - fp8_max) + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) B_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * - fp8_max) + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) else: - raise RuntimeError( - "Currently, only support tune w8a8 block fp8 kernel.") + raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - As = torch.rand(M, k_tiles, dtype=torch.float32, - device="cuda") * factor_for_scale - Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") * - factor_for_scale) + As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") + * factor_for_scale + ) best_config = None best_time = float("inf") @@ -267,7 +265,8 @@ def save_configs( device_name = current_platform.get_device_name().replace(" ", "_") json_file_name = ( f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8," - f"block_shape=[{block_n},{block_k}].json") + f"block_shape=[{block_n},{block_k}].json" + ) config_file_path = os.path.join(save_path, json_file_name) print(f"Writing best config to {config_file_path}...") @@ -295,8 +294,7 @@ def tune_on_gpu(args_dict): search_space = get_configs_compute_bound() search_space = [ - config for config in search_space - if block_k % config["BLOCK_SIZE_K"] == 0 + config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0 ] start = time.time() @@ -312,15 +310,11 @@ def tune_on_gpu(args_dict): out_dtype, search_space, input_type, - ) for batch_size in tqdm(batch_sizes, - desc=f"GPU {gpu_id} - Batch sizes") + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") ] - best_configs = { - M: config - for M, config in zip(batch_sizes, benchmark_results) - } - save_configs(N, K, block_n, block_k, best_configs, save_path, - input_type) + best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)} + save_configs(N, K, block_n, block_k, best_configs, save_path, input_type) end = time.time() print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") @@ -376,13 +370,14 @@ def main(args): process_args = [] for gpu_id in range(num_gpus): - process_args.append({ - "gpu_id": gpu_id, - "batch_sizes": batches_per_gpu[gpu_id], - "weight_shapes": - weight_shapes, # Each GPU processes all weight shapes - "args": args, - }) + process_args.append( + { + "gpu_id": gpu_id, + "batch_sizes": batches_per_gpu[gpu_id], + "weight_shapes": weight_shapes, # Each GPU processes all weight shapes + "args": args, + } + ) ctx = mp.get_context("spawn") with ctx.Pool(num_gpus) as pool: @@ -398,13 +393,11 @@ def main(args): python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8 Then copy to model_executor/layers/quantization/utils/configs """, - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) parser.add_argument("--tp-size", "-tp", type=int, default=8) - parser.add_argument("--input-type", - type=str, - choices=["fp8"], - default="fp8") + parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8") parser.add_argument( "--out-dtype", type=str, diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 5fa55bb974e1..e37764825451 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -11,7 +11,9 @@ # Import vLLM functions from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) from vllm.triton_utils import triton diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index bd62173a7b3a..0c86e4072957 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -2,11 +2,11 @@ import math import pickle -import re from collections import defaultdict import matplotlib.pyplot as plt import pandas as pd +import regex as re import seaborn as sns from torch.utils.benchmark import Measurement as TMeasurement @@ -14,13 +14,14 @@ if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the latency of processing a single batch of ' - 'requests till completion.') - parser.add_argument('filename', type=str) + description="Benchmark the latency of processing a single batch of " + "requests till completion." + ) + parser.add_argument("filename", type=str) args = parser.parse_args() - with open(args.filename, 'rb') as f: + with open(args.filename, "rb") as f: data = pickle.load(f) raw_results: list[TMeasurement] = data["results"] @@ -38,11 +39,7 @@ raise Exception("MKN not found") kernel = v.task_spec.description - results[KN].append({ - "kernel": kernel, - "batch_size": M, - "median": v.median - }) + results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median}) rows = int(math.ceil(len(results) / 2)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) @@ -50,14 +47,16 @@ for axs_idx, (shape, data) in enumerate(results.items()): plt.sca(axs[axs_idx]) df = pd.DataFrame(data) - sns.lineplot(data=df, - x="batch_size", - y="median", - hue="kernel", - style="kernel", - markers=True, - dashes=False, - palette="Dark2") + sns.lineplot( + data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2", + ) plt.title(f"Shape: {shape}") plt.ylabel("time (median, s)") plt.tight_layout() diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index ac64f786f184..877a29feed9d 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -23,6 +23,7 @@ class ArgPool: For every invocation during a benchmarking run, it will choose a different value from the list. """ + values: Iterable[Any] def __getitem__(self, index): @@ -30,9 +31,7 @@ def __getitem__(self, index): class Bench: - class ArgsIterator: - def __init__(self, args_list, kwargs_list): assert len(args_list) == len(kwargs_list) self.args_list = args_list @@ -53,10 +52,16 @@ def reset(self): def n_args(self): return self.n - def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], - label: str, sub_label: str, description: str, fn: Callable, - *args, **kwargs): - + def __init__( + self, + cuda_graph_params: Optional[CudaGraphBenchParams], + label: str, + sub_label: str, + description: str, + fn: Callable, + *args, + **kwargs, + ): self.cuda_graph_params = cuda_graph_params self.use_cuda_graph = self.cuda_graph_params is not None self.label = label @@ -67,10 +72,8 @@ def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], # Process args self._args = args self._kwargs = kwargs - self.args_list, self.kwargs_list = self.collapse_argpool( - *args, **kwargs) - self.args_iterator = self.ArgsIterator(self.args_list, - self.kwargs_list) + self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list) # Cudagraph runner self.g = None @@ -100,16 +103,13 @@ def collapse_argpool(self, *args, **kwargs): for i in range(argpool_size): # collapse args; Just pick the ith value - args_list[i] = tuple([ - arg[i] if isinstance(arg, ArgPool) else arg - for arg in args_list[i] - ]) + args_list[i] = tuple( + [arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]] + ) # collapse kwargs kwargs_i = kwargs_list[i] - arg_pool_keys = [ - k for k, v in kwargs_i.items() if isinstance(v, ArgPool) - ] + arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)] for k in arg_pool_keys: # again just pick the ith value kwargs_i[k] = kwargs_i[k][i] @@ -142,7 +142,7 @@ def get_cuda_graph_runner(self): def run_cudagrah(self) -> TMeasurement: assert self.use_cuda_graph - globals = {'g': self.g} + globals = {"g": self.g} return TBenchmark.Timer( stmt="g.replay()", @@ -162,15 +162,15 @@ def run_eager(self) -> TMeasurement: has_arg_pool = self.args_iterator.n_args > 1 if has_arg_pool: - setup = ''' + setup = """ args_iterator.reset() args_it = args_iterator.__next__() - ''' - stmt = ''' + """ + stmt = """ args, kwargs = next(args_it) fn(*args, **kwargs) - ''' - globals = {'fn': self.fn, 'args_iterator': self.args_iterator} + """ + globals = {"fn": self.fn, "args_iterator": self.args_iterator} else: # no arg pool. Just use the args and kwargs directly self.args_iterator.reset() @@ -178,10 +178,10 @@ def run_eager(self) -> TMeasurement: args, kwargs = next(args_it) setup = "" - stmt = ''' + stmt = """ fn(*args, **kwargs) - ''' - globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} + """ + globals = {"fn": self.fn, "args": args, "kwargs": kwargs} return TBenchmark.Timer( stmt=stmt, diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py index 5f94552e9dc8..d5701a8fbd6d 100644 --- a/benchmarks/overheads/benchmark_hashing.py +++ b/benchmarks/overheads/benchmark_hashing.py @@ -7,9 +7,8 @@ from vllm.utils import FlexibleArgumentParser # A very long prompt, total number of tokens is about 15k. -LONG_PROMPT = ["You are an expert in large language models, aren't you?" - ] * 1000 -LONG_PROMPT = ' '.join(LONG_PROMPT) +LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000 +LONG_PROMPT = " ".join(LONG_PROMPT) def main(args): @@ -30,32 +29,35 @@ def main(args): print("------start generating------") for i in range(3): - profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', - globals(), locals()) + profiler.runctx( + "llm.generate(LONG_PROMPT, sampling_params)", globals(), locals() + ) # analyze the runtime of hashing function stats = pstats.Stats(profiler) - stats.sort_stats('cumulative') + stats.sort_stats("cumulative") total_time = 0 total_calls = 0 for func in stats.stats: - if 'hash_of_block' in func[2]: + if "hash_of_block" in func[2]: total_time = stats.stats[func][3] total_calls = stats.stats[func][0] percentage = (total_time / stats.total_tt) * 100 - print(f"Hashing took {total_time:.2f} seconds," - f"{percentage:.2f}% of the total runtime.") + print( + f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime." + ) if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the performance of hashing function in' - 'automatic prefix caching.') - parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) - parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='enable prefix caching') + description="Benchmark the performance of hashing function in" + "automatic prefix caching." + ) + parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k") + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--output-len", type=int, default=10) + parser.add_argument( + "--enable-prefix-caching", action="store_true", help="enable prefix caching" + ) args = parser.parse_args() main(args) diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml new file mode 100644 index 000000000000..65b1e09a247e --- /dev/null +++ b/benchmarks/pyproject.toml @@ -0,0 +1,49 @@ +# This local pyproject file is part of the migration from yapf to ruff format. +# It uses the same core rules as the main pyproject.toml file, but with the +# following differences: +# - ruff line length is overridden to 88 +# - deprecated typing ignores (UP006, UP035) have been removed + +[tool.ruff] +line-length = 88 + +[tool.ruff.lint.per-file-ignores] +"vllm/third_party/**" = ["ALL"] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + +[tool.ruff.lint.isort] +known-first-party = ["vllm"] + +[tool.ruff.format] +docstring-code-format = true \ No newline at end of file diff --git a/benchmarks/run_structured_output_benchmark.sh b/benchmarks/run_structured_output_benchmark.sh index 53dc7ed70b9c..b043ab83e460 100755 --- a/benchmarks/run_structured_output_benchmark.sh +++ b/benchmarks/run_structured_output_benchmark.sh @@ -1,32 +1,98 @@ #!/bin/bash -# Define the model to use -MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"} - -# Define the backend to use -BACKEND=${2:-"vllm"} - -# Define the dataset to use -DATASET=${3:-"xgrammar_bench"} - +# default values +MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"} +BACKEND=${BACKEND:-"vllm"} +DATASET=${DATASET:-"xgrammar_bench"} SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -OUTPUT_DIR=${4:-"$SCRIPT_DIR/structured_output_benchmark_results"} +OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"} +PORT=${PORT:-8000} +STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1} +TOTAL_SECONDS=${TOTAL_SECONDS:-90} +MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300} +TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"} -GUIDED_RATIO=${5:-0.5} +usage() { + echo "Usage: $0 [options]" + echo "Options:" + echo " --model MODEL Model to benchmark (default: $MODEL)" + echo " --backend BACKEND Backend to use (default: $BACKEND)" + echo " --dataset DATASET Dataset to use (default: $DATASET)" + echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)" + echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)" + echo " --port PORT Port to use (default: $PORT)" + echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)" + echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)" + echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)" + echo " -h, --help Show this help message and exit" + exit 0 +} + +# parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --backend) + BACKEND="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --max-new-tokens) + MAX_NEW_TOKENS="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --structured-output-ratio) + STRUCTURED_OUTPUT_RATIO="$2" + shift 2 + ;; + --tokenizer-mode) + TOKENIZER_MODE="$2" + shift 2 + ;; + --total-seconds) + TOTAL_SECONDS="$2" + shift 2 + ;; + -h|--help) + usage + ;; + *) + echo "Unknown argument: $1\n" + usage + ;; + esac +done # Create output directory if it doesn't exist mkdir -p "$OUTPUT_DIR" # Define QPS values to test -QPS_VALUES=(70 60 50 25 20 15 10) +QPS_VALUES=(25 20 15 10 5 1) # Common parameters COMMON_PARAMS="--backend $BACKEND \ --model $MODEL \ --dataset $DATASET \ - --structured-output-ratio $GUIDED_RATIO \ + --structured-output-ratio $STRUCTURED_OUTPUT_RATIO \ --save-results \ - --result-dir $OUTPUT_DIR" + --result-dir $OUTPUT_DIR \ + --output-len $MAX_NEW_TOKENS \ + --port $PORT \ + --tokenizer-mode $TOKENIZER_MODE" echo "Starting structured output benchmark with model: $MODEL" echo "Backend: $BACKEND" @@ -45,12 +111,15 @@ for qps in "${QPS_VALUES[@]}"; do # Construct filename for this run FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" + NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc) + NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part + echo "Running benchmark with $NUM_PROMPTS prompts" + # Run the benchmark python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \ --request-rate $qps \ --result-filename "$FILENAME" \ - --tokenizer-mode ${TOKENIZER_MODE:-"auto"} \ - --port ${PORT:-8000} + --num-prompts $NUM_PROMPTS echo "Completed benchmark with QPS: $qps" echo "----------------------------------------" diff --git a/cmake/utils.cmake b/cmake/utils.cmake index c9cd099b82a7..12e4e39024f5 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs) "${multiValueArgs}" ${ARGN} ) foreach(_ARCH ${arg_CUDA_ARCHS}) - string(REPLACE "." "" _ARCH "${_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_ARCH}" - CODE "sm_${_ARCH}") + # handle +PTX suffix: generate both sm and ptx codes if requested + string(FIND "${_ARCH}" "+PTX" _HAS_PTX) + if(NOT _HAS_PTX EQUAL -1) + string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}") + string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "compute_${_STRIPPED_ARCH}") + else() + string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + endif() endforeach() if (${arg_BUILD_PTX_FOR_ARCH}) @@ -251,7 +266,10 @@ endmacro() # # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form # `.[letter]` compute the "loose intersection" with the -# `TGT_CUDA_ARCHS` list of gencodes. +# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in +# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there +# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the +# architecture in `SRC_CUDA_ARCHS`. # The loose intersection is defined as: # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } # where `<=` is the version comparison operator. @@ -268,44 +286,63 @@ endmacro() # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" # +# Example With PTX: +# SRC_CUDA_ARCHS="8.0+PTX" +# TGT_CUDA_ARCHS="9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0+PTX" +# function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) - list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) - set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) + set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}") + set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS}) + + # handle +PTX suffix: separate base arch for matching, record PTX requests + set(_PTX_ARCHS) + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "\\+PTX$") + string(REPLACE "+PTX" "" _base "${_arch}") + list(APPEND _PTX_ARCHS "${_base}") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + list(APPEND _SRC_CUDA_ARCHS "${_base}") + endif() + endforeach() + list(REMOVE_DUPLICATES _PTX_ARCHS) + list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) - if ("9.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") - if ("9.0" IN_LIST TGT_CUDA_ARCHS_) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") + if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") + if ("9.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") set(_CUDA_ARCHS "9.0a") endif() endif() - if ("10.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") + if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") if ("10.0" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") + list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") set(_CUDA_ARCHS "10.0a") endif() endif() - list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that # is less or equal to ARCH (but has the same major version since SASS binary # compatibility is only forward compatible within the same major version). - foreach(_ARCH ${TGT_CUDA_ARCHS_}) + foreach(_ARCH ${_TGT_CUDA_ARCHS}) set(_TMP_ARCH) # Extract the major version of the target arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") - foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS}) # Extract the major version of the source arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") - # Check major-version match AND version-less-or-equal + # Check version-less-or-equal, and allow PTX arches to match across majors if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) - if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) + if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) set(_TMP_ARCH "${_SRC_ARCH}") endif() else() @@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endforeach() list(REMOVE_DUPLICATES _CUDA_ARCHS) + + # reapply +PTX suffix to architectures that requested PTX + set(_FINAL_ARCHS) + foreach(_arch ${_CUDA_ARCHS}) + if(_arch IN_LIST _PTX_ARCHS) + list(APPEND _FINAL_ARCHS "${_arch}+PTX") + else() + list(APPEND _FINAL_ARCHS "${_arch}") + endif() + endforeach() + set(_CUDA_ARCHS ${_FINAL_ARCHS}) + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) endfunction() diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83a..55e659679701 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ + if (num_tokens == 0) { \ + return; \ + } \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index eb216dc8baf1..79a546554fa1 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -172,7 +172,7 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in + // For example, if the thread group size is 4, then the first thread in // the group has 0, 4, 8, ... th vectors of the query, and the second thread // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because // q is split from a qkv tensor, it may not be contiguous. @@ -259,7 +259,7 @@ __device__ void paged_attention_kernel( // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in + // For example, if the thread group size is 4, then the first thread in // the group has 0, 4, 8, ... th vectors of the key, and the second thread // has 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { diff --git a/csrc/attention/vertical_slash_index.cu b/csrc/attention/vertical_slash_index.cu new file mode 100644 index 000000000000..c1b45b143f4e --- /dev/null +++ b/csrc/attention/vertical_slash_index.cu @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include + +#include + +__device__ int64_t save_blocks(int* block_offset, int64_t range_start, + int64_t range_end, int64_t block_size, + int64_t input_block_count, int64_t kv_seqlen) { + if (range_start >= kv_seqlen) { + return input_block_count; + } + if (range_end > kv_seqlen) { + range_end = kv_seqlen; + } + int64_t current_block_count = input_block_count; + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[current_block_count++] = idx; + } + return current_block_count; +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count, + block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, + BLOCK_SIZE_N, NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * This function builds the index of each row of blocks from vertical indices + * and slash indices. The vertical indices are treated as points, while the + * slash indices are converted as ranges. The output consists of the merged + * ranges and separate column indices, where the ranges are represented by + * block indices. + * + * The implementation is referenced from the original MInference repo: + * https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu. + */ +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + block_count.data_ptr(), block_offset.data_ptr(), + column_count.data_ptr(), column_index.data_ptr(), batch_size, + num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash, + causal); +} + +__global__ void convert_vertical_slash_indexes_kernel_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + const int* per_head_vertical_topkv, const int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + // MergeHead: each head has it's unique max topk NNZ_V๏ผŒNNZ_S. (NNZ_V๏ผŒNNZ_S + // above is buffer size, use to compute offset) + NNZ_S = per_head_slash_topkv[head_idx]; + NNZ_V = per_head_vertical_topkv[head_idx]; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* per_head_vertical_topkv, int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel_mergehead<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, + per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset, + column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, + NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * Like the above convert_vertical_slash_indexes, but with + * pre-computed vertical and slash counts. + */ +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, // [N_HEADS, ] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64_mergehead( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + vertical_indices_count.data_ptr(), + slash_indices_count.data_ptr(), block_count.data_ptr(), + block_offset.data_ptr(), column_count.data_ptr(), + column_index.data_ptr(), batch_size, num_heads, num_rows, + block_size_M, block_size_N, nnz_vertical, nnz_slash, causal); +} diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index c2ae554c9f8e..d0f85e23609b 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8); static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); +static inline constexpr auto kFE2M1f = + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = @@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index cf67847b45ba..9a613ba588dd 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -19,6 +19,7 @@ namespace vec_op { #define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index dbe0e30f5cbf..195872e8edd3 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -15,15 +15,6 @@ cutlassGetStatusString(error)); \ } -/** - * Panic wrapper for unwinding CUDA runtime errors - */ -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ - } - inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, @@ -59,3 +50,13 @@ struct enable_sm90_only : Kernel { #endif } }; + +template +struct enable_sm100_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index dc6e0769b878..f7b75c48373f 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -65,5 +65,19 @@ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__) + #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 98daf1a1b8e6..f62d08c17c6d 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -13,6 +13,10 @@ #include #include +#ifdef USE_ROCM + namespace cub = hipcub; +#endif + #include "static_switch.h" @@ -501,15 +505,9 @@ void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { auto kernel = &causal_conv1d_fwd_kernel; if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. C10_CUDA_CHECK(cudaFuncSetAttribute( (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif } kernel<<>>(params); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index bd0a34119c82..0c9df925bdbf 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -321,7 +321,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { auto kernel = &selective_scan_fwd_kernel; if (kSmemSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 902bcd9dfd21..15f008d4f61e 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -31,7 +31,10 @@ # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] @@ -39,7 +42,7 @@ # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 2, 4, 8] +GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] DTYPES = ["fp16", "bf16"] @@ -72,6 +75,12 @@ def generate_new_kernels(): # for fp8 if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index c40c33d01f37..537282aba8c8 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,17 +7,18 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ - const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index c9e199bcea1f..1c255396099d 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -301,9 +301,11 @@ __global__ void Marlin( int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids const int32_t* __restrict__ expert_ids_ptr, // moe expert ids const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens @@ -341,6 +343,16 @@ __global__ void Marlin( extern __shared__ int4 sh[]; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || + w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = + !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); + + scalar_t2 global_scale; + constexpr bool has_act_order = group_blocks == 0; constexpr int pack_factor = 32 / w_type.size_bits(); @@ -348,7 +360,8 @@ __global__ void Marlin( constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; - const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int scales_expert_stride = + prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); @@ -460,9 +473,16 @@ __global__ void Marlin( if (mul_topk_weights) { #pragma unroll for (int i = 0; i < 4; i++) { - sh_block_topk_weights[tid4 * 4 + i] = - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + int idx = tid4 * 4 + i; + idx = idx < block_num_valid_tokens ? idx : 0; + if constexpr (w_type == vllm::kFE2M1f) { + sh_block_topk_weights[idx] = __hmul2( + global_scale, Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); + } else { + sh_block_topk_weights[idx] = Dtype::num2num2( + Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + } } } } @@ -493,6 +513,11 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[expert_id]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; if constexpr (has_zp) { @@ -606,7 +631,7 @@ __global__ void Marlin( constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -664,7 +689,8 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } @@ -688,10 +714,20 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; else @@ -801,7 +837,7 @@ __global__ void Marlin( sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < act_s_max_num_groups) { + if (sh_num_groups > act_s_max_num_groups) { sh_num_groups = act_s_max_num_groups; } @@ -1021,12 +1057,19 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + if constexpr (w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } } } @@ -1199,22 +1242,7 @@ __global__ void Marlin( }; auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - if constexpr (has_zp && is_zp_float || !has_zp) { - dequant(q, frag_b_ptr); - } else { - static_assert(has_zp && !is_zp_float); - static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); - // If (has_zp && !is_zp_float), - // we use not-zp version `dequant` function - // to improve numerical accuracy. - // Since both weight and zero point are dequanted using this logic, - // the final dequanted weight would be correct. - if constexpr (w_type_id == vllm::kU4.id()) { - dequant(q, frag_b_ptr); - } else if constexpr (w_type_id == vllm::kU8.id()) { - dequant(q, frag_b_ptr); - } - } + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. @@ -1244,13 +1272,23 @@ __global__ void Marlin( dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); } } - if constexpr (has_zp && is_zp_float) { + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if (is_new_zp) { reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; } } + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1259,7 +1297,10 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { @@ -1272,6 +1313,11 @@ __global__ void Marlin( dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { static_assert(group_blocks != -1); @@ -1279,7 +1325,8 @@ __global__ void Marlin( act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1287,7 +1334,7 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && group_blocks != -1) { + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); @@ -1554,10 +1601,17 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && !has_zp) { + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } + if constexpr (w_type == vllm::kFE2M1f) { + if (!mul_topk_weights) { + res = __hmul2(res, global_scale); + } + } + if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; @@ -1648,7 +1702,9 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - fetch_col_scale_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } } } fetch_to_shared(i, i, i < slice_iters, i); @@ -1711,17 +1767,20 @@ __global__ void Marlin( if constexpr (has_act_order) { slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_act_order_scales_to_shared(false, first_group_id, - last_group_id); - __syncthreads(); + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } } } if (slice_iters == 0) { @@ -1737,7 +1796,8 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1747,7 +1807,8 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1771,7 +1832,8 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && !has_zp) { + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 00b4e934cc39..2cff04f699b0 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ @@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) + #define BIGGROUP_GET_IF(W_TYPE) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ @@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, BIGGROUP_GET_IF(vllm::kFE4M3fn) + FP4_GET_IF(vllm::kFE2M1f) + ACT_GET_IF(vllm::kU4B8) ACT_GET_IF(vllm::kU8B128) @@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* zp, void* g_idx, void* perm, void* a_tmp, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, @@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, bool m_block_size_8 = moe_block_size == 8; if (has_zp) { - TORCH_CHECK(q_type == vllm::kU4, - "q_type must be u4 when has_zp = True. Got = ", q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn, - "q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " - "False. Got = ", - q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); @@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm( } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + torch::Tensor b_zeros; if (b_zeros_or_none.has_value()) { b_zeros = b_zeros_or_none.value(); @@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm( if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn, - "b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " - "False. Got = ", + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", b_q_type.str()); } @@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), @@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm( at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index d7be769458e3..6b6a9d04a60f 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, } if (use_global_memory) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, cumsum_buffer.data_ptr()); }); } else if (use_i16) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // set dynamic shared mem auto kernel = @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.numel()); }); } else { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto kernel = vllm::moe::moe_align_block_size_kernel; @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, TORCH_CHECK(num_experts == 256, "sgl_moe_align_block_size kernel only supports deepseek v3."); - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 0bae119a7c46..8fda434d452f 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -28,4 +28,6 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); -#endif \ No newline at end of file +#endif + +bool moe_permute_unpermute_supported(); \ No newline at end of file diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 76d5f0eab021..9a7465261abf 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -5,6 +5,9 @@ #include "permute_unpermute_kernels/dispatch.h" #include "core/registration.h" +// moe_permute kernels require at least CUDA 12.0 +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + void moe_permute( const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& topk_weights, //[n_token, topk] @@ -127,7 +130,45 @@ void moe_unpermute( }); } +#else + +void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, + torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indicies, + const std::optional& expert_map, + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, + torch::Tensor& src_row_id2dst_row_id_map, + torch::Tensor& m_indices) { + TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); +} + +void moe_unpermute(const torch::Tensor& input, + const torch::Tensor& topk_weights, torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indicies, + const std::optional& expert_map, + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, + torch::Tensor& src_row_id2dst_row_id_map, + torch::Tensor& m_indices) { + TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); +} + +#endif + +bool moe_permute_unpermute_supported() { +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + return true; +#else + return false; +#endif +} + TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} \ No newline at end of file +} diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index aa353d0f0437..de2c153882d9 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -1,6 +1,9 @@ #include "moe_permute_unpermute_kernel.h" +// moe_permute kernels require at least CUDA 12.0 +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + // CubKeyValueSorter definition begin CubKeyValueSorter::CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} @@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size, int num_experts) { auto tidx = threadIdx.x; auto bidx = blockIdx.x; - auto lidx = tidx & 31; - auto widx = tidx >> 5; - auto warp_count = (blockDim.x + 31) >> 5; auto offset = bidx * blockDim.x; auto bound = min(offset + blockDim.x, size); extern __shared__ int smem_expert_map[]; @@ -226,4 +226,6 @@ void getMIndices(int64_t* expert_first_token_offset, expert_first_token_offset, align_expert_first_token_offset, m_indices, num_local_expert, align_block_size); } -} \ No newline at end of file +} + +#endif diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index de9747b60252..a9379032245d 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__ } } -template -__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, - int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +template +__launch_bounds__(TPB) __global__ void moeTopK( + const float* inputs_after_softmax, + const bool* finished, + float* output, + IndType* indices, + int* source_rows, + const int num_experts, + const int k, + const int start_expert, + const int end_expert) { using cub_kvp = cub::KeyValuePair; @@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, int* source_rows, const int k, const int start_expert, const int end_expert) { // We begin by enforcing compile time assertions and setting up compile time constants. @@ -397,8 +405,8 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; @@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f token_expert_indices, num_tokens, topk, 0, num_experts, \ stream); +template void topkGatingSoftmaxKernelLauncher( const float* gating_output, float* topk_weights, - int* topk_indicies, + IndType* topk_indicies, int* token_expert_indices, float* softmax_workspace, const int num_tokens, @@ -493,14 +502,32 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + + if(topk_indices.scalar_type() == at::ScalarType::Int) + { + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } + else + { + assert(topk_indices.scalar_type() == at::ScalarType::UInt32); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 2a8b9bb39caa..7d35ec79ead4 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Calculate the result of moe by summing up the partial results // from all selected experts. - m.def("moe_sum(Tensor! input, Tensor output) -> ()"); + m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.impl("moe_sum", torch::kCUDA, &moe_sum); // Aligning the number of tokens to be processed by each expert such @@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " + "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," @@ -76,7 +77,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " "expert_first_token_offset, int n_expert, int n_local_expert,int " "topk, Tensor! hidden_states)->()"); - // conditionally compiled so impl registration is in source file + + m.def("moe_permute_unpermute_supported() -> bool"); + m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 1dfd2e067e85..7044b4588b81 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse); + +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal); + +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, int64_t context_size, + int64_t block_size_M, int64_t block_size_N, bool causal); #endif void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, @@ -208,6 +233,12 @@ void cutlass_moe_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets); + void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -235,6 +266,12 @@ std::vector cutlass_sparse_compress(torch::Tensor const& a); void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index ef6dd1c0978d..266f2a0667a2 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding( // head_size] const scalar_t* cache_ptr, const int head_size, const int num_heads, const int num_kv_heads, const int rot_dim, const int token_idx, - const int64_t query_stride, const int64_t key_stride) { + const int64_t query_stride, const int64_t key_stride, + const int64_t head_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int64_t token_head = + token_idx * query_stride + head_idx * head_stride; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding( const int nk = num_kv_heads * embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int64_t token_head = + token_idx * key_stride + head_idx * head_stride; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel( const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } template @@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel( const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - // or [num_tokens] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } } // namespace vllm @@ -179,6 +183,12 @@ void rotary_embedding( int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -190,14 +200,14 @@ void rotary_embedding( positions.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, - num_heads, num_kv_heads, head_size); + head_stride, num_heads, num_kv_heads, head_size); } else { vllm::rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } @@ -263,6 +273,12 @@ void batched_rotary_embedding( int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -276,7 +292,7 @@ void batched_rotary_embedding( key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } else { vllm::batched_rotary_embedding_kernel <<>>( @@ -284,7 +300,7 @@ void batched_rotary_embedding( key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index acc3d6722028..67e9149c1379 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel( void silu_and_mul_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] torch::Tensor& scale) { - TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn || + out.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.size(-1) % 2 == 0); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index e79785827189..bf46cce60a23 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) { float dst = std::nearbyint(x); // saturate - dst = std::clamp(dst, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // dst = std::clamp(dst, i8_min, i8_max); + dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; return static_cast(dst); #else // CUDA path @@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static_cast(std::numeric_limits::max()); // saturate - int32_t dst = std::clamp(x, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // int32_t dst = std::clamp(x, i8_min, i8_max); + int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x; return static_cast(dst); #else // CUDA path diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 000000000000..84492553c02f --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,27 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK( + a.size(0) % 4 == 0, + "Input tensor must have a number of rows that is a multiple of 4. ", + "but got: ", a.size(0), " rows."); + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh new file mode 100644 index 000000000000..ef324364c6d5 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -0,0 +1,205 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +template +struct cutlass_3x_gemm_fp8_blockwise { + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using LayoutC = LayoutD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + // MMA and Cluster Tile Shapes + // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster + // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>; + static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); + static constexpr int ScaleGranularityM = + size<0>(MmaTileShape{}) / ScaleMsPerTile; + static constexpr int ScaleGranularityN = + size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); + static constexpr int ScaleGranularityK = + size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); + + // Shape of the threadblocks in a cluster + using ClusterShape_MNK = ClusterShape; + + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + // clang-format off + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; + // clang-format on + + using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = cute::make_shape(m, n, k, 1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto m = a.size(0); + auto k = a.size(1); + auto n = b.size(1); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) { + return std::ceil(static_cast(m) / tile1SM) * + std::ceil(static_cast(n) / tile1SM) >= + sms; + }; + bool use_2sm = should_use_2sm(m, n); + if (use_2sm) { + cutlass_gemm_caller_blockwise, Shape<_256, _1, _1>, + Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Shape<_128, _1, _1>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp new file mode 100644 index 000000000000..2ee6a19407f9 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -0,0 +1,75 @@ +#include +#include "cuda_utils.h" +#include "cutlass_extensions/common.hpp" + +template +void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias, + Fp8Func fp8_func, Int8Func int8_func, + BlockwiseFunc blockwise_func) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int M = a.size(0), N = b.size(1), K = a.size(1); + + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == torch::kFloat8_e4m3fn) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(a.dtype() == torch::kInt8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(false, "Int8 not supported for this architecture"); + } + } + } else { + TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); + TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); + int32_t version_num = get_sm_version_num(); + if (version_num >= 100) { + TORCH_CHECK( + a.size(0) == a_scales.size(0) && + cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), + "a_scale_group_shape must be [1, 128]."); + TORCH_CHECK( + cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && + cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), + "b_scale_group_shape must be [128, 128]."); + } else { + // TODO: Remove this after using cutlass sm90 blockwise scaling gemm + // kernel, or introducing ceil_div to the load_init() of mainloop. + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + } + + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + blockwise_func(c, a, b, a_scales, b_scales); + } +} diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 85272804774d..c1242fdb39da 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu index 459eb1bb76eb..0cbd5305e3c2 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm100 (Blackwell). @@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - TORCH_CHECK( - (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), - "Currently, block scaled fp8 gemm is not implemented for Blackwell"); - - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, - "Currently, only fp8 gemm is implemented for Blackwell"); - vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm100_fp8, + nullptr, // int8 not supported on SM100 + vllm::cutlass_scaled_mm_blockwise_sm100_fp8); } #endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu index bcb91040d5e2..211302171f07 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper). @@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - - if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == torch::kFloat8_e4m3fn) { - vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias); - } else { - TORCH_CHECK(a.dtype() == torch::kInt8); - vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias); - } - } else { - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); - - vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); - } + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm90_fp8, + vllm::cutlass_scaled_mm_sm90_int8, + vllm::cutlass_scaled_mm_blockwise_sm90_fp8); } void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 54b63894e4cb..e9b408fbf2ee 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -29,7 +29,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias); - +#endif +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 void cutlass_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, @@ -37,12 +38,6 @@ void cutlass_moe_mm_sm90( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void get_cutlass_moe_mm_data_caller( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); - #endif #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 @@ -53,6 +48,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif +#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ + defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); +#endif + void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -110,6 +114,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { #if defined CUDA_VERSION if (cuda_device_capability >= 90 && cuda_device_capability < 100) { return CUDA_VERSION >= 12000; + } else if (cuda_device_capability >= 100) { + return CUDA_VERSION >= 12080; } #endif @@ -117,7 +123,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { } bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { - // CUTLASS groped FP8 kernels need at least CUDA 12.3 + // CUTLASS grouped FP8 kernels need at least CUDA 12.3 // and SM90 (Hopper) #if defined CUDA_VERSION @@ -222,7 +228,8 @@ void get_cutlass_moe_mm_data( // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k); diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu new file mode 100644 index 000000000000..45ec3d29ce04 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -0,0 +1,402 @@ +#include +#include + +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include + +using namespace cute; + +template +__global__ void __get_group_gemm_starts( + ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, + ElementSF** a_scales_offsets, ElementSF** b_scales_offsets, + ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int, + ElementAB* b_base_as_int, ElementC* out_base_as_int, + ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int, + ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets, + const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes, + const int K, const int N) { + int64_t expert_id = threadIdx.x; + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + // Originally int32_t but upcasting to int64_t to avoid overflow + // during offset calculations + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t sf_offset = static_cast(sf_offsets[expert_id]); + // size for block in block scale. + int64_t group_size = 16; + int64_t m = static_cast(problem_sizes_as_shapes[expert_id * 3]); + int64_t n = static_cast(problem_sizes_as_shapes[expert_id * 3 + 1]); + int64_t k = static_cast(problem_sizes_as_shapes[expert_id * 3 + 2]); + assert((m >= 0 && n == N && k == K && k % 2 == 0) && + "unexpected problem sizes"); + + int64_t half_k = static_cast(k / 2); + int64_t group_k = static_cast(k / group_size); + // Shape of A as uint8/byte = [M, K // 2] + // Shape of B as uint8/byte = [E, N, K // 2] + a_offsets[expert_id] = a_base_as_int + expert_offset * half_k; + + b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k; + // Shape of C = [M, N] + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + // Shape of a_scale = [sum(sf_sizes), K // group_size] + a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k; + + assert((reinterpret_cast(a_scales_offsets[expert_id]) % 128) == + 0 && + "TMA requires 128-byte alignment"); + + // Shape of B scale = [E, N, K // group_size] + b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k; + assert((reinterpret_cast(b_scales_offsets[expert_id]) % 128) == + 0 && + "TMA requires 128-byte alignment"); + // Shape of alpha = [E] + alpha_offsets[expert_id] = alphas_base_as_int + expert_id; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape( + static_cast(m), static_cast(n), static_cast(k), 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape( + static_cast(m), static_cast(n), static_cast(k), 1)); +} + +#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \ + TENSOR_C_TYPE, C_TYPE, LayoutSFA, \ + LayoutSFB, ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + __get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(a_starts.data_ptr()), \ + static_cast(b_starts.data_ptr()), \ + static_cast(out_starts.data_ptr()), \ + static_cast(a_scales_starts.data_ptr()), \ + static_cast(b_scales_starts.data_ptr()), \ + static_cast(alpha_starts.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast(alphas.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(sf_offsets.data_ptr()), \ + static_cast(problem_sizes.data_ptr()), K, N); \ + } + +template +void run_get_group_gemm_starts( + const torch::Tensor& a_starts, const torch::Tensor& b_starts, + const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts, + const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts, + const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, + /*these are used for their base addresses*/ + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor const& out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& alphas, + torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets, + torch::Tensor const& problem_sizes, int M, int N, int K) { + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + TORCH_CHECK(out_tensors.size(1) == N, + "Output tensor shape doesn't match expected shape"); + TORCH_CHECK(K / 2 == b_tensors.size(2), + "b_tensors(dim = 2) and a_tensors(dim = 1) trailing" + " dimension must match"); + if (false) { + } + //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, + // ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE( + cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16, + cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t, + cutlass::float_ue4m3_t, torch::kFloat16, + half, LayoutSFA, LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +template +void run_fp4_blockwise_scaled_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm100; + using EpilogueOperatorClass = + cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag + using MainloopOperatorClass = + cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + + using ClusterShape = Shape<_1, _1, _1>; + struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _128, _128>; + using KernelSchedule = cutlass::gemm:: + KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + }; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape, + ClusterShape, Shape<_128, _64>, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentD, + typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA, + ElementB, LayoutB*, AlignmentB, ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = Gemm1SM; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); + torch::Tensor c_strides1 = + torch::full({num_experts}, output.stride(0), options_int); + torch::Tensor a_strides1 = + torch::full({num_experts}, a.stride(0) * 2, options_int); + torch::Tensor b_strides1 = + torch::full({num_experts}, b.stride(1) * 2, options_int); + + run_get_group_gemm_starts( + a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, + layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, + expert_offsets, sf_offsets, problem_sizes, M, N, K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm100GroupParams< + typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.get_device(); + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides1.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides1.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides1.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides1.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = + reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +#define CHECK_TYPE(x, st, m) \ + TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) \ + TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + // Input validation + CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); + CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale"); + CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales"); + CHECK_INPUT(alphas, at::ScalarType::Float, "alphas"); + + TORCH_CHECK(a_blockscale.dim() == 2, + "expected a_blockscale to be of shape [num_experts, rounded_m," + " k // group_size], observed rank: ", + a_blockscale.dim()) + TORCH_CHECK(b_blockscales.dim() == 3, + "expected b_blockscale to be of shape: " + " [num_experts, n, k // group_size], observed rank: ", + b_blockscales.dim()) + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have the shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32."); + + int M = static_cast(a.size(0)); + int N = static_cast(b.size(1)); + int E = static_cast(b.size(0)); + int K = static_cast(2 * b.size(2)); + + if (output.scalar_type() == torch::kBFloat16) { + run_fp4_blockwise_scaled_group_mm( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + } else { + run_fp4_blockwise_scaled_group_mm( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + } +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel, vLLM must " + "be compiled with ENABLE_NVFP4 for SM100+ and CUDA " + "12.8 or above."); +#endif +} diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu new file mode 100644 index 000000000000..076c4a085337 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -0,0 +1,404 @@ +#include + +#include +#include + +#include +#include + +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + // Find index within the experts. + int rowIdx_in_expert = 0; + int expert_idx = 0; + for (int i = 0; i < n_experts; i++) { + if (rowIdx >= input_offset_by_experts[i] && + rowIdx < input_offset_by_experts[i + 1]) { + rowIdx_in_expert = rowIdx - input_offset_by_experts[i]; + expert_idx = i; + break; + } + } + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + // The actual output_scales dim is computed from the padded numCols. + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = + SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void quant_impl(void* output, void* output_scale, void* input, + void* input_global_scale, void* input_offset_by_experts, + void* output_scale_offset_by_experts, int m_topk, int k, + int n_experts, cudaStream_t stream) { + // TODO: this multiProcessorCount should be cached. + int device; + cudaGetDevice(&device); + int multiProcessorCount; + cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, + device); + + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(k / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM)); + + cvt_fp16_to_fp4<<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), n_experts); +} + +/*Quantization entry for fp4 experts quantization*/ +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); + +constexpr auto HALF = at::ScalarType::Half; +constexpr auto BF16 = at::ScalarType::BFloat16; +constexpr auto FLOAT = at::ScalarType::Float; +constexpr auto INT = at::ScalarType::Int; +constexpr auto UINT8 = at::ScalarType::Byte; + +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + CHECK_INPUT(output, "output must be a CUDA tensor"); + CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); + CHECK_INPUT(input, "input must be a CUDA tensor"); + CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); + CHECK_INPUT(input_offset_by_experts, + "input_offset_by_experts must be a CUDA tensor"); + CHECK_INPUT(output_scale_offset_by_experts, + "output_scale_offset_by_experts must be a CUDA tensor"); + + TORCH_CHECK(output.dim() == 2); + TORCH_CHECK(output_scale.dim() == 2); + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(input_global_scale.dim() == 1); + TORCH_CHECK(input_offset_by_experts.dim() == 1); + TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + + TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + // output is uint8 (two nvfp4 values are packed into one uint8) + // output_scale is int32 (four fp8 values are packed into one int32) + TORCH_CHECK(output.scalar_type() == UINT8); + TORCH_CHECK(output_scale.scalar_type() == INT); + + const int BLOCK_SIZE = 16; + auto m_topk = input.size(0); + auto k = input.size(1); + TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto n_experts = input_global_scale.size(0); + TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output.size(0) == m_topk); + TORCH_CHECK(output.size(1) == k / 2); + int scales_k = k / BLOCK_SIZE; + // 4 means the swizzle requirement by nvidia nvfp4. + int padded_k = (scales_k + (4 - 1)) / 4 * 4; + // 4 means 4 fp8 values are packed into one int32 + TORCH_CHECK(output_scale.size(1) * 4 == padded_k); + + auto in_dtype = input.dtype(); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + if (in_dtype == at::ScalarType::Half) { + quant_impl(output.data_ptr(), output_scale.data_ptr(), + input.data_ptr(), input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, + n_experts, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), + input.data_ptr(), input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, + k, n_experts, stream); + } else { + TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); + } +} \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index b1426c43b456..badbb7e310df 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, torch::Tensor const& input_sf); #endif +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if defined ENABLE_NVFP4 && ENABLE_NVFP4 return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); +} + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return scaled_fp4_experts_quant_sm100a( + output, output_scale, input, input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, + "No compiled nvfp4 experts quantization kernel"); } diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index 7c10aaa81cf7..4e6118e52e8d 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -21,7 +21,13 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { // round float dst = std::nearbyint(x); // saturate - dst = std::clamp(dst, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // dst = std::clamp(dst, i8_min, i8_max); + dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; return static_cast(dst); #else // CUDA path diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index 3c0d77ac345d..ae0d6c0f2002 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -1,3 +1,67 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ #include "marlin_dtypes.cuh" @@ -27,7 +91,8 @@ __device__ inline uint32_t prmt(uint32_t a) { return res; } -template +template __device__ inline void dequant(int q, scalar_t2* frag_b); // @@ -40,7 +105,22 @@ __device__ inline void dequant(int q, scalar_t2* frag_b); // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -62,7 +142,14 @@ __device__ inline void dequant(int q, half2* frag_b) { } template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -84,7 +171,7 @@ __device__ inline void dequant(int q, half2* frag_b) { } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; @@ -96,39 +183,36 @@ __device__ inline void dequant( int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); // clang-format on - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; + dequant(q, frag_b); +} - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - // clang-format on +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC300C300; + static constexpr uint32_t SUB = 0x43004300; - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } // @@ -140,8 +224,8 @@ __device__ inline void dequant( // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 // template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -149,33 +233,42 @@ __device__ inline void dequant(int q, uint32_t lo = prmt(q); uint32_t hi = prmt(q); - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} - frag_b[0] = __hsub2(*reinterpret_cast(&lo), +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); +} - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - frag_b[0] = __hsub2(*reinterpret_cast(&lo), + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = @@ -200,7 +293,7 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = @@ -225,22 +318,30 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant( + int q, half2* frag_b) { // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 + constexpr int MASK = 0x7F007F00; // Extract and shift FP8 values to FP16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; // Construct and apply exponent bias constexpr int BIAS_OFFSET = @@ -248,28 +349,36 @@ __device__ inline void dequant(int q, const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); // Convert to half2 and apply bias - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 + constexpr int MASK = 0x7F007F00; // Extract and shift FP8 values to BF16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; // Construct and apply exponent bias constexpr int BIAS_OFFSET = @@ -281,9 +390,116 @@ __device__ inline void dequant( __float2bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, + nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } #endif diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 8b4b951f3d86..4ac7121ab4e1 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -31,7 +31,10 @@ # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] @@ -40,7 +43,7 @@ # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 2, 4, 8] +GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] DTYPES = ["fp16", "bf16"] @@ -73,6 +76,12 @@ def generate_new_kernels(): # for fp8 if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 02527a481661..4a242f2050d5 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -258,6 +258,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ @@ -314,6 +315,23 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ @@ -366,6 +384,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) + FP4_GET_IF(vllm::kFE2M1f) + BIGGROUP_GET_IF(vllm::kFE4M3fn) ACT_GET_IF(vllm::kU4B8) @@ -434,8 +454,8 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, int lda, void* workspace, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, + int prob_m, int prob_n, int prob_k, int lda, void* workspace, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k_init, @@ -446,11 +466,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, q_type == vllm::kU4 || q_type == vllm::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn, - "q_type must be uint4b8, uint8b128 or float8_e4m3fn when " - "has_zp = False. Got = ", - q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -483,6 +504,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -601,7 +623,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on @@ -617,6 +639,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -759,6 +782,17 @@ torch::Tensor gptq_marlin_gemm( } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + torch::Tensor b_zeros; if (b_zeros_or_none.has_value()) { b_zeros = b_zeros_or_none.value(); @@ -774,8 +808,9 @@ torch::Tensor gptq_marlin_gemm( "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn, - "b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when " + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " "has_zp = False. Got = ", b_q_type.str()); } @@ -820,22 +855,36 @@ torch::Tensor gptq_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, + has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else { diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index eb2700c95e86..f92056589d20 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -7,13 +7,14 @@ #include "marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ - const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \ - int prob_k, int lda, int *locks, bool use_atomic_add, \ - bool use_fp32_reduce, int max_shared_mem +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template ::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); + + scalar_t2 global_scale; + + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + constexpr bool has_act_order = group_blocks == 0; constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); @@ -481,7 +498,7 @@ __global__ void Marlin( constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -540,7 +557,8 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } @@ -564,10 +582,20 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; else @@ -681,7 +709,7 @@ __global__ void Marlin( sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < act_s_max_num_groups) { + if (sh_num_groups > act_s_max_num_groups) { sh_num_groups = act_s_max_num_groups; } @@ -887,12 +915,19 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + if constexpr (w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } } } @@ -1065,22 +1100,7 @@ __global__ void Marlin( }; auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - if constexpr (has_zp && is_zp_float || !has_zp) { - dequant(q, frag_b_ptr); - } else { - static_assert(has_zp && !is_zp_float); - static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); - // If (has_zp && !is_zp_float), - // we use not-zp version `dequant` function - // to improve numerical accuracy. - // Since both weight and zero point are dequanted using this logic, - // the final dequanted weight would be correct. - if constexpr (w_type_id == vllm::kU4.id()) { - dequant(q, frag_b_ptr); - } else if constexpr (w_type_id == vllm::kU8.id()) { - dequant(q, frag_b_ptr); - } - } + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. @@ -1110,13 +1130,23 @@ __global__ void Marlin( dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); } } - if constexpr (has_zp && is_zp_float) { + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if (is_new_zp) { reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; } } + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1125,7 +1155,10 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { @@ -1138,6 +1171,11 @@ __global__ void Marlin( dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { static_assert(group_blocks != -1); @@ -1145,7 +1183,8 @@ __global__ void Marlin( act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1153,7 +1192,7 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && group_blocks != -1) { + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); @@ -1408,10 +1447,15 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && !has_zp) { + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } + if constexpr (w_type == vllm::kFE2M1f) { + res = __hmul2(res, global_scale); + } + if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; @@ -1488,7 +1532,9 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - fetch_col_scale_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } } } fetch_to_shared(i, i, i < slice_iters); @@ -1542,16 +1588,20 @@ __global__ void Marlin( if constexpr (has_act_order) { slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } } } @@ -1563,7 +1613,8 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1573,7 +1624,8 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1597,7 +1649,8 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && !has_zp) { + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 8cc5a0f4f218..f1e7da164199 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -30,6 +30,14 @@ #define __HIP__GFX9__ #endif +#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__)) + #define __HIP__GFX11__ +#endif + +#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__)) + #define __HIP__GFX12__ +#endif + #if defined(NDEBUG) #undef NDEBUG #include @@ -43,7 +51,7 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#if defined(__HIP__GFX9__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 @@ -1482,191 +1490,1690 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#elif defined(__HIP__GFX11__) -// clang-format off -template -__global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, const float* k_scale, const float* v_scale) { - UNREACHABLE_CODE +using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; + +using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t; +union b16x8_u { + bit16x8 u16x8; + _B16x4 xy[2]; +}; +typedef b16x8_u _B16x8; + +using bit16x16 = + __attribute__((__vector_size__(16 * sizeof(uint16_t)))) uint16_t; +union b16x16_u { + bit16x16 u16x16; + _B16x8 xy[2]; +}; +typedef b16x16_u _B16x16; + +using _B8x8 = uint2; +using bit8_t = uint8_t; + +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; + +template +__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x16& inpA, + const bit16x16& inpB, + const floatx8& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(inpA, inpB, inpC); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(inpA, inpB, inpC); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { + if constexpr (std::is_same::value) { + union h2cvt { + __half2 h2[4]; + _B16x8 b16x8; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5])); + u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7])); + return u.b16x8; + } else if constexpr (std::is_same::value) { + union b2cvt { + __hip_bfloat162 b2[4]; + _B16x8 b16x8; + } u; + + u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1])); + u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3])); + u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5])); + u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7])); + + return u.b16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } } +// clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO> __global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] +__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] int max_ctx_blocks, const float* k_scale, const float* v_scale) { - UNREACHABLE_CODE -} + // clang-format on + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11 + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane2id = laneid % 2; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; -// Grid: (num_heads, num_seqs). -template -__global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { - UNREACHABLE_CODE -} -// clang-format on + const int seq_idx = blockIdx.x; + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) { + return; + } -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support + const int partition_idx = blockIdx.y; -#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma16_kernel \ - <<>>( \ - query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ - max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ - kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ - max_ctx_blocks, k_scale_ptr, v_scale_ptr); + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 -#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma4_kernel \ - <<>>( \ - query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ - max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ - kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ - max_ctx_blocks, k_scale_ptr, v_scale_ptr); + const int max_num_partitions = gridDim.y; -#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ - paged_attention_ll4mi_reduce_kernel \ - <<>>( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ - fp8_out_scale_ptr); + const int context_len = context_lens[seq_idx]; // length of a seq -template -void paged_attention_custom_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, const int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& context_lens, - const std::optional& query_start_loc, int max_context_len, - const std::optional& alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale, const std::optional& fp8_out_scale) { - int num_seqs = block_tables.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } - // NOTE: query start location is optional for V0 decode should not be used. - // If batch contains mix of prefills and decode, prefills should be skipped. - const int* query_start_loc_ptr = - query_start_loc - ? reinterpret_cast(query_start_loc.value().data_ptr()) - : nullptr; + constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2); - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x16 shared_logits[NWARPS][2][16][2]; - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); - const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - // NOTE: fp8_out_scale is optional. - const auto fp8_out_scale_ptr = - fp8_out_scale - ? static_cast(fp8_out_scale.value().data_ptr()) - : nullptr; - OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + // for QK wmma16x16, layout is QHead/Tokenx16 across every 16 lanes, + // 32 Bytes HeadElements in each lane, 2x16B HeadElements across a row of warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16 / 2; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across + // warp - const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + _B16x16 Qlocal[QKHELOOP / 2]; // note that 16 contiguous elements of Q should + // be fetched per lane for 16 bit cache types - // partition size is fixed at 256 since both mfma4 and mfma16 kernels support - // it mfma4 kernel also supports partition size 512 - constexpr int PARTITION_SIZE = 256; - const int max_num_partitions = - DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - const int gqa_ratio = num_heads / num_kv_heads; - assert(num_heads % num_kv_heads == 0); - assert(head_size == HEAD_SIZE); + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); - constexpr int NTHR = 256; - dim3 grid(num_seqs, max_num_partitions, num_kv_heads); - dim3 block(NTHR); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each wmma16x16x16 instruction processes 16 tokens - // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 - switch (gqa_ratio) { - case 1: - LAUNCH_CUSTOM_ATTENTION_MFMA4(1); - break; - case 2: - LAUNCH_CUSTOM_ATTENTION_MFMA4(2); - break; - case 3: - LAUNCH_CUSTOM_ATTENTION_MFMA4(3); - break; - case 4: - LAUNCH_CUSTOM_ATTENTION_MFMA4(4); - break; + _B16x16 Klocal[TLOOP] + [QKHELOOP / 2]; // can be interpreted as B8x16 for 8 bit types + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each wmma takes QH16xT16x16HE across warp + // repeat wmma across QKHELOOP dimension + // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens + // across 2 rows x 8 tokens per lane + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + + if (GQA_RATIO == 1) { + const int local_qhead_idx = lane16id % GQA_RATIO; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const scalar_t* q_ptr = + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; + if (lane16id < GQA_RATIO) { + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH * 2; + const _B16x16* q_fetch_ptr_32B = + reinterpret_cast(q_fetch_ptr); + Qlocal[qkhe_depth] = *q_fetch_ptr_32B; + } + } + } else { + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 2 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const scalar_t* q_ptr = + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + + const int offset1 = + lane16id / + 2; // 16 contiguous chunks of head elems are spread across 8x2lanes + shared_logits[offset1][lane2id][local_qhead_idx][0].xy[0] = tmp; + } + + __syncthreads(); + + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + Qlocal[qkhe_depth].xy[0] = + shared_logits[qkhe_depth][0][lane16id % GQA_RATIO][0].xy[0]; + Qlocal[qkhe_depth].xy[1] = + shared_logits[qkhe_depth][1][lane16id % GQA_RATIO][0].xy[0]; + } + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + constexpr int KX = 16 / sizeof(cache_t); + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = 0; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth / 2].xy[qkhe_depth % 2] = *k_fetch_ptr_16B; + } + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 32/1 = 32 vtokens per lane + constexpr int VBLOCKS_PER_LANE = 2; // assumes block size >=16 + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = DIVIDE_ROUND_UP( + (HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each + // wmma instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x16 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP / 2]; // this can be interpreted as B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + // v fetches are 16head elems across lanes x (16x2) tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth] + [vfetch_depth / VBLOCKS_PER_LANE]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + + (vfetch_depth % VBLOCKS_PER_LANE) * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth / 2].xy[vfetch_depth % 2] = + *v_fetch_ptr_16B; + } + } + } + + floatx8 dout[TLOOP]; + // qk wmma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + dout[token_depth] = gcn_wmma16x16x16_instr( + Klocal[token_depth][qkhe_depth].u16x16, Qlocal[qkhe_depth].u16x16, + dout[token_depth]); + } + dout[token_depth] *= scale; + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + 2 * i < context_len) + ? dout[token_depth][i] + : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16)); + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + 2 * i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + exp_sum += __shfl_xor(exp_sum, 16); + + __syncthreads(); + + if (laneid < 16) { + shared_qk_max[warpid][lane16id] = qk_max; + shared_exp_sum[warpid][lane16id] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_qk_max[w][lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // write logits to shared mem + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + shared_logits[warpid][token_depth][lane16id][0].xy[rowid] = + from_floatx8(dout[token_depth]); + } + __syncthreads(); + + _B16x8 swp_buf[TLOOP][2]; + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + swp_buf[token_depth][0] = + shared_logits[warpid][token_depth][lane16id][0].xy[0]; + swp_buf[token_depth][1] = + shared_logits[warpid][token_depth][lane16id][0].xy[1]; + } + + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + shared_logits[warpid][token_depth][lane16id][0].xy[rowid].u16x8[i] = + swp_buf[token_depth][i % 2].u16x8[4 * rowid + (i / 2)]; + } + } + + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + _B16x8 outelems[VHELOOP]; + // Softmax V wmma + // v layout: 16he across lanes x (16x2) tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx8 tmp_out = {0}; + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP / 2; + vfetch_depth++) { + const int offset = vfetch_depth; + // if output format is 16 qheads across 16 lanes, 16 head elems spread + // across rows + tmp_out = gcn_wmma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x16, + shared_logits[vtoken_depth][offset][lane16id][0].u16x16, tmp_out); + } + } + outelems[vhe_depth] = from_floatx8(tmp_out); + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid] = + outelems[vhe_depth]; // lane16 id head dimension; rowid head element + // dimension + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + swp_buf[vhe_depth][0] = shared_logits[warpid][vhe_depth][lane16id][0].xy[0]; + swp_buf[vhe_depth][1] = shared_logits[warpid][vhe_depth][lane16id][0].xy[1]; + } + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid].u16x8[i] = + swp_buf[vhe_depth][i % 2].u16x8[4 * rowid + (i / 2)]; + } + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO2]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + const int offset1 = (head_elem_idx / 16) % NWARPS; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row + vout[h] = + shared_logits[offset1][offset2][local_head_idx][0].xy[offset3]; + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { + const auto num_heads = gridDim.x; + const auto head_idx = blockIdx.x; + const auto seq_idx = blockIdx.y; + + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 32; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + + static_cast(head_idx) * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#elif defined(__HIP__GFX12__) + +using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; + +using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t; +union b16x8_u { + bit16x8 u16x8; + _B16x4 xy[2]; +}; +typedef b16x8_u _B16x8; + +using _B8x8 = uint2; +using bit8_t = uint8_t; + +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; + +template +__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x8& inpA, + const bit16x8& inpB, + const floatx8& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(inpA, inpB, inpC); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(inpA, inpB, inpC); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float_b16(const bit16_t& inp) { + union tmpcvt { + bit16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + t16.u = inp; + if constexpr (std::is_same::value) { + return (float)t16.f; + } else if constexpr (std::is_same::value) { + return __bfloat162float(t16.b); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { + if constexpr (std::is_same::value) { + union h2cvt { + __half2 h2[4]; + _B16x8 b16x8; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5])); + u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7])); + return u.b16x8; + } else if constexpr (std::is_same::value) { + union b2cvt { + __hip_bfloat162 b2[4]; + _B16x8 b16x8; + } u; + + u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1])); + u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3])); + u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5])); + u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7])); + + return u.b16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +// clang-format off +template +__global__ +__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + // clang-format on + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11 + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane2id = laneid % 2; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; // length of a seq + + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2); + + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x8 shared_logits[NWARPS][2][16][2]; + + // for QK wmma16x16_gfx12, layout is QHead/Tokenx16 across every 16 lanes, + // 16 Bytes HeadElements in each lane, 2x16B HeadElements across 2 rows of + // warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across + // warp + + _B16x8 Qlocal[QKHELOOP]; // note that 16 contiguous elements of Q should + // be fetched per lane for 16 bit cache types + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each wmma16x16x16 instruction processes 16 tokens + + _B16x8 Klocal[TLOOP] + [QKHELOOP]; // can be interpreted as B8x16 for 8 bit types + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each wmma takes QH16xT16x16HE across warp + // repeat wmma across QKHELOOP dimension + // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens + // across 2 rows x 8 tokens per lane + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + + if (GQA_RATIO == 1) { + const int local_qhead_idx = lane16id % GQA_RATIO; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const scalar_t* q_ptr = q + query_start_off * q_stride + + global_qhead_idx * HEAD_SIZE + + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + if (lane16id < GQA_RATIO) { + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + Qlocal[qkhe_depth] = *q_fetch_ptr_16B; + } + } + } else { + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 2 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const scalar_t* q_ptr = + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + + const int offset1 = + lane16id / + 2; // 16 contiguous chunks of head elems are spread across 8x2lanes + shared_logits[offset1][lane2id][local_qhead_idx][0] = tmp; + } + + __syncthreads(); + + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + Qlocal[qkhe_depth] = + shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0]; + } + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + constexpr int KX = 16 / sizeof(cache_t); + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 32/2 = 16 vtokens per lane + constexpr int VBLOCKS_PER_LANE = 1; // assumes block size >=16 + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = DIVIDE_ROUND_UP( + (HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each + // wmma instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP]; // this can be interpreted as B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE); + + // v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = 0; + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + + floatx8 dout[TLOOP]; + // qk wmma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + dout[token_depth] = gcn_wmma16x16x16_instr( + Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8, + dout[token_depth]); + } + dout[token_depth] *= scale; + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 8; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = + (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16)); + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + exp_sum += __shfl_xor(exp_sum, 16); + + __syncthreads(); + + if (laneid < 16) { + shared_qk_max[warpid][lane16id] = qk_max; + shared_exp_sum[warpid][lane16id] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_qk_max[w][lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // write logits to shared mem + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx8(dout[token_depth]); + } + + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + _B16x8 outelems[VHELOOP]; + // Softmax V wmma + // v layout: 16he across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx8 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int offset = rowid * VTLANELOOP + vfetch_depth; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // if output format is 16 qheads across 16 lanes, 16 head elems spread + // across rows + tmp_out = gcn_wmma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8, + shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8, + tmp_out); + } + } + outelems[vhe_depth] = from_floatx8(tmp_out); + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + shared_logits[warpid][vhe_depth][lane16id][rowid] = + outelems[vhe_depth]; // lane16 id head dimension; rowid head element + // dimension + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO2]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + const int offset1 = (head_elem_idx / 16) % NWARPS; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row + vout[h] = shared_logits[offset1][offset2][local_head_idx][offset3]; + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { + const auto num_heads = gridDim.x; + const auto head_idx = blockIdx.x; + const auto seq_idx = blockIdx.y; + + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 32; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + + static_cast(head_idx) * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#else + +// clang-format off +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { + UNREACHABLE_CODE +} +// clang-format on + +#endif + +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ + max_ctx_blocks, k_scale_ptr, v_scale_ptr); + +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ + max_ctx_blocks, k_scale_ptr, v_scale_ptr); + +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ + context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ + fp8_out_scale_ptr); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + const std::optional& query_start_loc, int max_context_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const std::optional& fp8_out_scale) { + int num_seqs = block_tables.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: query start location is optional for V0 decode should not be used. + // If batch contains mix of prefills and decode, prefills should be skipped. + const int* query_start_loc_ptr = + query_start_loc + ? reinterpret_cast(query_start_loc.value().data_ptr()) + : nullptr; + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + // NOTE: fp8_out_scale is optional. + const auto fp8_out_scale_ptr = + fp8_out_scale + ? static_cast(fp8_out_scale.value().data_ptr()) + : nullptr; + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + // partition size is fixed at 256 since both mfma4 and mfma16 kernels support + // it mfma4 kernel also supports partition size 512 + constexpr int PARTITION_SIZE = 256; + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + + constexpr int NTHR = 256; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION_MFMA4(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION_MFMA4(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION_MFMA4(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION_MFMA4(4); + break; case 5: LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; @@ -1744,13 +3251,195 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ - PSIZE, ALIBI_ENABLED) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); +template +void paged_attention_custom_launcher_navi( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + const std::optional& query_start_loc, int max_context_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + int num_seqs = block_tables.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: query start location is optional for V0 decode should not be used. + // If batch contains mix of prefills and decode, prefills should be skipped. + const int* query_start_loc_ptr = + query_start_loc + ? reinterpret_cast(query_start_loc.value().data_ptr()) + : nullptr; + + // NOTE: Navi does not support alibi_slopes. + const float* alibi_slopes_ptr = nullptr; + + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + // NOTE: Navi does not support fp8. + const auto fp8_out_scale_ptr = nullptr; + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + constexpr int PARTITION_SIZE = 256; + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + + constexpr int NTHR = 256; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION_MFMA16(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION_MFMA16(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION_MFMA16(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION_MFMA16(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + const int warp_size = 32; + const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, warp_size); + // reduction kernel supports upto 16 NPAR_loops * 32 (warp_size) * 256 + // (partition size) = 128K context length + switch (npar_loops) { + case 1: + LAUNCH_CUSTOM_REDUCTION(1); + break; + case 2: + LAUNCH_CUSTOM_REDUCTION(2); + break; + case 3: + LAUNCH_CUSTOM_REDUCTION(3); + break; + case 4: + LAUNCH_CUSTOM_REDUCTION(4); + break; + case 5: + LAUNCH_CUSTOM_REDUCTION(5); + break; + case 6: + LAUNCH_CUSTOM_REDUCTION(6); + break; + case 7: + LAUNCH_CUSTOM_REDUCTION(7); + break; + case 8: + LAUNCH_CUSTOM_REDUCTION(8); + break; + case 9: + LAUNCH_CUSTOM_REDUCTION(9); + break; + case 10: + LAUNCH_CUSTOM_REDUCTION(10); + break; + case 11: + LAUNCH_CUSTOM_REDUCTION(11); + break; + case 12: + LAUNCH_CUSTOM_REDUCTION(12); + break; + case 13: + LAUNCH_CUSTOM_REDUCTION(13); + break; + case 14: + LAUNCH_CUSTOM_REDUCTION(14); + break; + case 15: + LAUNCH_CUSTOM_REDUCTION(15); + break; + case 16: + LAUNCH_CUSTOM_REDUCTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + break; + } +} + +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + PSIZE, ALIBI_ENABLED) \ + if (!is_navi) { \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ + } else { \ + paged_attention_custom_launcher_navi< \ + T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale); \ + } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ OUTT, PSIZE) \ @@ -1807,6 +3496,24 @@ void paged_attention_custom_launcher( break; \ } +bool is_navi_gpu() { + static bool is_cached = false; + static bool result; + + if (!is_cached) { + int device_id; + hipDeviceProp_t deviceProp; + hipGetDevice(&device_id); + hipGetDeviceProperties(&deviceProp, device_id); + + std::string arch = deviceProp.gcnArchName; + result = arch.find("gfx11") == 0 || arch.find("gfx12") == 0; + is_cached = true; + } + + return result; +} + // clang-format off void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] @@ -1827,6 +3534,8 @@ void paged_attention( torch::Tensor& v_scale, const std::optional& fp8_out_scale) { // clang-format on + bool is_navi = is_navi_gpu(); + const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 9c8a50332ad0..c22523da4e43 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -8,6 +8,8 @@ #include +#include "cuda_utils.h" + #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -95,9 +97,9 @@ struct cutlass_sparse_3x_gemm { // clang-format off using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, - ElementAB, cutlass::layout::RowMajor, AlignmentAB, - ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, + ElementAB, cutlass::layout::RowMajor, AlignmentAB, + ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, ElementAcc, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7ca40a5e7827..371894c56a79 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -77,6 +77,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_output," " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); + + ops.def( + "convert_vertical_slash_indexes(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + ops.impl("convert_vertical_slash_indexes", torch::kCUDA, + &convert_vertical_slash_indexes); + + ops.def( + "convert_vertical_slash_indexes_mergehead(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " Tensor vertical_indices_count, Tensor slash_indices_count, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, + &convert_vertical_slash_indexes_mergehead); #endif // Activation ops @@ -292,8 +315,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " - "Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, " - "Tensor? perm_or_none, Tensor workspace, int b_q_type, " + "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " + "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor", {stride_tag}); @@ -363,6 +386,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + // cutlass nvfp4 block scaled group GEMM + ops.def( + "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," + " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," + " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()", + {stride_tag}); + ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( @@ -451,47 +482,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor page_table, float scale) -> ()"); ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // Mamba selective scan kernel - ops.def( - "selective_scan_fwd(Tensor! u, Tensor! delta," - "Tensor! A, Tensor! B, Tensor! C," - "Tensor? D_, Tensor!? z_, Tensor? delta_bias_," - "bool delta_softplus," - "Tensor? query_start_loc," - "Tensor? cache_indices," - "Tensor? has_initial_state," - "Tensor! ssm_states," - "int pad_slot_id) -> ()"); - ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); - - ops.def( - "causal_conv1d_update(Tensor! x," - "Tensor! conv_state," - "Tensor! weight," - "Tensor? bias_," - "bool silu_activation," - "Tensor? cache_seqlens_," - "Tensor? conv_state_indices," - "int pad_slot_id) -> ()"); - ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); - - ops.def( - "causal_conv1d_fwd(Tensor! x, Tensor! weight," - "Tensor? bias_," - "Tensor!? conv_states," - "Tensor? query_start_loc," - "Tensor? cache_indices," - "Tensor? has_initial_state," - "bool silu_activation," - "int pad_slot_id) -> ()"); - ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); - // Compute NVFP4 block quantized tensor. ops.def( "scaled_fp4_quant(Tensor! output, Tensor input," " Tensor! output_scale, Tensor input_scale) -> ()"); ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + // Compute NVFP4 experts quantization. + ops.def( + "scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts) -> ()"); + ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); + // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices // of the given capability ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); @@ -546,6 +549,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); + // Mamba selective scan kernel + ops.def( + "selective_scan_fwd(Tensor! u, Tensor! delta," + "Tensor! A, Tensor! B, Tensor! C," + "Tensor? D_, Tensor!? z_, Tensor? delta_bias_," + "bool delta_softplus," + "Tensor? query_start_loc," + "Tensor? cache_indices," + "Tensor? has_initial_state," + "Tensor! ssm_states," + "int pad_slot_id) -> ()"); + ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + + ops.def( + "causal_conv1d_update(Tensor! x," + "Tensor! conv_state," + "Tensor! weight," + "Tensor? bias_," + "bool silu_activation," + "Tensor? cache_seqlens_," + "Tensor? conv_state_indices," + "int pad_slot_id) -> ()"); + ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); + + ops.def( + "causal_conv1d_fwd(Tensor! x, Tensor! weight," + "Tensor? bias_," + "Tensor!? conv_states," + "Tensor? query_start_loc," + "Tensor? cache_indices," + "Tensor? has_initial_state," + "bool silu_activation," + "int pad_slot_id) -> ()"); + ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); + #ifndef USE_ROCM // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel ops.def( diff --git a/docker/Dockerfile b/docker/Dockerfile index 17adb7a92dc1..24986a1b73b1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -2,8 +2,8 @@ # to run the OpenAI compatible server. # Please update any changes made here to -# docs/source/contributing/dockerfile/dockerfile.md and -# docs/source/assets/contributing/dockerfile-stages-dependency.png +# docs/contributing/dockerfile/dockerfile.md and +# docs/assets/contributing/dockerfile-stages-dependency.png ARG CUDA_VERSION=12.8.1 #################### BASE BUILD IMAGE #################### @@ -77,7 +77,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # can be useful for both `dev` and `test` # explicitly set the list to avoid issues with torch 2.2 # see https://github.com/pytorch/pytorch/pull/123243 -ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' +ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} # Override the arch list for flash-attn to reduce the binary size ARG vllm_fa_cmake_gpu_arches='80-real;90-real' @@ -189,6 +189,8 @@ WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive ARG TARGETPLATFORM +SHELL ["/bin/bash", "-c"] + RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment @@ -255,9 +257,17 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ - # TESTING: install FlashInfer from source to test 2.7.0 final RC - FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \ - uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.2.post1" ; \ + # FlashInfer alreary has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use + if [[ "$CUDA_VERSION" == 12.8* ]]; then \ + uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl; \ + else \ + export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \ + CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ + if [ "$CUDA_MAJOR" -lt 12 ]; then \ + export FLASHINFER_ENABLE_SM90=0; \ + fi; \ + uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \ + fi \ fi COPY examples examples COPY benchmarks benchmarks @@ -267,7 +277,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ uv pip list -# Although we build Flashinfer with AOT mode, there's still +# Even when we build Flashinfer with AOT mode, there's still # some issues w.r.t. JIT compilation. Therefore we need to # install build dependencies for JIT compilation. # TODO: Remove this once FlashInfer AOT wheel is fixed @@ -295,8 +305,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" # install development dependencies (for testing) -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/dev.txt +RUN --mount=type=cache,target=/root/.cache/uv \ + CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ + if [ "$CUDA_MAJOR" -ge 12 ]; then \ + uv pip install --system -r requirements/dev.txt; \ + fi # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ @@ -315,7 +328,9 @@ COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1 # will not be imported by other tests RUN mkdir test_docs RUN mv docs test_docs/ +RUN cp -r examples test_docs/ RUN mv vllm test_docs/ +RUN mv mkdocs.yaml test_docs/ #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index c647d9036f40..5395b3884fb5 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -51,9 +51,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --upgrade pip && \ uv pip install -r requirements/cpu.txt -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install intel-openmp==2024.2.1 intel_extension_for_pytorch==2.6.0 - ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/opt/venv/lib/libiomp5.so:$LD_PRELOAD" RUN echo 'ulimit -c 0' >> ~/.bashrc diff --git a/docker/Dockerfile.neuron b/docker/Dockerfile.neuron index 2b63fe301bac..259dc5a23f78 100644 --- a/docker/Dockerfile.neuron +++ b/docker/Dockerfile.neuron @@ -1,6 +1,6 @@ # default base image # https://gallery.ecr.aws/neuron/pytorch-inference-neuronx -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.22.0-ubuntu22.04" +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04" FROM $BASE_IMAGE @@ -22,8 +22,7 @@ WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity -RUN python3 -m pip install sentencepiece transformers==4.48.0 -U -RUN python3 -m pip install neuronx-cc==2.17.194.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install pytest # uninstall transformers-neuronx package explicitly to avoid version conflict @@ -49,6 +48,8 @@ RUN python3 -m pip install -e tests/vllm_test_utils # FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps +RUN python3 -m pip install sentencepiece transformers==4.48.0 -U + # overwrite entrypoint to run bash script RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py diff --git a/docker/Dockerfile.ppc64le b/docker/Dockerfile.ppc64le index ec979227871c..14043eb7a8e3 100644 --- a/docker/Dockerfile.ppc64le +++ b/docker/Dockerfile.ppc64le @@ -21,12 +21,8 @@ ENV UV_LINK_MODE=copy # Note: A dummy file 'control' is created in /tmp/ to artificially create dependencies between stages when building stages in parallel # when `--jobs=` is passed with podman build command RUN microdnf install -y openssl-devel dnf \ - && dnf install -y https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-gpg-keys-9.0-24.el9.noarch.rpm \ - https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-stream-repos-9.0-24.el9.noarch.rpm \ - https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \ - && dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os \ - && dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/AppStream/`arch`/os \ - && dnf config-manager --set-enabled crb \ + && dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \ + && dnf config-manager --set-enabled codeready-builder-for-rhel-9-ppc64le-rpms \ && dnf install -y \ git tar gcc-toolset-13 automake libtool numactl-devel lapack-devel \ pkgconfig xsimd zeromq-devel kmod findutils protobuf* \ diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 12009b8aa046..45efcbde698b 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="7e1ed08" +ARG AITER_BRANCH="c1debd8" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 9c10cd56b594..4e89bb3057c5 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -84,16 +84,40 @@ RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ rustup default stable && \ rustup show +FROM python-install AS torch +ARG TORCH_VERSION=2.7.0 +ENV export _GLIBCXX_USE_CXX11_ABI=1 +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" + +WORKDIR /tmp + +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + git clone https://github.com/pytorch/pytorch.git && \ + cd pytorch && \ + git checkout v2.7.0 && \ + git submodule sync && \ + git submodule update --init --recursive && \ + uv pip install cmake ninja && \ + uv pip install -r requirements.txt && \ + python setup.py bdist_wheel + + FROM python-install AS torch-vision # Install torchvision -ARG TORCH_VERSION=2.7.0.dev20250304 +ARG TORCH_VERSION=2.7.0 ARG TORCH_VISION_VERSION=v0.20.1 WORKDIR /tmp RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ git clone https://github.com/pytorch/vision.git && \ cd vision && \ git checkout $TORCH_VISION_VERSION && \ - uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \ + TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ + uv pip install -v $TORCH_WHL_FILE && \ python setup.py bdist_wheel FROM python-install AS hf-xet-builder @@ -138,15 +162,17 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ --mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \ + --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ sed -i '/^torch/d' requirements/build.txt && \ ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \ + TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ $HF_XET_WHL_FILE \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + $TORCH_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ -r requirements/cpu.txt diff --git a/docs/.nav.yml b/docs/.nav.yml new file mode 100644 index 000000000000..42aba9775360 --- /dev/null +++ b/docs/.nav.yml @@ -0,0 +1,63 @@ +nav: + - Home: + - vLLM: README.md + - Getting Started: + - getting_started/quickstart.md + - getting_started/installation + - Examples: + - Offline Inference: examples/offline_inference + - Online Serving: examples/online_serving + - Others: examples/others + - Quick Links: + - User Guide: usage/README.md + - Developer Guide: contributing/README.md + - API Reference: api/README.md + - Timeline: + - Roadmap: https://roadmap.vllm.ai + - Releases: https://github.com/vllm-project/vllm/releases + - User Guide: + - Summary: usage/README.md + - usage/v1_guide.md + - General: + - usage/* + - Inference and Serving: + - serving/offline_inference.md + - serving/openai_compatible_server.md + - serving/* + - serving/integrations + - Deployment: + - deployment/* + - deployment/frameworks + - deployment/integrations + - Training: training + - Configuration: + - Summary: configuration/README.md + - configuration/* + - Models: + - models/supported_models.md + - models/generative_models.md + - models/pooling_models.md + - models/extensions + - Features: + - features/compatibility_matrix.md + - features/* + - features/quantization + - Developer Guide: + - Summary: contributing/README.md + - General: + - glob: contributing/* + flatten_single_child_sections: true + - Model Implementation: contributing/model + - Design Documents: + - V0: design + - V1: design/v1 + - API Reference: + - Summary: api/README.md + - Contents: + - glob: api/vllm/* + preserve_directory_names: true + - Community: + - community/* + - Blog: https://blog.vllm.ai + - Forum: https://discuss.vllm.ai + - Slack: https://slack.vllm.ai diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index d3b429dfb925..000000000000 --- a/docs/Makefile +++ /dev/null @@ -1,25 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source -BUILDDIR = build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -clean: - @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - rm -rf "$(SOURCEDIR)/getting_started/examples" - rm -rf "$(SOURCEDIR)/api/vllm" diff --git a/docs/README.md b/docs/README.md index dcd5e759dfa8..57b1d03deee2 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,43 +1,50 @@ -# vLLM documents - -## Build the docs - -- Make sure in `docs` directory - -```bash -cd docs -``` - -- Install the dependencies: - -```bash -pip install -r ../requirements/docs.txt -``` - -- Clean the previous build (optional but recommended): - -```bash -make clean -``` - -- Generate the HTML documentation: - -```bash -make html -``` - -## Open the docs with your browser - -- Serve the documentation locally: - -```bash -python -m http.server -d build/html/ -``` - -This will start a local server at http://localhost:8000. You can now open your browser and view the documentation. - -If port 8000 is already in use, you can specify a different port, for example: - -```bash -python -m http.server 3000 -d build/html/ -``` +# Welcome to vLLM + +
+ ![](./assets/logos/vllm-logo-text-light.png){ align="center" alt="vLLM" class="no-scaled-link" width="60%" } +
+ +

+Easy, fast, and cheap LLM serving for everyone + +

+ +

+ +Star +Watch +Fork +

+ +vLLM is a fast and easy-to-use library for LLM inference and serving. + +Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry. + +vLLM is fast with: + +- State-of-the-art serving throughput +- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) +- Continuous batching of incoming requests +- Fast model execution with CUDA/HIP graph +- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8 +- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. +- Speculative decoding +- Chunked prefill + +vLLM is flexible and easy to use with: + +- Seamless integration with popular HuggingFace models +- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more +- Tensor parallelism and pipeline parallelism support for distributed inference +- Streaming outputs +- OpenAI-compatible API server +- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudiยฎ accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators. +- Prefix caching support +- Multi-lora support + +For more information, check out the following: + +- [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention) +- [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023) +- [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al. +- [vLLM Meetups][meetups] diff --git a/docs/api/README.md b/docs/api/README.md new file mode 100644 index 000000000000..5c7b2ca79ee2 --- /dev/null +++ b/docs/api/README.md @@ -0,0 +1,107 @@ +# Summary + +[](){ #configuration } + +## Configuration + +API documentation for vLLM's configuration classes. + +- [vllm.config.ModelConfig][] +- [vllm.config.CacheConfig][] +- [vllm.config.TokenizerPoolConfig][] +- [vllm.config.LoadConfig][] +- [vllm.config.ParallelConfig][] +- [vllm.config.SchedulerConfig][] +- [vllm.config.DeviceConfig][] +- [vllm.config.SpeculativeConfig][] +- [vllm.config.LoRAConfig][] +- [vllm.config.PromptAdapterConfig][] +- [vllm.config.MultiModalConfig][] +- [vllm.config.PoolerConfig][] +- [vllm.config.DecodingConfig][] +- [vllm.config.ObservabilityConfig][] +- [vllm.config.KVTransferConfig][] +- [vllm.config.CompilationConfig][] +- [vllm.config.VllmConfig][] + +[](){ #offline-inference-api } + +## Offline Inference + +LLM Class. + +- [vllm.LLM][] + +LLM Inputs. + +- [vllm.inputs.PromptType][] +- [vllm.inputs.TextPrompt][] +- [vllm.inputs.TokensPrompt][] + +## vLLM Engines + +Engine classes for offline and online inference. + +- [vllm.LLMEngine][] +- [vllm.AsyncLLMEngine][] + +## Inference Parameters + +Inference parameters for vLLM APIs. + +[](){ #sampling-params } +[](){ #pooling-params } + +- [vllm.SamplingParams][] +- [vllm.PoolingParams][] + +[](){ #multi-modality } + +## Multi-Modality + +vLLM provides experimental support for multi-modal models through the [vllm.multimodal][] package. + +Multi-modal inputs can be passed alongside text and token prompts to [supported models][supported-mm-models] +via the `multi_modal_data` field in [vllm.inputs.PromptType][]. + +Looking to add your own multi-modal model? Please follow the instructions listed [here][supports-multimodal]. + +- [vllm.multimodal.MULTIMODAL_REGISTRY][] + +### Inputs + +User-facing inputs. + +- [vllm.multimodal.inputs.MultiModalDataDict][] + +Internal data structures. + +- [vllm.multimodal.inputs.PlaceholderRange][] +- [vllm.multimodal.inputs.NestedTensors][] +- [vllm.multimodal.inputs.MultiModalFieldElem][] +- [vllm.multimodal.inputs.MultiModalFieldConfig][] +- [vllm.multimodal.inputs.MultiModalKwargsItem][] +- [vllm.multimodal.inputs.MultiModalKwargs][] +- [vllm.multimodal.inputs.MultiModalInputs][] + +### Data Parsing + +- [vllm.multimodal.parse][] + +### Data Processing + +- [vllm.multimodal.processing][] + +### Memory Profiling + +- [vllm.multimodal.profiling][] + +### Registry + +- [vllm.multimodal.registry][] + +## Model Development + +- [vllm.model_executor.models.interfaces_base][] +- [vllm.model_executor.models.interfaces][] +- [vllm.model_executor.models.adapters][] diff --git a/docs/api/vllm/.meta.yml b/docs/api/vllm/.meta.yml new file mode 100644 index 000000000000..c15adfec644c --- /dev/null +++ b/docs/api/vllm/.meta.yml @@ -0,0 +1,2 @@ +search: + boost: 0.5 diff --git a/docs/source/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png similarity index 100% rename from docs/source/assets/contributing/dockerfile-stages-dependency.png rename to docs/assets/contributing/dockerfile-stages-dependency.png diff --git a/docs/source/assets/deployment/anything-llm-chat-with-doc.png b/docs/assets/deployment/anything-llm-chat-with-doc.png similarity index 100% rename from docs/source/assets/deployment/anything-llm-chat-with-doc.png rename to docs/assets/deployment/anything-llm-chat-with-doc.png diff --git a/docs/source/assets/deployment/anything-llm-chat-without-doc.png b/docs/assets/deployment/anything-llm-chat-without-doc.png similarity index 100% rename from docs/source/assets/deployment/anything-llm-chat-without-doc.png rename to docs/assets/deployment/anything-llm-chat-without-doc.png diff --git a/docs/source/assets/deployment/anything-llm-provider.png b/docs/assets/deployment/anything-llm-provider.png similarity index 100% rename from docs/source/assets/deployment/anything-llm-provider.png rename to docs/assets/deployment/anything-llm-provider.png diff --git a/docs/source/assets/deployment/anything-llm-upload-doc.png b/docs/assets/deployment/anything-llm-upload-doc.png similarity index 100% rename from docs/source/assets/deployment/anything-llm-upload-doc.png rename to docs/assets/deployment/anything-llm-upload-doc.png diff --git a/docs/source/assets/deployment/architecture_helm_deployment.png b/docs/assets/deployment/architecture_helm_deployment.png similarity index 100% rename from docs/source/assets/deployment/architecture_helm_deployment.png rename to docs/assets/deployment/architecture_helm_deployment.png diff --git a/docs/source/assets/deployment/chatbox-chat.png b/docs/assets/deployment/chatbox-chat.png similarity index 100% rename from docs/source/assets/deployment/chatbox-chat.png rename to docs/assets/deployment/chatbox-chat.png diff --git a/docs/source/assets/deployment/chatbox-settings.png b/docs/assets/deployment/chatbox-settings.png similarity index 100% rename from docs/source/assets/deployment/chatbox-settings.png rename to docs/assets/deployment/chatbox-settings.png diff --git a/docs/assets/deployment/dify-chat.png b/docs/assets/deployment/dify-chat.png new file mode 100644 index 000000000000..dfea23309c1c Binary files /dev/null and b/docs/assets/deployment/dify-chat.png differ diff --git a/docs/assets/deployment/dify-create-chatbot.png b/docs/assets/deployment/dify-create-chatbot.png new file mode 100644 index 000000000000..07bbde5ba285 Binary files /dev/null and b/docs/assets/deployment/dify-create-chatbot.png differ diff --git a/docs/assets/deployment/dify-settings.png b/docs/assets/deployment/dify-settings.png new file mode 100644 index 000000000000..7900cc774741 Binary files /dev/null and b/docs/assets/deployment/dify-settings.png differ diff --git a/docs/source/assets/deployment/open_webui.png b/docs/assets/deployment/open_webui.png similarity index 100% rename from docs/source/assets/deployment/open_webui.png rename to docs/assets/deployment/open_webui.png diff --git a/docs/source/assets/deployment/streamlit-chat.png b/docs/assets/deployment/streamlit-chat.png similarity index 100% rename from docs/source/assets/deployment/streamlit-chat.png rename to docs/assets/deployment/streamlit-chat.png diff --git a/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png b/docs/assets/design/arch_overview/entrypoints.excalidraw.png similarity index 100% rename from docs/source/assets/design/arch_overview/entrypoints.excalidraw.png rename to docs/assets/design/arch_overview/entrypoints.excalidraw.png diff --git a/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png b/docs/assets/design/arch_overview/llm_engine.excalidraw.png similarity index 100% rename from docs/source/assets/design/arch_overview/llm_engine.excalidraw.png rename to docs/assets/design/arch_overview/llm_engine.excalidraw.png diff --git a/docs/source/assets/design/hierarchy.png b/docs/assets/design/hierarchy.png similarity index 100% rename from docs/source/assets/design/hierarchy.png rename to docs/assets/design/hierarchy.png diff --git a/docs/source/assets/design/v1/metrics/intervals-1.png b/docs/assets/design/v1/metrics/intervals-1.png similarity index 100% rename from docs/source/assets/design/v1/metrics/intervals-1.png rename to docs/assets/design/v1/metrics/intervals-1.png diff --git a/docs/source/assets/design/v1/metrics/intervals-2.png b/docs/assets/design/v1/metrics/intervals-2.png similarity index 100% rename from docs/source/assets/design/v1/metrics/intervals-2.png rename to docs/assets/design/v1/metrics/intervals-2.png diff --git a/docs/source/assets/design/v1/metrics/intervals-3.png b/docs/assets/design/v1/metrics/intervals-3.png similarity index 100% rename from docs/source/assets/design/v1/metrics/intervals-3.png rename to docs/assets/design/v1/metrics/intervals-3.png diff --git a/docs/source/assets/design/v1/prefix_caching/example-time-1.png b/docs/assets/design/v1/prefix_caching/example-time-1.png similarity index 100% rename from docs/source/assets/design/v1/prefix_caching/example-time-1.png rename to docs/assets/design/v1/prefix_caching/example-time-1.png diff --git a/docs/source/assets/design/v1/prefix_caching/example-time-3.png b/docs/assets/design/v1/prefix_caching/example-time-3.png similarity index 100% rename from docs/source/assets/design/v1/prefix_caching/example-time-3.png rename to docs/assets/design/v1/prefix_caching/example-time-3.png diff --git a/docs/source/assets/design/v1/prefix_caching/example-time-4.png b/docs/assets/design/v1/prefix_caching/example-time-4.png similarity index 100% rename from docs/source/assets/design/v1/prefix_caching/example-time-4.png rename to docs/assets/design/v1/prefix_caching/example-time-4.png diff --git a/docs/source/assets/design/v1/prefix_caching/example-time-5.png b/docs/assets/design/v1/prefix_caching/example-time-5.png similarity index 100% rename from docs/source/assets/design/v1/prefix_caching/example-time-5.png rename to docs/assets/design/v1/prefix_caching/example-time-5.png diff --git a/docs/source/assets/design/v1/prefix_caching/example-time-6.png b/docs/assets/design/v1/prefix_caching/example-time-6.png similarity index 100% rename from docs/source/assets/design/v1/prefix_caching/example-time-6.png rename to docs/assets/design/v1/prefix_caching/example-time-6.png diff --git a/docs/source/assets/design/v1/prefix_caching/example-time-7.png b/docs/assets/design/v1/prefix_caching/example-time-7.png similarity index 100% rename from docs/source/assets/design/v1/prefix_caching/example-time-7.png rename to docs/assets/design/v1/prefix_caching/example-time-7.png diff --git a/docs/source/assets/design/v1/prefix_caching/free.png b/docs/assets/design/v1/prefix_caching/free.png similarity index 100% rename from docs/source/assets/design/v1/prefix_caching/free.png rename to docs/assets/design/v1/prefix_caching/free.png diff --git a/docs/source/assets/design/v1/prefix_caching/overview.png b/docs/assets/design/v1/prefix_caching/overview.png similarity index 100% rename from docs/source/assets/design/v1/prefix_caching/overview.png rename to docs/assets/design/v1/prefix_caching/overview.png diff --git a/docs/source/assets/features/disagg_prefill/abstraction.jpg b/docs/assets/features/disagg_prefill/abstraction.jpg similarity index 100% rename from docs/source/assets/features/disagg_prefill/abstraction.jpg rename to docs/assets/features/disagg_prefill/abstraction.jpg diff --git a/docs/source/assets/features/disagg_prefill/overview.jpg b/docs/assets/features/disagg_prefill/overview.jpg similarity index 100% rename from docs/source/assets/features/disagg_prefill/overview.jpg rename to docs/assets/features/disagg_prefill/overview.jpg diff --git a/docs/source/assets/kernel/k_vecs.png b/docs/assets/kernel/k_vecs.png similarity index 100% rename from docs/source/assets/kernel/k_vecs.png rename to docs/assets/kernel/k_vecs.png diff --git a/docs/source/assets/kernel/key.png b/docs/assets/kernel/key.png similarity index 100% rename from docs/source/assets/kernel/key.png rename to docs/assets/kernel/key.png diff --git a/docs/source/assets/kernel/logits_vec.png b/docs/assets/kernel/logits_vec.png similarity index 100% rename from docs/source/assets/kernel/logits_vec.png rename to docs/assets/kernel/logits_vec.png diff --git a/docs/source/assets/kernel/q_vecs.png b/docs/assets/kernel/q_vecs.png similarity index 100% rename from docs/source/assets/kernel/q_vecs.png rename to docs/assets/kernel/q_vecs.png diff --git a/docs/source/assets/kernel/query.png b/docs/assets/kernel/query.png similarity index 100% rename from docs/source/assets/kernel/query.png rename to docs/assets/kernel/query.png diff --git a/docs/source/assets/kernel/v_vec.png b/docs/assets/kernel/v_vec.png similarity index 100% rename from docs/source/assets/kernel/v_vec.png rename to docs/assets/kernel/v_vec.png diff --git a/docs/source/assets/kernel/value.png b/docs/assets/kernel/value.png similarity index 100% rename from docs/source/assets/kernel/value.png rename to docs/assets/kernel/value.png diff --git a/docs/source/assets/logos/vllm-logo-only-light.ico b/docs/assets/logos/vllm-logo-only-light.ico similarity index 100% rename from docs/source/assets/logos/vllm-logo-only-light.ico rename to docs/assets/logos/vllm-logo-only-light.ico diff --git a/docs/source/assets/logos/vllm-logo-only-light.png b/docs/assets/logos/vllm-logo-only-light.png similarity index 100% rename from docs/source/assets/logos/vllm-logo-only-light.png rename to docs/assets/logos/vllm-logo-only-light.png diff --git a/docs/source/assets/logos/vllm-logo-text-dark.png b/docs/assets/logos/vllm-logo-text-dark.png similarity index 100% rename from docs/source/assets/logos/vllm-logo-text-dark.png rename to docs/assets/logos/vllm-logo-text-dark.png diff --git a/docs/source/assets/logos/vllm-logo-text-light.png b/docs/assets/logos/vllm-logo-text-light.png similarity index 100% rename from docs/source/assets/logos/vllm-logo-text-light.png rename to docs/assets/logos/vllm-logo-text-light.png diff --git a/docs/source/community/meetups.md b/docs/community/meetups.md similarity index 94% rename from docs/source/community/meetups.md rename to docs/community/meetups.md index 085918bed2b0..8ea42e3cad18 100644 --- a/docs/source/community/meetups.md +++ b/docs/community/meetups.md @@ -1,9 +1,11 @@ -(meetups)= - -# vLLM Meetups +--- +title: Meetups +--- +[](){ #meetups } We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [The first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg), March 16th 2025. [[Slides]](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). diff --git a/docs/source/community/sponsors.md b/docs/community/sponsors.md similarity index 100% rename from docs/source/community/sponsors.md rename to docs/community/sponsors.md diff --git a/docs/configuration/README.md b/docs/configuration/README.md new file mode 100644 index 000000000000..6a8fbc79f4af --- /dev/null +++ b/docs/configuration/README.md @@ -0,0 +1,9 @@ +# Configuration Options + +This section lists the most common options for running vLLM. + +There are three main levels of configuration, from highest priority to lowest priority: + +- [Request parameters][completions-api] and [input arguments][sampling-params] +- [Engine arguments](./engine_args.md) +- [Environment variables](./env_vars.md) diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md new file mode 100644 index 000000000000..a1283a503a6d --- /dev/null +++ b/docs/configuration/conserving_memory.md @@ -0,0 +1,144 @@ +# Conserving Memory + +Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem. + +## Tensor Parallelism (TP) + +Tensor parallelism (`tensor_parallel_size` option) can be used to split the model across multiple GPUs. + +The following code splits the model across 2 GPUs. + +```python +from vllm import LLM + +llm = LLM(model="ibm-granite/granite-3.1-8b-instruct", + tensor_parallel_size=2) +``` + +!!! warning + To ensure that vLLM initializes CUDA correctly, you should avoid calling related functions (e.g. [torch.cuda.set_device][]) + before initializing vLLM. Otherwise, you may run into an error like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`. + + To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable. + +!!! note + With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). + + You can convert the model checkpoint to a sharded checkpoint using . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. + +## Quantization + +Quantized models take less memory at the cost of lower precision. + +Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Red Hat AI](https://huggingface.co/RedHatAI)) +and used directly without extra configuration. + +Dynamic quantization is also supported via the `quantization` option -- see [here][quantization-index] for more details. + +## Context length and batch size + +You can further reduce memory usage by limiting the context length of the model (`max_model_len` option) +and the maximum batch size (`max_num_seqs` option). + +```python +from vllm import LLM + +llm = LLM(model="adept/fuyu-8b", + max_model_len=2048, + max_num_seqs=2) +``` + +## Reduce CUDA Graphs + +By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU. + +!!! warning + CUDA graph capture takes up more memory in V1 than in V0. + +You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage: + +```python +from vllm import LLM +from vllm.config import CompilationConfig, CompilationLevel + +llm = LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + # By default, it goes up to max_num_seqs + cudagraph_capture_sizes=[1, 2, 4, 8, 16], + ), +) +``` + +You can disable graph capturing completely via the `enforce_eager` flag: + +```python +from vllm import LLM + +llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", + enforce_eager=True) +``` + +## Adjust cache size + +If you run out of CPU RAM, try the following options: + +- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). +- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). + +## Multi-modal input limits + +You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model: + +```python +from vllm import LLM + +# Accept up to 3 images and 1 video per prompt +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={"image": 3, "video": 1}) +``` + +You can go a step further and disable unused modalities completely by setting its limit to zero. +For example, if your application only accepts image input, there is no need to allocate any memory for videos. + +```python +from vllm import LLM + +# Accept any number of images but no videos +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={"video": 0}) +``` + +You can even run a multi-modal model for text-only inference: + +```python +from vllm import LLM + +# Don't accept images. Just text. +llm = LLM(model="google/gemma-3-27b-it", + limit_mm_per_prompt={"image": 0}) +``` + +## Multi-modal processor arguments + +For certain models, you can adjust the multi-modal processor arguments to +reduce the size of the processed multi-modal inputs, which in turn saves memory. + +Here are some examples: + +```python +from vllm import LLM + +# Available for Qwen2-VL series models +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_kwargs={ + "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 + }) + +# Available for InternVL series models +llm = LLM(model="OpenGVLab/InternVL2-2B", + mm_processor_kwargs={ + "max_dynamic_patch": 4, # Default is 12 + }) +``` diff --git a/docs/configuration/engine_args.md b/docs/configuration/engine_args.md new file mode 100644 index 000000000000..fb2689a56391 --- /dev/null +++ b/docs/configuration/engine_args.md @@ -0,0 +1,18 @@ +--- +title: Engine Arguments +--- +[](){ #engine-args } + +Engine arguments control the behavior of the vLLM engine. + +- For [offline inference][offline-inference], they are part of the arguments to [LLM][vllm.LLM] class. +- For [online serving][openai-compatible-server], they are part of the arguments to `vllm serve`. + +You can look at [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs] to see the available engine arguments. + +However, these classes are a combination of the configuration classes defined in [vllm.config][]. Therefore, we would recommend you read about them there where they are best documented. + +For offline inference you will have access to these configuration classes and for online serving you can cross-reference the configs with `vllm serve --help`, which has its arguments grouped by config. + +!!! note + Additional arguments are available to the [AsyncLLMEngine][vllm.engine.async_llm_engine.AsyncLLMEngine] which is used for online serving. These can be found by running `vllm serve --help` diff --git a/docs/configuration/env_vars.md b/docs/configuration/env_vars.md new file mode 100644 index 000000000000..f6d548a19d91 --- /dev/null +++ b/docs/configuration/env_vars.md @@ -0,0 +1,12 @@ +# Environment Variables + +vLLM uses the following environment variables to configure the system: + +!!! warning + Please note that `VLLM_PORT` and `VLLM_HOST_IP` set the port and ip for vLLM's **internal usage**. It is not the port and ip for the API server. If you use `--host $VLLM_HOST_IP` and `--port $VLLM_PORT` to start the API server, it will not work. + + All environment variables used by vLLM are prefixed with `VLLM_`. **Special care should be taken for Kubernetes users**: please do not name the service as `vllm`, otherwise environment variables set by Kubernetes might conflict with vLLM's environment variables, because [Kubernetes sets environment variables for each service with the capitalized service name as the prefix](https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables). + +```python +--8<-- "vllm/envs.py:env-vars-definition" +``` diff --git a/docs/configuration/model_resolution.md b/docs/configuration/model_resolution.md new file mode 100644 index 000000000000..8757c257d3e9 --- /dev/null +++ b/docs/configuration/model_resolution.md @@ -0,0 +1,23 @@ +# Model Resolution + +vLLM loads HuggingFace-compatible models by inspecting the `architectures` field in `config.json` of the model repository +and finding the corresponding implementation that is registered to vLLM. +Nevertheless, our model resolution may fail for the following reasons: + +- The `config.json` of the model repository lacks the `architectures` field. +- Unofficial repositories refer to a model using alternative names which are not recorded in vLLM. +- The same architecture name is used for multiple models, creating ambiguity as to which model should be loaded. + +To fix this, explicitly specify the model architecture by passing `config.json` overrides to the `hf_overrides` option. +For example: + +```python +from vllm import LLM + +model = LLM( + model="cerebras/Cerebras-GPT-1.3B", + hf_overrides={"architectures": ["GPT2LMHeadModel"]}, # GPT-2 +) +``` + +Our [list of supported models][supported-models] shows the model architectures that are recognized by vLLM. diff --git a/docs/source/performance/optimization.md b/docs/configuration/optimization.md similarity index 99% rename from docs/source/performance/optimization.md rename to docs/configuration/optimization.md index 4160f0784962..811925c19e63 100644 --- a/docs/source/performance/optimization.md +++ b/docs/configuration/optimization.md @@ -1,5 +1,3 @@ -(optimization-and-tuning)= - # Optimization and Tuning This guide covers optimization strategies and performance tuning for vLLM V1. @@ -26,7 +24,7 @@ You can monitor the number of preemption requests through Prometheus metrics exp In vLLM V1, the default preemption mode is `RECOMPUTE` rather than `SWAP`, as recomputation has lower overhead in the V1 architecture. -(chunked-prefill)= +[](){ #chunked-prefill } ## Chunked Prefill diff --git a/docs/configuration/serve_args.md b/docs/configuration/serve_args.md new file mode 100644 index 000000000000..16b4b29f45d9 --- /dev/null +++ b/docs/configuration/serve_args.md @@ -0,0 +1,38 @@ +--- +title: Server Arguments +--- +[](){ #serve-args } + +The `vllm serve` command is used to launch the OpenAI-compatible server. + +## CLI Arguments + +The `vllm serve` command is used to launch the OpenAI-compatible server. +To see the available CLI arguments, run `vllm serve --help`! + +## Configuration file + +You can load CLI arguments via a [YAML](https://yaml.org/) config file. +The argument names must be the long form of those outlined [above][serve-args]. + +For example: + +```yaml +# config.yaml + +model: meta-llama/Llama-3.1-8B-Instruct +host: "127.0.0.1" +port: 6379 +uvicorn-log-level: "info" +``` + +To use the above config file: + +```bash +vllm serve --config config.yaml +``` + +!!! note + In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence. + The order of priorities is `command line > config file values > defaults`. + e.g. `vllm serve SOME_MODEL --config config.yaml`, SOME_MODEL takes precedence over `model` in config file. diff --git a/docs/source/contributing/overview.md b/docs/contributing/README.md similarity index 84% rename from docs/source/contributing/overview.md rename to docs/contributing/README.md index 89b31f0311e2..2517436afcc1 100644 --- a/docs/source/contributing/overview.md +++ b/docs/contributing/README.md @@ -16,9 +16,9 @@ Finally, one of the most impactful ways to support us is by raising awareness ab Unsure on where to start? Check out the following links for tasks to work on: - [Good first issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22) - - [Selected onboarding tasks](gh-project:6) + - [Selected onboarding tasks](gh-project:6) - [New model requests](https://github.com/vllm-project/vllm/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22new-model%22) - - [Models with multi-modal capabilities](gh-project:10) + - [Models with multi-modal capabilities](gh-project:10) ## License @@ -27,7 +27,21 @@ See . ## Developing Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation. -Check out the [building from source](#build-from-source) documentation for details. +Check out the [building from source][build-from-source] documentation for details. + +### Building the docs + +Install the dependencies: + +```bash +pip install -r requirements/docs.txt +``` + +Start the autoreloading MkDocs server: + +```bash +mkdocs serve +``` ## Testing @@ -48,29 +62,25 @@ pre-commit run mypy-3.9 --hook-stage manual --all-files pytest tests/ ``` -:::{tip} -Since the ships with Python 3.12, all tests in CI (except `mypy`) are run with Python 3.12. +!!! tip + Since the ships with Python 3.12, all tests in CI (except `mypy`) are run with Python 3.12. -Therefore, we recommend developing with Python 3.12 to minimise the chance of your local environment clashing with our CI environment. -::: + Therefore, we recommend developing with Python 3.12 to minimise the chance of your local environment clashing with our CI environment. -:::{note} -Currently, the repository is not fully checked by `mypy`. -::: +!!! note + Currently, the repository is not fully checked by `mypy`. -:::{note} -Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU -platform to run unit tests locally, rely on the continuous integration system to run the tests for -now. -::: +!!! note + Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU + platform to run unit tests locally, rely on the continuous integration system to run the tests for + now. ## Issues If you encounter a bug or have a feature request, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible. -:::{important} -If you discover a security vulnerability, please follow the instructions [here](gh-file:SECURITY.md#reporting-a-vulnerability). -::: +!!! warning + If you discover a security vulnerability, please follow the instructions [here](gh-file:SECURITY.md#reporting-a-vulnerability). ## Pull Requests & Code Reviews @@ -106,9 +116,8 @@ appropriately to indicate the type of change. Please use one of the following: - `[Misc]` for PRs that do not fit the above categories. Please use this sparingly. -:::{note} -If the PR spans more than one category, please include all relevant prefixes. -::: +!!! note + If the PR spans more than one category, please include all relevant prefixes. ### Code Quality @@ -121,9 +130,8 @@ The PR needs to meet the following code quality standards: understand the code. - Include sufficient tests to ensure the project stays correct and robust. This includes both unit tests and integration tests. -- Please add documentation to `docs/source/` if the PR modifies the - user-facing behaviors of vLLM. It helps vLLM users understand and utilize the - new features or changes. +- Please add documentation to `docs/` if the PR modifies the user-facing behaviors of vLLM. + It helps vLLM users understand and utilize the new features or changes. ### Adding or Changing Kernels diff --git a/docs/source/performance/benchmarks.md b/docs/contributing/benchmarks.md similarity index 86% rename from docs/source/performance/benchmarks.md rename to docs/contributing/benchmarks.md index 39dc470a1c70..00505fc6f2a9 100644 --- a/docs/source/performance/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -1,13 +1,14 @@ -(benchmarks)= - -# Benchmark Suites +--- +title: Benchmark Suites +--- +[](){ #benchmarks } vLLM contains two sets of benchmarks: -- [Performance benchmarks](#performance-benchmarks) -- [Nightly benchmarks](#nightly-benchmarks) +- [Performance benchmarks][performance-benchmarks] +- [Nightly benchmarks][nightly-benchmarks] -(performance-benchmarks)= +[](){ #performance-benchmarks } ## Performance Benchmarks @@ -17,7 +18,7 @@ The latest performance results are hosted on the public [vLLM Performance Dashbo More information on the performance benchmarks and their parameters can be found [here](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). -(nightly-benchmarks)= +[](){ #nightly-benchmarks } ## Nightly Benchmarks diff --git a/docs/source/contributing/deprecation_policy.md b/docs/contributing/deprecation_policy.md similarity index 100% rename from docs/source/contributing/deprecation_policy.md rename to docs/contributing/deprecation_policy.md diff --git a/docs/source/contributing/dockerfile/dockerfile.md b/docs/contributing/dockerfile/dockerfile.md similarity index 82% rename from docs/source/contributing/dockerfile/dockerfile.md rename to docs/contributing/dockerfile/dockerfile.md index 90b9a33cfbe6..a39f335c87b8 100644 --- a/docs/source/contributing/dockerfile/dockerfile.md +++ b/docs/contributing/dockerfile/dockerfile.md @@ -1,7 +1,7 @@ # Dockerfile We provide a to construct the image for running an OpenAI compatible server with vLLM. -More information about deploying with Docker can be found [here](#deployment-docker). +More information about deploying with Docker can be found [here][deployment-docker]. Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes: @@ -17,18 +17,21 @@ The edges of the build graph represent: - `RUN --mount=(.\*)from=...` dependencies (with a dotted line and an empty diamond arrow head) - > :::{figure} /assets/contributing/dockerfile-stages-dependency.png - > :align: center - > :alt: query - > :width: 100% - > ::: + >
+ > ![](../../assets/contributing/dockerfile-stages-dependency.png){ align="center" alt="query" width="100%" } + >
> > Made using: > > Commands to regenerate the build graph (make sure to run it **from the \`root\` directory of the vLLM repository** where the dockerfile is present): > > ```bash - > dockerfilegraph -o png --legend --dpi 200 --max-label-length 50 --filename docker/Dockerfile + > dockerfilegraph \ + > -o png \ + > --legend \ + > --dpi 200 \ + > --max-label-length 50 \ + > --filename docker/Dockerfile > ``` > > or in case you want to run it directly with the docker image: diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md new file mode 100644 index 000000000000..b7727f02c11b --- /dev/null +++ b/docs/contributing/model/README.md @@ -0,0 +1,23 @@ +--- +title: Adding a New Model +--- +[](){ #new-model } + +This section provides more information on how to integrate a [PyTorch](https://pytorch.org/) model into vLLM. + +Contents: + +- [Basic](basic.md) +- [Registration](registration.md) +- [Tests](tests.md) +- [Multimodal](multimodal.md) + +!!! note + The complexity of adding a new model depends heavily on the model's architecture. + The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. + However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex. + +!!! tip + If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues) + or ask on our [developer slack](https://slack.vllm.ai). + We will be happy to help you out! diff --git a/docs/source/contributing/model/basic.md b/docs/contributing/model/basic.md similarity index 82% rename from docs/source/contributing/model/basic.md rename to docs/contributing/model/basic.md index ad31995f76be..0c0ba3379257 100644 --- a/docs/source/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -1,6 +1,7 @@ -(new-model-basic)= - -# Implementing a Basic Model +--- +title: Implementing a Basic Model +--- +[](){ #new-model-basic } This guide walks you through the steps to implement a basic vLLM model. @@ -10,9 +11,8 @@ First, clone the PyTorch model code from the source repository. For instance, vLLM's [OPT model](gh-file:vllm/model_executor/models/opt.py) was adapted from HuggingFace's [modeling_opt.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py) file. -:::{warning} -Make sure to review and adhere to the original code's copyright and licensing terms! -::: +!!! warning + Make sure to review and adhere to the original code's copyright and licensing terms! ## 2. Make your code compatible with vLLM @@ -67,7 +67,7 @@ class MyModel(nn.Module): ... ``` -- Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension. +- Rewrite the [forward][torch.nn.Module.forward] method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension. ```python def forward( @@ -78,10 +78,9 @@ def forward( ... ``` -:::{note} -Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. -If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM. -::: +!!! note + Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. + If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM. For reference, check out our [Llama implementation](gh-file:vllm/model_executor/models/llama.py). vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out for more examples. @@ -89,7 +88,7 @@ For reference, check out our [Llama implementation](gh-file:vllm/model_executor/ If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it. To do this, substitute your model's linear and embedding layers with their tensor-parallel versions. -For the embedding layer, you can simply replace {class}`torch.nn.Embedding` with `VocabParallelEmbedding`. For the output LM head, you can use `ParallelLMHead`. +For the embedding layer, you can simply replace [torch.nn.Embedding][] with `VocabParallelEmbedding`. For the output LM head, you can use `ParallelLMHead`. When it comes to the linear layers, we provide the following options to parallelize them: - `ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving. @@ -107,7 +106,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a ## 5. Register your model -See [this page](#new-model-registration) for instructions on how to register your new model to be used by vLLM. +See [this page][new-model-registration] for instructions on how to register your new model to be used by vLLM. ## Frequently Asked Questions @@ -117,7 +116,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m To support a model with interleaving sliding windows, we need to take care of the following details: -- Make sure [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/config.py#L308) evaluates `has_interleaved_attention` to `True` for this model, and set `self.hf_text_config.interleaved_sliding_window` to the format of interleaving sliding windows the model can understand. Then, `self.hf_text_config.sliding_window` will be deleted, and the model will be treated as a full-attention model. +- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model. - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). With these two steps, interleave sliding windows should work with the model. diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md new file mode 100644 index 000000000000..892ab9098407 --- /dev/null +++ b/docs/contributing/model/multimodal.md @@ -0,0 +1,803 @@ +--- +title: Multi-Modal Support +--- +[](){ #supports-multimodal } + +This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs][multimodal-inputs]. + +## 1. Update the base vLLM model + +It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic]. +Further update the model as follows: + +- Reserve a keyword parameter in [forward][torch.nn.Module.forward] for each input tensor that corresponds to a multi-modal input, as shown in the following example: + + ```diff + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + + pixel_values: torch.Tensor, + ) -> SamplerOutput: + ``` + + More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it. + +- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs. + + ```python + class YourModelForImage2Seq(nn.Module): + ... + + def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor: + + assert self.vision_encoder is not None + image_features = self.vision_encoder(image_input) + return self.multi_modal_projector(image_features) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + # Validate the multimodal input keyword arguments + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + # Run multimodal inputs through encoder and projector + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + ``` + +!!! warning + The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request. + +- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings. + + ```python + from .utils import merge_multimodal_embeddings + + class YourModelForImage2Seq(nn.Module): + ... + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + # `get_input_embeddings` should already be implemented for the language + # model as one of the requirements of basic vLLM model implementation. + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=self.config.image_token_index) + + return inputs_embeds + ``` + +- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model. + + ```python + class YourModelForImage2Seq(nn.Module): + ... + + def get_language_model(self) -> torch.nn.Module: + # Change `language_model` according to your implementation. + return self.language_model + ``` + +- Once the above steps are done, update the model class with the [SupportsMultiModal][vllm.model_executor.models.interfaces.SupportsMultiModal] interface. + + ```diff + + from vllm.model_executor.models.interfaces import SupportsMultiModal + + - class YourModelForImage2Seq(nn.Module): + + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): + ``` + +!!! note + The model class does not have to be named `*ForCausalLM`. + Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples. + +## 2. Specify processing information + +Next, create a subclass of [BaseProcessingInfo][vllm.multimodal.processing.BaseProcessingInfo] +to provide basic information related to HF processing. + +### Maximum number of input items + +You need to override the abstract method [get_supported_mm_limits][vllm.multimodal.processing.BaseProcessingInfo.get_supported_mm_limits] +to return the maximum number of input items for each modality supported by the model. + +For example, if the model supports any number of images but only one video per prompt: + +```python +def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} +``` + +## 3. Specify dummy inputs + +Then, inherit [BaseDummyInputsBuilder][vllm.multimodal.profiling.BaseDummyInputsBuilder] to construct dummy inputs for +HF processing as well as memory profiling. + +### For memory profiling + +Override the abstract methods [get_dummy_text][vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text] and [get_dummy_mm_data][vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_mm_data] to construct dummy inputs for memory profiling. These dummy inputs should result in the worst-case memory usage of the model so that vLLM can reserve the correct amount of memory for it. + +Assuming that the memory usage increases with the number of tokens, the dummy inputs can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens. + +=== "Basic example: LLaVA" + + Looking at the code of HF's `LlavaForConditionalGeneration`: + + ```python + # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544 + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + ``` + + The number of placeholder feature tokens per image is `image_features.shape[1]`. + `image_features` is calculated inside the `get_image_features` method: + + ```python + # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300 + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + image_features = self.multi_modal_projector(selected_image_feature) + return image_features + ``` + + We can infer that `image_features.shape[1]` is based on `image_outputs.hidden_states.shape[1]` from the vision tower + (`CLIPVisionModel` for the [`llava-hf/llava-1.5-7b-hf`](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model). + Moreover, we only need the sequence length (the second dimension of the tensor) to get `image_features.shape[1]`. + The sequence length is determined by the initial hidden states in `CLIPVisionTransformer` since the attention + mechanism doesn't change the sequence length of the output hidden states. + + ```python + # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L1094-L1102 + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + ``` + + To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`: + + ```python + # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257 + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + ``` + + We can infer that `embeddings.shape[1] == self.num_positions`, where + + ```python + # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L195-L196 + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + ``` + + Overall, the number of placeholder feature tokens for an image can be calculated as: + + ```python + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + hf_processor = self.get_hf_processor() + + image_size = hf_config.vision_config.image_size + patch_size = hf_config.vision_config.patch_size + + num_image_tokens = (image_size // patch_size) ** 2 + 1 + if hf_processor.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + return num_image_tokens + ``` + + Notice that the number of image tokens doesn't depend on the image width and height. + We can simply use a dummy `image_size` to calculate the multimodal profiling data: + + ```python + # NOTE: In actuality, this is usually implemented as part of the + # model's subclass of `BaseProcessingInfo`, but we show it as is + # here for simplicity. + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + width = height = hf_config.image_size + return ImageSize(width=width, height=height) + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + ``` + + For the text, we simply expand the multimodal image token from the model config to match the desired number of images. + + ```python + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + ``` + +=== "No input placeholders: Fuyu" + + Looking at the code of HF's `FuyuForCausalLM`: + + ```python + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322 + if image_patches is not None and past_key_values is None: + patch_embeddings = [ + self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)) + .squeeze(0) + .to(inputs_embeds.device) + for patch in image_patches + ] + inputs_embeds = self.gather_continuous_embeddings( + word_embeddings=inputs_embeds, + continuous_embeddings=patch_embeddings, + image_patch_input_indices=image_patches_indices, + ) + ``` + + The number of placeholder feature tokens for the `i`th item in the batch is `patch_embeddings[i].shape[0]`, + which is the same as `image_patches[i].shape[0]`, i.e. `num_total_patches`. + + Unlike LLaVA, Fuyu does not define the number of patches inside the modeling file. Where can we get more information? + Considering that the model input comes from the output of `FuyuProcessor`, let's **look at the preprocessing files**. + + The image outputs are obtained by calling `FuyuImageProcessor.preprocess` and then + `FuyuImageProcessor.preprocess_with_tokenizer_info` inside `FuyuProcessor`. + + In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`, + returning the dimensions after resizing (but before padding) as metadata. + + ```python + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544 + image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"]) + batch_images = image_encoding["images"] + image_unpadded_heights = image_encoding["image_unpadded_heights"] + image_unpadded_widths = image_encoding["image_unpadded_widths"] + + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L480-L + if do_resize: + batch_images = [ + [self.resize(image, size=size, input_data_format=input_data_format) for image in images] + for images in batch_images + ] + + image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images] + image_unpadded_heights = [[image_size[0]] for image_size in image_sizes] + image_unpadded_widths = [[image_size[1]] for image_size in image_sizes] + + if do_pad: + batch_images = [ + [ + self.pad_image( + image, + size=size, + mode=padding_mode, + constant_values=padding_value, + input_data_format=input_data_format, + ) + for image in images + ] + for images in batch_images + ] + ``` + + In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata: + + ```python + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425 + model_image_input = self.image_processor.preprocess_with_tokenizer_info( + image_input=tensor_batch_images, + image_present=image_present, + image_unpadded_h=image_unpadded_heights, + image_unpadded_w=image_unpadded_widths, + image_placeholder_id=image_placeholder_id, + image_newline_id=image_newline_id, + variable_sized=True, + ) + + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L638-L658 + image_height, image_width = image.shape[1], image.shape[2] + if variable_sized: # variable_sized=True + new_h = min( + image_height, + math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height, + ) + new_w = min( + image_width, + math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width, + ) + image = image[:, :new_h, :new_w] + image_height, image_width = new_h, new_w + + num_patches = self.get_num_patches(image_height=image_height, image_width=image_width) + tensor_of_image_ids = torch.full( + [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device + ) + patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0) + assert num_patches == patches.shape[0] + ``` + + The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`: + + ```python + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562 + patch_size = patch_size if patch_size is not None else self.patch_size + patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] + + if image_height % patch_height != 0: + raise ValueError(f"{image_height=} must be divisible by {patch_height}") + if image_width % patch_width != 0: + raise ValueError(f"{image_width=} must be divisible by {patch_width}") + + num_patches_per_dim_h = image_height // patch_height + num_patches_per_dim_w = image_width // patch_width + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + ``` + + These image patches correspond to placeholder tokens (`|SPEAKER|`). So, we just need to maximize the number of image patches. Since input images are first resized + to fit within `image_processor.size`, we can maximize the number of image patches by inputting an image with size equal to `image_processor.size`. + + ```python + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() + return ImageSize(width=image_processor.size["width"], + height=image_processor.size["height"]) + ``` + + Fuyu does not expect image placeholders in the inputs to HF processor, so + the dummy prompt text is empty regardless of the number of images. + + ```python + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + ``` + + For the multimodal image profiling data, the logic is very similar to LLaVA: + + ```python + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + ``` + +## 4. Specify processing details + +Afterwards, create a subclass of [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] +to fill in the missing details about HF processing. + +!!! info + [Multi-Modal Data Processing][mm-processing] + +### Multi-modal fields + +Override [_get_mm_fields_config][vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config] to +return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items. + +=== "Basic example: LLaVA" + + The output of `CLIPImageProcessor` is a simple tensor with shape + `(num_images, num_channels, image_height, image_width)`: + + ```python + # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/image_processing_clip.py#L339-L345 + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + ``` + + So, we override [_get_mm_fields_config][vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config] as follows: + + ```python + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + ) + ``` + + !!! note + Our [actual code](gh-file:vllm/model_executor/models/llava.py) additionally supports + pre-computed image embeddings, which can be passed to be model via the `image_embeds` argument. + +=== "With postprocessing: Fuyu" + + The `image_patches` output of `FuyuImageProcessor.preprocess_with_tokenizer_info` concatenates + the patches from each image belonging to an item in the batch: + + ```python + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L673-L679 + image_input_ids.append(tensor_of_image_ids) + image_patches.append(patches) + else: + image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device)) + + batch_image_input_ids.append(image_input_ids) + batch_image_patches.append(image_patches) + ``` + + The shape of `image_patches` outputted by `FuyuImageProcessor` is therefore + `(1, num_images, num_patches, patch_width * patch_height * num_channels)`. + + In order to support the use of [MultiModalFieldConfig.batched][] like in LLaVA, + we remove the extra batch dimension by overriding [BaseMultiModalProcessor._call_hf_processor][]: + + ```python + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + image_patches = processed_outputs.get("image_patches") + if image_patches is not None: + images = mm_data["images"] + assert isinstance(images, list) + + # Original output: (1, num_images, Pn, Px * Py * C) + # New output: (num_images, Pn, Px * Py * C) + assert (isinstance(image_patches, list) + and len(image_patches) == 1) + assert (isinstance(image_patches[0], torch.Tensor) + and len(image_patches[0]) == len(images)) + + processed_outputs["image_patches"] = image_patches[0] + + return processed_outputs + ``` + + !!! note + Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling + for text-only inputs to prevent unnecessary warnings from HF processor. + + This lets us override [_get_mm_fields_config][vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config] as follows: + + ```python + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(image_patches=MultiModalFieldConfig.batched("image")) + ``` + +### Prompt updates + +Override [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] to +return a list of [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instances. + +Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies an update operation +(e.g.: insertion, replacement) performed by the HF processor. + +=== "Basic example: LLaVA" + + Looking at HF's `LlavaProcessor`: + + ```python + # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/processing_llava.py#L167-L170 + prompt_strings = [] + for sample in text: + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) + ``` + + It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`). + Based on this, we override [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] as follows: + + ```python + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + ``` + +=== "Handling additional tokens: Fuyu" + + Recall the layout of feature tokens from Step 2: + + ``` + |SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| + |SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| + ... + |SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| + ``` + + We define a helper function to return `ncols` and `nrows` directly: + + ```python + def get_image_feature_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + image_processor = self.get_image_processor() + target_width = image_processor.size["width"] + target_height = image_processor.size["height"] + patch_width = image_processor.patch_size["width"] + patch_height = image_processor.patch_size["height"] + + if not (image_width <= target_width and image_height <= target_height): + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + image_height = int(image_height * optimal_scale_factor) + image_width = int(image_width * optimal_scale_factor) + + ncols = math.ceil(image_width / patch_width) + nrows = math.ceil(image_height / patch_height) + return ncols, nrows + ``` + + Based on this, we can initially define our replacement tokens as: + + ```python + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + + # `_IMAGE_TOKEN_ID` corresponds to `|SPEAKER|` + # `_NEWLINE_TOKEN_ID` corresponds to `|NEWLINE|` + return ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows + ``` + + However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called, + a BOS token (``) is also added to the promopt: + + ```python + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435 + model_image_input = self.image_processor.preprocess_with_tokenizer_info( + image_input=tensor_batch_images, + image_present=image_present, + image_unpadded_h=image_unpadded_heights, + image_unpadded_w=image_unpadded_widths, + image_placeholder_id=image_placeholder_id, + image_newline_id=image_newline_id, + variable_sized=True, + ) + prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( + tokenizer=self.tokenizer, + prompts=prompts, + scale_factors=scale_factors, + max_tokens_to_generate=self.max_tokens_to_generate, + max_position_embeddings=self.max_position_embeddings, + add_BOS=True, + add_beginning_of_answer_token=True, + ) + ``` + + To assign the vision embeddings to only the image tokens, instead of a string + you can return an instance of [PromptUpdateDetails][vllm.multimodal.processing.PromptUpdateDetails]: + + ```python + hf_config = self.info.get_hf_config() + bos_token_id = hf_config.bos_token_id # `` + assert isinstance(bos_token_id, int) + + def get_replacement_fuyu(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + + [_NEWLINE_TOKEN_ID]) * nrows + + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=_IMAGE_TOKEN_ID, + ) + ``` + + Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt, + we can search for it to conduct the replacement at the start of the string: + + ```python + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + bos_token_id = hf_config.bos_token_id + assert isinstance(bos_token_id, int) + + tokenizer = self.info.get_tokenizer() + eot_token_id = tokenizer.bos_token_id + assert isinstance(eot_token_id, int) + + def get_replacement_fuyu(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + + [_NEWLINE_TOKEN_ID]) * nrows + + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=_IMAGE_TOKEN_ID, + ) + + return [ + PromptReplacement( + modality="image", + target=[eot_token_id], + replacement=get_replacement_fuyu, + ) + ] + ``` + +## 5. Register processor-related classes + +After you have defined [BaseProcessingInfo][vllm.multimodal.processing.BaseProcessingInfo] (Step 2), +[BaseDummyInputsBuilder][vllm.multimodal.profiling.BaseDummyInputsBuilder] (Step 3), +and [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] (Step 4), +decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor ` +to register them to the multi-modal registry: + +```diff + from vllm.model_executor.models.interfaces import SupportsMultiModal ++ from vllm.multimodal import MULTIMODAL_REGISTRY + ++ @MULTIMODAL_REGISTRY.register_processor(YourMultiModalProcessor, ++ info=YourProcessingInfo, ++ dummy_inputs=YourDummyInputsBuilder) + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): +``` + +## Notes + +### Inserting feature tokens without replacement + +Some HF processors directly insert feature tokens without replacing anything in the original prompt. In that case, you can use [PromptInsertion][vllm.multimodal.processing.PromptInsertion] instead of [PromptReplacement][vllm.multimodal.processing.PromptReplacement] inside [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates]. + +Examples: + +- BLIP-2 (insert at start of prompt): +- Florence2 (insert at start of prompt): +- Molmo (insert after `<|endoftext|>` token): + +### Handling prompt updates unrelated to multi-modal data + +[_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override [_apply_hf_processor_tokens_only][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only] so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design][mm-processing]. + +Examples: + +- Chameleon (appends `sep_token`): +- Fuyu (appends `boa_token`): +- Molmo (applies chat template which is not defined elsewhere): + +### Custom HF processor + +Some models don't define a HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. + +Examples: + +- DeepSeek-VL2: +- InternVL: +- Qwen-VL: diff --git a/docs/contributing/model/registration.md b/docs/contributing/model/registration.md new file mode 100644 index 000000000000..7a7bd7914058 --- /dev/null +++ b/docs/contributing/model/registration.md @@ -0,0 +1,54 @@ +--- +title: Registering a Model to vLLM +--- +[](){ #new-model-registration } + +vLLM relies on a model registry to determine how to run each model. +A list of pre-registered architectures can be found [here][supported-models]. + +If your model is not on this list, you must register it to vLLM. +This page provides detailed instructions on how to do so. + +## Built-in models + +To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source][build-from-source]. +This gives you the ability to modify the codebase and test your model. + +After you have implemented your model (see [tutorial][new-model-basic]), put it into the directory. +Then, add your model class to `_VLLM_MODELS` in so that it is automatically registered upon importing vLLM. +Finally, update our [list of supported models][supported-models] to promote your model! + +!!! warning + The list of models in each section should be maintained in alphabetical order. + +## Out-of-tree models + +You can load an external model [using a plugin][plugin-system] without modifying the vLLM codebase. + +To register the model, use the following code: + +```python +# The entrypoint of your plugin +def register(): + from vllm import ModelRegistry + from your_code import YourModelForCausalLM + + ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) +``` + +If your model imports modules that initialize CUDA, consider lazy-importing it to avoid errors like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`: + +```python +# The entrypoint of your plugin +def register(): + from vllm import ModelRegistry + + ModelRegistry.register_model( + "YourModelForCausalLM", + "your_code:YourModelForCausalLM" + ) +``` + +!!! warning + If your model is a multimodal model, ensure the model class implements the [SupportsMultiModal][vllm.model_executor.models.interfaces.SupportsMultiModal] interface. + Read more about that [here][supports-multimodal]. diff --git a/docs/source/contributing/model/tests.md b/docs/contributing/model/tests.md similarity index 74% rename from docs/source/contributing/model/tests.md rename to docs/contributing/model/tests.md index 68d51d89f7cf..67f8eda61dc5 100644 --- a/docs/source/contributing/model/tests.md +++ b/docs/contributing/model/tests.md @@ -1,6 +1,7 @@ -(new-model-tests)= - -# Writing Unit Tests +--- +title: Writing Unit Tests +--- +[](){ #new-model-tests } This page explains how to write unit tests to verify the implementation of your model. @@ -14,14 +15,12 @@ Without them, the CI for your PR will fail. Include an example HuggingFace repository for your model in . This enables a unit test that loads dummy weights to ensure that the model can be initialized in vLLM. -:::{important} -The list of models in each section should be maintained in alphabetical order. -::: +!!! warning + The list of models in each section should be maintained in alphabetical order. -:::{tip} -If your model requires a development version of HF Transformers, you can set -`min_transformers_version` to skip the test in CI until the model is released. -::: +!!! tip + If your model requires a development version of HF Transformers, you can set + `min_transformers_version` to skip the test in CI until the model is released. ## Optional Tests @@ -34,16 +33,16 @@ These tests compare the model outputs of vLLM against [HF Transformers](https:// #### Generative models -For [generative models](#generative-models), there are two levels of correctness tests, as defined in : +For [generative models](../../models/generative_models.md), there are two levels of correctness tests, as defined in : - Exact correctness (`check_outputs_equal`): The text outputted by vLLM should exactly match the text outputted by HF. - Logprobs similarity (`check_logprobs_close`): The logprobs outputted by vLLM should be in the top-k logprobs outputted by HF, and vice versa. #### Pooling models -For [pooling models](#pooling-models), we simply check the cosine similarity, as defined in . +For [pooling models](../../models/pooling_models.md), we simply check the cosine similarity, as defined in . -(mm-processing-tests)= +[](){ #mm-processing-tests } ### Multi-modal processing diff --git a/docs/source/contributing/profiling/profiling_index.md b/docs/contributing/profiling.md similarity index 90% rename from docs/source/contributing/profiling/profiling_index.md rename to docs/contributing/profiling.md index ce25daa39c5c..be01b9b65f65 100644 --- a/docs/source/contributing/profiling/profiling_index.md +++ b/docs/contributing/profiling.md @@ -1,8 +1,7 @@ # Profiling vLLM -:::{warning} -Profiling is only intended for vLLM developers and maintainers to understand the proportion of time spent in different parts of the codebase. **vLLM end-users should never turn on profiling** as it will significantly slow down the inference. -::: +!!! warning + Profiling is only intended for vLLM developers and maintainers to understand the proportion of time spent in different parts of the codebase. **vLLM end-users should never turn on profiling** as it will significantly slow down the inference. ## Profile with PyTorch Profiler @@ -14,15 +13,13 @@ When using `benchmarks/benchmark_serving.py`, you can enable profiling by passin Traces can be visualized using . -:::{tip} -Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. -::: +!!! tip + Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. -:::{tip} -To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. -Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes. -`export VLLM_RPC_TIMEOUT=1800000` -::: +!!! tip + To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. + Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes. + `export VLLM_RPC_TIMEOUT=1800000` ### Example commands and usage diff --git a/docs/source/contributing/vulnerability_management.md b/docs/contributing/vulnerability_management.md similarity index 100% rename from docs/source/contributing/vulnerability_management.md rename to docs/contributing/vulnerability_management.md diff --git a/docs/deployment/docker.md b/docs/deployment/docker.md new file mode 100644 index 000000000000..516640f6fd3c --- /dev/null +++ b/docs/deployment/docker.md @@ -0,0 +1,129 @@ +--- +title: Using Docker +--- +[](){ #deployment-docker } + +[](){ #deployment-docker-pre-built-image } + +## Use vLLM's Official Docker Image + +vLLM offers an official Docker image for deployment. +The image can be used to run OpenAI compatible server and is available on Docker Hub as [vllm/vllm-openai](https://hub.docker.com/r/vllm/vllm-openai/tags). + +```console +docker run --runtime nvidia --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HUGGING_FACE_HUB_TOKEN=" \ + -p 8000:8000 \ + --ipc=host \ + vllm/vllm-openai:latest \ + --model mistralai/Mistral-7B-v0.1 +``` + +This image can also be used with other container engines such as [Podman](https://podman.io/). + +```console +podman run --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \ + -p 8000:8000 \ + --ipc=host \ + vllm/vllm-openai:latest \ + --model mistralai/Mistral-7B-v0.1 +``` + +You can add any other [engine-args][engine-args] you need after the image tag (`vllm/vllm-openai:latest`). + +!!! note + You can either use the `ipc=host` flag or `--shm-size` flag to allow the + container to access the host's shared memory. vLLM uses PyTorch, which uses shared + memory to share data between processes under the hood, particularly for tensor parallel inference. + +!!! note + Optional dependencies are not included in order to avoid licensing issues (e.g. ). + + If you need to use those dependencies (having accepted the license terms), + create a custom Dockerfile on top of the base image with an extra layer that installs them: + + ```Dockerfile + FROM vllm/vllm-openai:v0.8.3 + + # e.g. install the `audio` optional dependencies + # NOTE: Make sure the version of vLLM matches the base image! + RUN uv pip install --system vllm[audio]==0.8.3 + ``` + +!!! tip + Some new models may only be available on the main branch of [HF Transformers](https://github.com/huggingface/transformers). + + To use the development version of `transformers`, create a custom Dockerfile on top of the base image + with an extra layer that installs their code from source: + + ```Dockerfile + FROM vllm/vllm-openai:latest + + RUN uv pip install --system git+https://github.com/huggingface/transformers.git + ``` + +[](){ #deployment-docker-build-image-from-source } + +## Building vLLM's Docker Image from Source + +You can build and run vLLM from source via the provided . To build vLLM: + +```console +# optionally specifies: --build-arg max_jobs=8 --build-arg nvcc_threads=2 +DOCKER_BUILDKIT=1 docker build . \ + --target vllm-openai \ + --tag vllm/vllm-openai \ + --file docker/Dockerfile +``` + +!!! note + By default vLLM will build for all GPU types for widest distribution. If you are just building for the + current GPU type the machine is running on, you can add the argument `--build-arg torch_cuda_arch_list=""` + for vLLM to find the current GPU type and build for that. + + If you are using Podman instead of Docker, you might need to disable SELinux labeling by + adding `--security-opt label=disable` when running `podman build` command to avoid certain [existing issues](https://github.com/containers/buildah/discussions/4184). + +## Building for Arm64/aarch64 + +A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper. At time of this writing, this requires the use +of PyTorch Nightly and should be considered **experimental**. Using the flag `--platform "linux/arm64"` will attempt to build for arm64. + +!!! note + Multiple modules must be compiled, so this process can take a while. Recommend using `--build-arg max_jobs=` & `--build-arg nvcc_threads=` + flags to speed up build process. However, ensure your `max_jobs` is substantially larger than `nvcc_threads` to get the most benefits. + Keep an eye on memory usage with parallel jobs as it can be substantial (see example below). + +```console +# Example of building on Nvidia GH200 server. (Memory usage: ~15GB, Build time: ~1475s / ~25 min, Image size: 6.93GB) +python3 use_existing_torch.py +DOCKER_BUILDKIT=1 docker build . \ + --file docker/Dockerfile \ + --target vllm-openai \ + --platform "linux/arm64" \ + -t vllm/vllm-gh200-openai:latest \ + --build-arg max_jobs=66 \ + --build-arg nvcc_threads=2 \ + --build-arg torch_cuda_arch_list="9.0+PTX" \ + --build-arg vllm_fa_cmake_gpu_arches="90-real" +``` + +## Use the custom-built vLLM Docker image + +To run vLLM with the custom-built Docker image: + +```console +docker run --runtime nvidia --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + -p 8000:8000 \ + --env "HUGGING_FACE_HUB_TOKEN=" \ + vllm/vllm-openai +``` + +The argument `vllm/vllm-openai` specifies the image to run, and should be replaced with the name of the custom-built image (the `-t` tag from the build command). + +!!! note + **For version 0.4.1 and 0.4.2 only** - the vLLM docker images under these versions are supposed to be run under the root user since a library under the root user's home directory, i.e. `/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` is required to be loaded during runtime. If you are running the container under a different user, you may need to first change the permissions of the library (and all the parent directories) to allow the user to access it, then run vLLM with environment variable `VLLM_NCCL_SO_PATH=/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` . diff --git a/docs/source/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md similarity index 78% rename from docs/source/deployment/frameworks/anything-llm.md rename to docs/deployment/frameworks/anything-llm.md index d430c170ef54..a89e633c086e 100644 --- a/docs/source/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -1,6 +1,7 @@ -(deployment-anything-llm)= - -# Anything LLM +--- +title: Anything LLM +--- +[](){ #deployment-anything-llm } [Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. @@ -25,23 +26,19 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 - Base URL: http://{vllm server host}:{vllm server port}/v1 - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` -:::{image} /assets/deployment/anything-llm-provider.png -::: +![](../../assets/deployment/anything-llm-provider.png) - Back to home page, New Workspace --> create `vllm` workspace, and start to chat: -:::{image} /assets/deployment/anything-llm-chat-without-doc.png -::: +![](../../assets/deployment/anything-llm-chat-without-doc.png) - Click the upload button: - upload the doc - select the doc and move to the workspace - save and embed -:::{image} /assets/deployment/anything-llm-upload-doc.png -::: +![](../../assets/deployment/anything-llm-upload-doc.png) - Chat again: -:::{image} /assets/deployment/anything-llm-chat-with-doc.png -::: +![](../../assets/deployment/anything-llm-chat-with-doc.png) diff --git a/docs/deployment/frameworks/autogen.md b/docs/deployment/frameworks/autogen.md new file mode 100644 index 000000000000..ad8c167659ef --- /dev/null +++ b/docs/deployment/frameworks/autogen.md @@ -0,0 +1,83 @@ +--- +title: AutoGen +--- +[](){ #deployment-autogen } + +[AutoGen](https://github.com/microsoft/autogen) is a framework for creating multi-agent AI applications that can act autonomously or work alongside humans. + +## Prerequisites + +- Setup vLLM environment + +- Setup [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment + +```console +pip install vllm + +# Install AgentChat and OpenAI client from Extensions +# AutoGen requires Python 3.10 or later. +pip install -U "autogen-agentchat" "autogen-ext[openai]" +``` + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +python -m vllm.entrypoints.openai.api_server \ + --model mistralai/Mistral-7B-Instruct-v0.2 +``` + +- Call it with AutoGen: + +```python +import asyncio +from autogen_core.models import UserMessage +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_core.models import ModelFamily + + +async def main() -> None: + # Create a model client + model_client = OpenAIChatCompletionClient( + model="mistralai/Mistral-7B-Instruct-v0.2", + base_url="http://{your-vllm-host-ip}:{your-vllm-host-port}/v1", + api_key="EMPTY", + model_info={ + "vision": False, + "function_calling": False, + "json_output": False, + "family": ModelFamily.MISTRAL, + "structured_output": True, + }, + ) + + messages = [UserMessage(content="Write a very short story about a dragon.", source="user")] + + # Create a stream. + stream = model_client.create_stream(messages=messages) + + # Iterate over the stream and print the responses. + print("Streamed responses:") + async for response in stream: + if isinstance(response, str): + # A partial response is a string. + print(response, flush=True, end="") + else: + # The last response is a CreateResult object with the complete message. + print("\n\n------------\n") + print("The complete response:", flush=True) + print(response.content, flush=True) + + # Close the client when done. + await model_client.close() + + +asyncio.run(main()) +``` + +For details, see the tutorial: + +- [Using vLLM in AutoGen](https://microsoft.github.io/autogen/0.2/docs/topics/non-openai-models/local-vllm/) + +- [OpenAI-compatible API examples](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.models.openai.html#autogen_ext.models.openai.OpenAIChatCompletionClient) diff --git a/docs/source/deployment/frameworks/bentoml.md b/docs/deployment/frameworks/bentoml.md similarity index 89% rename from docs/source/deployment/frameworks/bentoml.md rename to docs/deployment/frameworks/bentoml.md index 2bf435bda838..7e64b6eb6fb0 100644 --- a/docs/source/deployment/frameworks/bentoml.md +++ b/docs/deployment/frameworks/bentoml.md @@ -1,6 +1,7 @@ -(deployment-bentoml)= - -# BentoML +--- +title: BentoML +--- +[](){ #deployment-bentoml } [BentoML](https://github.com/bentoml/BentoML) allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. You can serve the model locally or containerize it as an OCI-compliant image and deploy it on Kubernetes. diff --git a/docs/source/deployment/frameworks/cerebrium.md b/docs/deployment/frameworks/cerebrium.md similarity index 98% rename from docs/source/deployment/frameworks/cerebrium.md rename to docs/deployment/frameworks/cerebrium.md index b20c95137b6e..84cb2304fac2 100644 --- a/docs/source/deployment/frameworks/cerebrium.md +++ b/docs/deployment/frameworks/cerebrium.md @@ -1,12 +1,11 @@ -(deployment-cerebrium)= +--- +title: Cerebrium +--- +[](){ #deployment-cerebrium } -# Cerebrium - -:::{raw} html

vLLM_plus_cerebrium

-::: vLLM can be run on a cloud based GPU machine with [Cerebrium](https://www.cerebrium.ai/), a serverless AI infrastructure platform that makes it easier for companies to build and deploy AI based applications. diff --git a/docs/source/deployment/frameworks/chatbox.md b/docs/deployment/frameworks/chatbox.md similarity index 84% rename from docs/source/deployment/frameworks/chatbox.md rename to docs/deployment/frameworks/chatbox.md index e62f4647150f..10da2fc71002 100644 --- a/docs/source/deployment/frameworks/chatbox.md +++ b/docs/deployment/frameworks/chatbox.md @@ -1,6 +1,7 @@ -(deployment-chatbox)= - -# Chatbox +--- +title: Chatbox +--- +[](){ #deployment-chatbox } [Chatbox](https://github.com/chatboxai/chatbox) is a desktop client for LLMs, available on Windows, Mac, Linux. @@ -27,10 +28,8 @@ vllm serve qwen/Qwen1.5-0.5B-Chat - API Path: `/chat/completions` - Model: `qwen/Qwen1.5-0.5B-Chat` -:::{image} /assets/deployment/chatbox-settings.png -::: +![](../../assets/deployment/chatbox-settings.png) - Go to `Just chat`, and start to chat: -:::{image} /assets/deployment/chatbox-chat.png -::: +![](../../assets/deployment/chatbox-chat.png) diff --git a/docs/deployment/frameworks/dify.md b/docs/deployment/frameworks/dify.md new file mode 100644 index 000000000000..886484b54347 --- /dev/null +++ b/docs/deployment/frameworks/dify.md @@ -0,0 +1,54 @@ +--- +title: Dify +--- +[](){ #deployment-dify } + +[Dify](https://github.com/langgenius/dify) is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. + +It supports vLLM as a model provider to efficiently serve large language models. + +This guide walks you through deploying Dify using a vLLM backend. + +## Prerequisites + +- Setup vLLM environment +- Install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve Qwen/Qwen1.5-7B-Chat +``` + +- Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): + +```console +git clone https://github.com/langgenius/dify.git +cd dify +cd docker +cp .env.example .env +docker compose up -d +``` + +- Open the browser to access `http://localhost/install`, config the basic login information and login. + +- In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. + +- Fill in the model provider details as follows: + - **Model Type**: `LLM` + - **Model Name**: `Qwen/Qwen1.5-7B-Chat` + - **API Endpoint URL**: `http://{vllm_server_host}:{vllm_server_port}/v1` + - **Model Name for API Endpoint**: `Qwen/Qwen1.5-7B-Chat` + - **Completion Mode**: `Completion` + +![](../../assets/deployment/dify-settings.png) + +- To create a test chatbot, go to `Studio โ†’ Chatbot โ†’ Create from Blank`, then select Chatbot as the type: + +![](../../assets/deployment/dify-create-chatbot.png) + +- Click the chatbot you just created to open the chat interface and start interacting with the model: + +![](../../assets/deployment/dify-chat.png) diff --git a/docs/source/deployment/frameworks/dstack.md b/docs/deployment/frameworks/dstack.md similarity index 83% rename from docs/source/deployment/frameworks/dstack.md rename to docs/deployment/frameworks/dstack.md index a16e28f2d898..7de92855745b 100644 --- a/docs/source/deployment/frameworks/dstack.md +++ b/docs/deployment/frameworks/dstack.md @@ -1,12 +1,11 @@ -(deployment-dstack)= +--- +title: dstack +--- +[](){ #deployment-dstack } -# dstack - -:::{raw} html

vLLM_plus_dstack

-::: vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/), an open-source framework for running LLMs on any cloud. This tutorial assumes that you have already configured credentials, gateway, and GPU quotas on your cloud environment. @@ -97,6 +96,5 @@ completion = client.chat.completions.create( print(completion.choices[0].message.content) ``` -:::{note} -dstack automatically handles authentication on the gateway using dstack's tokens. Meanwhile, if you don't want to configure a gateway, you can provision dstack `Task` instead of `Service`. The `Task` is for development purpose only. If you want to know more about hands-on materials how to serve vLLM using dstack, check out [this repository](https://github.com/dstackai/dstack-examples/tree/main/deployment/vllm) -::: +!!! note + dstack automatically handles authentication on the gateway using dstack's tokens. Meanwhile, if you don't want to configure a gateway, you can provision dstack `Task` instead of `Service`. The `Task` is for development purpose only. If you want to know more about hands-on materials how to serve vLLM using dstack, check out [this repository](https://github.com/dstackai/dstack-examples/tree/main/deployment/vllm) diff --git a/docs/deployment/frameworks/haystack.md b/docs/deployment/frameworks/haystack.md new file mode 100644 index 000000000000..2eac4a5279fd --- /dev/null +++ b/docs/deployment/frameworks/haystack.md @@ -0,0 +1,60 @@ +--- +title: Haystack +--- +[](){ #deployment-haystack } + +# Haystack + +[Haystack](https://github.com/deepset-ai/haystack) is an end-to-end LLM framework that allows you to build applications powered by LLMs, Transformer models, vector search and more. Whether you want to perform retrieval-augmented generation (RAG), document search, question answering or answer generation, Haystack can orchestrate state-of-the-art embedding models and LLMs into pipelines to build end-to-end NLP applications and solve your use case. + +It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. + +## Prerequisites + +- Setup vLLM and Haystack environment + +```console +pip install vllm haystack-ai +``` + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve mistralai/Mistral-7B-Instruct-v0.1 +``` + +- Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server. + +```python +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from haystack.utils import Secret + +generator = OpenAIChatGenerator( + # for compatibility with the OpenAI API, a placeholder api_key is needed + api_key=Secret.from_token("VLLM-PLACEHOLDER-API-KEY"), + model="mistralai/Mistral-7B-Instruct-v0.1", + api_base_url="http://{your-vLLM-host-ip}:{your-vLLM-host-port}/v1", + generation_kwargs = {"max_tokens": 512} +) + +response = generator.run( + messages=[ChatMessage.from_user("Hi. Can you help me plan my next trip to Italy?")] +) + +print("-"*30) +print(response) +print("-"*30) +``` + +Output e.g.: + +```console +------------------------------ +{'replies': [ChatMessage(_role=, _content=[TextContent(text=' Of course! Where in Italy would you like to go and what type of trip are you looking to plan?')], _name=None, _meta={'model': 'mistralai/Mistral-7B-Instruct-v0.1', 'index': 0, 'finish_reason': 'stop', 'usage': {'completion_tokens': 23, 'prompt_tokens': 21, 'total_tokens': 44, 'completion_tokens_details': None, 'prompt_tokens_details': None}})]} +------------------------------ +``` + +For details, see the tutorial [Using vLLM in Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/vllm.md). diff --git a/docs/deployment/frameworks/helm.md b/docs/deployment/frameworks/helm.md new file mode 100644 index 000000000000..192b90438acf --- /dev/null +++ b/docs/deployment/frameworks/helm.md @@ -0,0 +1,95 @@ +--- +title: Helm +--- +[](){ #deployment-helm } + +A Helm chart to deploy vLLM for Kubernetes + +Helm is a package manager for Kubernetes. It will help you to deploy vLLM on k8s and automate the deployment of vLLM Kubernetes applications. With Helm, you can deploy the same framework architecture with different configurations to multiple namespaces by overriding variable values. + +This guide will walk you through the process of deploying vLLM with Helm, including the necessary prerequisites, steps for helm installation and documentation on architecture and values file. + +## Prerequisites + +Before you begin, ensure that you have the following: + +- A running Kubernetes cluster +- NVIDIA Kubernetes Device Plugin (`k8s-device-plugin`): This can be found at [https://github.com/NVIDIA/k8s-device-plugin](https://github.com/NVIDIA/k8s-device-plugin) +- Available GPU resources in your cluster +- S3 with the model which will be deployed + +## Installing the chart + +To install the chart with the release name `test-vllm`: + +```console +helm upgrade --install --create-namespace --namespace=ns-vllm test-vllm . -f values.yaml --set secrets.s3endpoint=$ACCESS_POINT --set secrets.s3bucketname=$BUCKET --set secrets.s3accesskeyid=$ACCESS_KEY --set secrets.s3accesskey=$SECRET_KEY +``` + +## Uninstalling the Chart + +To uninstall the `test-vllm` deployment: + +```console +helm uninstall test-vllm --namespace=ns-vllm +``` + +The command removes all the Kubernetes components associated with the +chart **including persistent volumes** and deletes the release. + +## Architecture + +![](../../assets/deployment/architecture_helm_deployment.png) + +## Values + +| Key | Type | Default | Description | +|--------------------------------------------|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------| +| autoscaling | object | {"enabled":false,"maxReplicas":100,"minReplicas":1,"targetCPUUtilizationPercentage":80} | Autoscaling configuration | +| autoscaling.enabled | bool | false | Enable autoscaling | +| autoscaling.maxReplicas | int | 100 | Maximum replicas | +| autoscaling.minReplicas | int | 1 | Minimum replicas | +| autoscaling.targetCPUUtilizationPercentage | int | 80 | Target CPU utilization for autoscaling | +| configs | object | {} | Configmap | +| containerPort | int | 8000 | Container port | +| customObjects | list | [] | Custom Objects configuration | +| deploymentStrategy | object | {} | Deployment strategy configuration | +| externalConfigs | list | [] | External configuration | +| extraContainers | list | [] | Additional containers configuration | +| extraInit | object | {"pvcStorage":"1Gi","s3modelpath":"relative_s3_model_path/opt-125m", "awsEc2MetadataDisabled": true} | Additional configuration for the init container | +| extraInit.pvcStorage | string | "50Gi" | Storage size of the s3 | +| extraInit.s3modelpath | string | "relative_s3_model_path/opt-125m" | Path of the model on the s3 which hosts model weights and config files | +| extraInit.awsEc2MetadataDisabled | boolean | true | Disables the use of the Amazon EC2 instance metadata service | +| extraPorts | list | [] | Additional ports configuration | +| gpuModels | list | ["TYPE_GPU_USED"] | Type of gpu used | +| image | object | {"command":["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"],"repository":"vllm/vllm-openai","tag":"latest"} | Image configuration | +| image.command | list | ["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"] | Container launch command | +| image.repository | string | "vllm/vllm-openai" | Image repository | +| image.tag | string | "latest" | Image tag | +| livenessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":15,"periodSeconds":10} | Liveness probe configuration | +| livenessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not alive | +| livenessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the Kubelet http request on the server | +| livenessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server | +| livenessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening | +| livenessProbe.initialDelaySeconds | int | 15 | Number of seconds after the container has started before liveness probe is initiated | +| livenessProbe.periodSeconds | int | 10 | How often (in seconds) to perform the liveness probe | +| maxUnavailablePodDisruptionBudget | string | "" | Disruption Budget Configuration | +| readinessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":5,"periodSeconds":5} | Readiness probe configuration | +| readinessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not ready | +| readinessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the Kubelet http request on the server | +| readinessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server | +| readinessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening | +| readinessProbe.initialDelaySeconds | int | 5 | Number of seconds after the container has started before readiness probe is initiated | +| readinessProbe.periodSeconds | int | 5 | How often (in seconds) to perform the readiness probe | +| replicaCount | int | 1 | Number of replicas | +| resources | object | {"limits":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1},"requests":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1}} | Resource configuration | +| resources.limits."nvidia.com/gpu" | int | 1 | Number of gpus used | +| resources.limits.cpu | int | 4 | Number of CPUs | +| resources.limits.memory | string | "16Gi" | CPU memory configuration | +| resources.requests."nvidia.com/gpu" | int | 1 | Number of gpus used | +| resources.requests.cpu | int | 4 | Number of CPUs | +| resources.requests.memory | string | "16Gi" | CPU memory configuration | +| secrets | object | {} | Secrets configuration | +| serviceName | string | Service name | | +| servicePort | int | 80 | Service port | +| labels.environment | string | test | Environment name | diff --git a/docs/deployment/frameworks/litellm.md b/docs/deployment/frameworks/litellm.md new file mode 100644 index 000000000000..3011cde83018 --- /dev/null +++ b/docs/deployment/frameworks/litellm.md @@ -0,0 +1,76 @@ +--- +title: LiteLLM +--- +[](){ #deployment-litellm } + +[LiteLLM](https://github.com/BerriAI/litellm) call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, Groq etc.] + +LiteLLM manages: + +- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints +- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']` +- Retry/fallback logic across multiple deployments (e.g. Azure/OpenAI) - [Router](https://docs.litellm.ai/docs/routing) +- Set Budgets & Rate limits per project, api key, model [LiteLLM Proxy Server (LLM Gateway)](https://docs.litellm.ai/docs/simple_proxy) + +And LiteLLM supports all models on VLLM. + +## Prerequisites + +- Setup vLLM and litellm environment + +```console +pip install vllm litellm +``` + +## Deploy + +### Chat completion + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve qwen/Qwen1.5-0.5B-Chat +``` + +- Call it with litellm: + +```python +import litellm + +messages = [{ "content": "Hello, how are you?","role": "user"}] + +# hosted_vllm is prefix key word and necessary +response = litellm.completion( + model="hosted_vllm/qwen/Qwen1.5-0.5B-Chat", # pass the vllm model name + messages=messages, + api_base="http://{your-vllm-server-host}:{your-vllm-server-port}/v1", + temperature=0.2, + max_tokens=80) + +print(response) +``` + +### Embeddings + +- Start the vLLM server with the supported embedding model, e.g. + +```console +vllm serve BAAI/bge-base-en-v1.5 +``` + +- Call it with litellm: + +```python +from litellm import embedding +import os + +os.environ["HOSTED_VLLM_API_BASE"] = "http://{your-vllm-server-host}:{your-vllm-server-port}/v1" + +# hosted_vllm is prefix key word and necessary +# pass the vllm model name +embedding = embedding(model="hosted_vllm/BAAI/bge-base-en-v1.5", input=["Hello world"]) + +print(embedding) +``` + +For details, see the tutorial [Using vLLM in LiteLLM](https://docs.litellm.ai/docs/providers/vllm). diff --git a/docs/deployment/frameworks/lobe-chat.md b/docs/deployment/frameworks/lobe-chat.md new file mode 100644 index 000000000000..cd95c028155e --- /dev/null +++ b/docs/deployment/frameworks/lobe-chat.md @@ -0,0 +1,14 @@ +--- +title: Lobe Chat +--- +[](){ #deployment-lobe-chat } + +[Lobe Chat](https://github.com/lobehub/lobe-chat) is an open-source, modern-design ChatGPT/LLMs UI/Framework. + +Supports speech-synthesis, multi-modal, and extensible (function call) plugin system. + +One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application. + +It supports vLLM as a AI model provider to efficiently serve large language models. + +For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm). diff --git a/docs/source/deployment/frameworks/lws.md b/docs/deployment/frameworks/lws.md similarity index 99% rename from docs/source/deployment/frameworks/lws.md rename to docs/deployment/frameworks/lws.md index 4e9a03b5c4c1..18282a89ddff 100644 --- a/docs/source/deployment/frameworks/lws.md +++ b/docs/deployment/frameworks/lws.md @@ -1,6 +1,7 @@ -(deployment-lws)= - -# LWS +--- +title: LWS +--- +[](){ #deployment-lws } LeaderWorkerSet (LWS) is a Kubernetes API that aims to address common deployment patterns of AI/ML inference workloads. A major use case is for multi-host/multi-node distributed inference. diff --git a/docs/source/deployment/frameworks/modal.md b/docs/deployment/frameworks/modal.md similarity index 85% rename from docs/source/deployment/frameworks/modal.md rename to docs/deployment/frameworks/modal.md index e7c42088e36a..dbdb739a1000 100644 --- a/docs/source/deployment/frameworks/modal.md +++ b/docs/deployment/frameworks/modal.md @@ -1,6 +1,7 @@ -(deployment-modal)= - -# Modal +--- +title: Modal +--- +[](){ #deployment-modal } vLLM can be run on cloud GPUs with [Modal](https://modal.com), a serverless computing platform designed for fast auto-scaling. diff --git a/docs/source/deployment/frameworks/open-webui.md b/docs/deployment/frameworks/open-webui.md similarity index 87% rename from docs/source/deployment/frameworks/open-webui.md rename to docs/deployment/frameworks/open-webui.md index 83e5303a00ef..1ab1931068fa 100644 --- a/docs/source/deployment/frameworks/open-webui.md +++ b/docs/deployment/frameworks/open-webui.md @@ -1,6 +1,7 @@ -(deployment-open-webui)= - -# Open WebUI +--- +title: Open WebUI +--- +[](){ #deployment-open-webui } 1. Install the [Docker](https://docs.docker.com/engine/install/) @@ -25,5 +26,4 @@ ghcr.io/open-webui/open-webui:main On the top of the web page, you can see the model `qwen/Qwen1.5-0.5B-Chat`. -:::{image} /assets/deployment/open_webui.png -::: +![](../../assets/deployment/open_webui.png) diff --git a/docs/source/deployment/frameworks/retrieval_augmented_generation.md b/docs/deployment/frameworks/retrieval_augmented_generation.md similarity index 96% rename from docs/source/deployment/frameworks/retrieval_augmented_generation.md rename to docs/deployment/frameworks/retrieval_augmented_generation.md index f84451fafe91..cb26c8378dee 100644 --- a/docs/source/deployment/frameworks/retrieval_augmented_generation.md +++ b/docs/deployment/frameworks/retrieval_augmented_generation.md @@ -1,6 +1,7 @@ -(deployment-retrieval-augmented-generation)= - -# Retrieval-Augmented Generation +--- +title: Retrieval-Augmented Generation +--- +[](){ #deployment-retrieval-augmented-generation } [Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources. diff --git a/docs/source/deployment/frameworks/skypilot.md b/docs/deployment/frameworks/skypilot.md similarity index 94% rename from docs/source/deployment/frameworks/skypilot.md rename to docs/deployment/frameworks/skypilot.md index 5e101b900103..9763745f2378 100644 --- a/docs/source/deployment/frameworks/skypilot.md +++ b/docs/deployment/frameworks/skypilot.md @@ -1,12 +1,11 @@ -(deployment-skypilot)= +--- +title: SkyPilot +--- +[](){ #deployment-skypilot } -# SkyPilot - -:::{raw} html

vLLM

-::: vLLM can be **run and scaled to multiple service replicas on clouds and Kubernetes** with [SkyPilot](https://github.com/skypilot-org/skypilot), an open-source framework for running LLMs on any cloud. More examples for various open models, such as Llama-3, Mixtral, etc, can be found in [SkyPilot AI gallery](https://skypilot.readthedocs.io/en/latest/gallery/index.html). @@ -83,7 +82,11 @@ Check the output of the command. There will be a shareable gradio link (like the **Optional**: Serve the 70B model instead of the default 8B and use more GPU: ```console -HF_TOKEN="your-huggingface-token" sky launch serving.yaml --gpus A100:8 --env HF_TOKEN --env MODEL_NAME=meta-llama/Meta-Llama-3-70B-Instruct +HF_TOKEN="your-huggingface-token" \ + sky launch serving.yaml \ + --gpus A100:8 \ + --env HF_TOKEN \ + --env MODEL_NAME=meta-llama/Meta-Llama-3-70B-Instruct ``` ## Scale up to multiple replicas @@ -104,10 +107,8 @@ service: max_completion_tokens: 1 ``` -:::{raw} html
Click to see the full recipe YAML -::: ```yaml service: @@ -153,14 +154,14 @@ run: | 2>&1 | tee api_server.log ``` -:::{raw} html
-::: Start the serving the Llama-3 8B model on multiple replicas: ```console -HF_TOKEN="your-huggingface-token" sky serve up -n vllm serving.yaml --env HF_TOKEN +HF_TOKEN="your-huggingface-token" \ + sky serve up -n vllm serving.yaml \ + --env HF_TOKEN ``` Wait until the service is ready: @@ -169,10 +170,8 @@ Wait until the service is ready: watch -n10 sky serve status vllm ``` -:::{raw} html
Example outputs: -::: ```console Services @@ -185,9 +184,7 @@ vllm 1 1 xx.yy.zz.121 18 mins ago 1x GCP([Spot]{'L4': 1}) R vllm 2 1 xx.yy.zz.245 18 mins ago 1x GCP([Spot]{'L4': 1}) READY us-east4 ``` -:::{raw} html
-::: After the service is READY, you can find a single endpoint for the service and access the service with the endpoint: @@ -223,10 +220,8 @@ service: This will scale the service up to when the QPS exceeds 2 for each replica. -:::{raw} html
Click to see the full recipe YAML -::: ```yaml service: @@ -275,9 +270,7 @@ run: | 2>&1 | tee api_server.log ``` -:::{raw} html
-::: To update the service with the new config: @@ -295,10 +288,8 @@ sky serve down vllm It is also possible to access the Llama-3 service with a separate GUI frontend, so the user requests send to the GUI will be load-balanced across replicas. -:::{raw} html
Click to see the full GUI YAML -::: ```yaml envs: @@ -328,14 +319,14 @@ run: | --stop-token-ids 128009,128001 | tee ~/gradio.log ``` -:::{raw} html
-::: 1. Start the chat web UI: ```console - sky launch -c gui ./gui.yaml --env ENDPOINT=$(sky serve status --endpoint vllm) + sky launch \ + -c gui ./gui.yaml \ + --env ENDPOINT=$(sky serve status --endpoint vllm) ``` 2. Then, we can access the GUI at the returned gradio link: diff --git a/docs/source/deployment/frameworks/streamlit.md b/docs/deployment/frameworks/streamlit.md similarity index 81% rename from docs/source/deployment/frameworks/streamlit.md rename to docs/deployment/frameworks/streamlit.md index 084550ec991e..33ed8c5f5b54 100644 --- a/docs/source/deployment/frameworks/streamlit.md +++ b/docs/deployment/frameworks/streamlit.md @@ -1,6 +1,7 @@ -(deployment-streamlit)= - -# Streamlit +--- +title: Streamlit +--- +[](){ #deployment-streamlit } [Streamlit](https://github.com/streamlit/streamlit) lets you transform Python scripts into interactive web apps in minutes, instead of weeks. Build dashboards, generate reports, or create chat apps. @@ -32,11 +33,11 @@ pip install streamlit openai streamlit run streamlit_openai_chatbot_webserver.py # or specify the VLLM_API_BASE or VLLM_API_KEY -VLLM_API_BASE="http://vllm-server-host:vllm-server-port/v1" streamlit run streamlit_openai_chatbot_webserver.py +VLLM_API_BASE="http://vllm-server-host:vllm-server-port/v1" \ + streamlit run streamlit_openai_chatbot_webserver.py # start with debug mode to view more details streamlit run streamlit_openai_chatbot_webserver.py --logger.level=debug ``` -:::{image} /assets/deployment/streamlit-chat.png -::: +![](../../assets/deployment/streamlit-chat.png) diff --git a/docs/source/deployment/frameworks/triton.md b/docs/deployment/frameworks/triton.md similarity index 87% rename from docs/source/deployment/frameworks/triton.md rename to docs/deployment/frameworks/triton.md index 94d87120159c..082bc24d85aa 100644 --- a/docs/source/deployment/frameworks/triton.md +++ b/docs/deployment/frameworks/triton.md @@ -1,5 +1,6 @@ -(deployment-triton)= - -# NVIDIA Triton +--- +title: NVIDIA Triton +--- +[](){ #deployment-triton } The [Triton Inference Server](https://github.com/triton-inference-server) hosts a tutorial demonstrating how to quickly deploy a simple [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model using vLLM. Please see [Deploying a vLLM model in Triton](https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton) for more details. diff --git a/docs/source/deployment/integrations/kserve.md b/docs/deployment/integrations/kserve.md similarity index 85% rename from docs/source/deployment/integrations/kserve.md rename to docs/deployment/integrations/kserve.md index c780fd74e8f5..754b983dee92 100644 --- a/docs/source/deployment/integrations/kserve.md +++ b/docs/deployment/integrations/kserve.md @@ -1,6 +1,7 @@ -(deployment-kserve)= - -# KServe +--- +title: KServe +--- +[](){ #deployment-kserve } vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving. diff --git a/docs/source/deployment/integrations/kubeai.md b/docs/deployment/integrations/kubeai.md similarity index 93% rename from docs/source/deployment/integrations/kubeai.md rename to docs/deployment/integrations/kubeai.md index 2f5772e075d8..ba0a3c52cca7 100644 --- a/docs/source/deployment/integrations/kubeai.md +++ b/docs/deployment/integrations/kubeai.md @@ -1,6 +1,7 @@ -(deployment-kubeai)= - -# KubeAI +--- +title: KubeAI +--- +[](){ #deployment-kubeai } [KubeAI](https://github.com/substratusai/kubeai) is a Kubernetes operator that enables you to deploy and manage AI models on Kubernetes. It provides a simple and scalable way to deploy vLLM in production. Functionality such as scale-from-zero, load based autoscaling, model caching, and much more is provided out of the box with zero external dependencies. diff --git a/docs/source/deployment/integrations/llamastack.md b/docs/deployment/integrations/llamastack.md similarity index 94% rename from docs/source/deployment/integrations/llamastack.md rename to docs/deployment/integrations/llamastack.md index a6c3569637ab..2ae600a423ff 100644 --- a/docs/source/deployment/integrations/llamastack.md +++ b/docs/deployment/integrations/llamastack.md @@ -1,6 +1,7 @@ -(deployment-llamastack)= - -# Llama Stack +--- +title: Llama Stack +--- +[](){ #deployment-llamastack } vLLM is also available via [Llama Stack](https://github.com/meta-llama/llama-stack) . diff --git a/docs/source/deployment/integrations/llmaz.md b/docs/deployment/integrations/llmaz.md similarity index 87% rename from docs/source/deployment/integrations/llmaz.md rename to docs/deployment/integrations/llmaz.md index cd4a76353d26..03d284c34769 100644 --- a/docs/source/deployment/integrations/llmaz.md +++ b/docs/deployment/integrations/llmaz.md @@ -1,6 +1,7 @@ -(deployment-llmaz)= - -# llmaz +--- +title: llmaz +--- +[](){ #deployment-llmaz } [llmaz](https://github.com/InftyAI/llmaz) is an easy-to-use and advanced inference platform for large language models on Kubernetes, aimed for production use. It uses vLLM as the default model serving backend. diff --git a/docs/source/deployment/integrations/production-stack.md b/docs/deployment/integrations/production-stack.md similarity index 98% rename from docs/source/deployment/integrations/production-stack.md rename to docs/deployment/integrations/production-stack.md index 05f1568306cc..8288a4b6e6be 100644 --- a/docs/source/deployment/integrations/production-stack.md +++ b/docs/deployment/integrations/production-stack.md @@ -1,6 +1,7 @@ -(deployment-production-stack)= - -# Production stack +--- +title: Production stack +--- +[](){ #deployment-production-stack } Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine learning models. This guide walks you through deploying vLLM using the [vLLM production stack](https://github.com/vllm-project/production-stack). Born out of a Berkeley-UChicago collaboration, [vLLM production stack](https://github.com/vllm-project/production-stack) is an officially released, production-optimized codebase under the [vLLM project](https://github.com/vllm-project), designed for LLM deployment with: @@ -114,7 +115,7 @@ To remove the deployment, run: sudo helm uninstall vllm ``` ------- +--- ### (Advanced) Configuring vLLM production stack diff --git a/docs/source/deployment/k8s.md b/docs/deployment/k8s.md similarity index 98% rename from docs/source/deployment/k8s.md rename to docs/deployment/k8s.md index 9079cfa8e1b6..6b08c4960d02 100644 --- a/docs/source/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -1,6 +1,7 @@ -(deployment-k8s)= - -# Using Kubernetes +--- +title: Using Kubernetes +--- +[](){ #deployment-k8s } Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine learning models. This guide walks you through deploying vLLM using native Kubernetes. @@ -8,6 +9,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le * [Deployment with GPUs](#deployment-with-gpus) Alternatively, you can deploy vLLM to Kubernetes using any of the following: + * [Helm](frameworks/helm.md) * [InftyAI/llmaz](integrations/llmaz.md) * [KServe](integrations/kserve.md) @@ -19,9 +21,8 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following: ## Deployment with CPUs -:::{note} -The use of CPUs here is for demonstration and testing purposes only and its performance will not be on par with GPUs. -::: +!!! note + The use of CPUs here is for demonstration and testing purposes only and its performance will not be on par with GPUs. First, create a Kubernetes PVC and Secret for downloading and storing Hugging Face model: diff --git a/docs/source/deployment/nginx.md b/docs/deployment/nginx.md similarity index 60% rename from docs/source/deployment/nginx.md rename to docs/deployment/nginx.md index bf404f1098c3..80242919ba5b 100644 --- a/docs/source/deployment/nginx.md +++ b/docs/deployment/nginx.md @@ -1,20 +1,21 @@ -(nginxloadbalancer)= - -# Using Nginx +--- +title: Using Nginx +--- +[](){ #nginxloadbalancer } This document shows how to launch multiple vLLM serving containers and use Nginx to act as a load balancer between the servers. Table of contents: -1. [Build Nginx Container](#nginxloadbalancer-nginx-build) -2. [Create Simple Nginx Config file](#nginxloadbalancer-nginx-conf) -3. [Build vLLM Container](#nginxloadbalancer-nginx-vllm-container) -4. [Create Docker Network](#nginxloadbalancer-nginx-docker-network) -5. [Launch vLLM Containers](#nginxloadbalancer-nginx-launch-container) -6. [Launch Nginx](#nginxloadbalancer-nginx-launch-nginx) -7. [Verify That vLLM Servers Are Ready](#nginxloadbalancer-nginx-verify-nginx) +1. [Build Nginx Container][nginxloadbalancer-nginx-build] +2. [Create Simple Nginx Config file][nginxloadbalancer-nginx-conf] +3. [Build vLLM Container][nginxloadbalancer-nginx-vllm-container] +4. [Create Docker Network][nginxloadbalancer-nginx-docker-network] +5. [Launch vLLM Containers][nginxloadbalancer-nginx-launch-container] +6. [Launch Nginx][nginxloadbalancer-nginx-launch-nginx] +7. [Verify That vLLM Servers Are Ready][nginxloadbalancer-nginx-verify-nginx] -(nginxloadbalancer-nginx-build)= +[](){ #nginxloadbalancer-nginx-build } ## Build Nginx Container @@ -39,7 +40,7 @@ Build the container: docker build . -f Dockerfile.nginx --tag nginx-lb ``` -(nginxloadbalancer-nginx-conf)= +[](){ #nginxloadbalancer-nginx-conf } ## Create Simple Nginx Config file @@ -63,7 +64,7 @@ server { } ``` -(nginxloadbalancer-nginx-vllm-container)= +[](){ #nginxloadbalancer-nginx-vllm-container } ## Build vLLM Container @@ -76,10 +77,14 @@ If you are behind proxy, you can pass the proxy settings to the docker build com ```console cd $vllm_root -docker build -f docker/Dockerfile . --tag vllm --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy +docker build \ + -f docker/Dockerfile . \ + --tag vllm \ + --build-arg http_proxy=$http_proxy \ + --build-arg https_proxy=$https_proxy ``` -(nginxloadbalancer-nginx-docker-network)= +[](){ #nginxloadbalancer-nginx-docker-network } ## Create Docker Network @@ -87,7 +92,7 @@ docker build -f docker/Dockerfile . --tag vllm --build-arg http_proxy=$http_prox docker network create vllm_nginx ``` -(nginxloadbalancer-nginx-launch-container)= +[](){ #nginxloadbalancer-nginx-launch-container } ## Launch vLLM Containers @@ -101,23 +106,45 @@ Notes: ```console mkdir -p ~/.cache/huggingface/hub/ hf_cache_dir=~/.cache/huggingface/ -docker run -itd --ipc host --network vllm_nginx --gpus device=0 --shm-size=10.24gb -v $hf_cache_dir:/root/.cache/huggingface/ -p 8081:8000 --name vllm0 vllm --model meta-llama/Llama-2-7b-chat-hf -docker run -itd --ipc host --network vllm_nginx --gpus device=1 --shm-size=10.24gb -v $hf_cache_dir:/root/.cache/huggingface/ -p 8082:8000 --name vllm1 vllm --model meta-llama/Llama-2-7b-chat-hf +docker run \ + -itd \ + --ipc host \ + --network vllm_nginx \ + --gpus device=0 \ + --shm-size=10.24gb \ + -v $hf_cache_dir:/root/.cache/huggingface/ \ + -p 8081:8000 \ + --name vllm0 vllm \ + --model meta-llama/Llama-2-7b-chat-hf +docker run \ + -itd \ + --ipc host \ + --network vllm_nginx \ + --gpus device=1 \ + --shm-size=10.24gb \ + -v $hf_cache_dir:/root/.cache/huggingface/ \ + -p 8082:8000 \ + --name vllm1 vllm \ + --model meta-llama/Llama-2-7b-chat-hf ``` -:::{note} -If you are behind proxy, you can pass the proxy settings to the docker run command via `-e http_proxy=$http_proxy -e https_proxy=$https_proxy`. -::: +!!! note + If you are behind proxy, you can pass the proxy settings to the docker run command via `-e http_proxy=$http_proxy -e https_proxy=$https_proxy`. -(nginxloadbalancer-nginx-launch-nginx)= +[](){ #nginxloadbalancer-nginx-launch-nginx } ## Launch Nginx ```console -docker run -itd -p 8000:80 --network vllm_nginx -v ./nginx_conf/:/etc/nginx/conf.d/ --name nginx-lb nginx-lb:latest +docker run \ + -itd \ + -p 8000:80 \ + --network vllm_nginx \ + -v ./nginx_conf/:/etc/nginx/conf.d/ \ + --name nginx-lb nginx-lb:latest ``` -(nginxloadbalancer-nginx-verify-nginx)= +[](){ #nginxloadbalancer-nginx-verify-nginx } ## Verify That vLLM Servers Are Ready diff --git a/docs/source/design/arch_overview.md b/docs/design/arch_overview.md similarity index 81% rename from docs/source/design/arch_overview.md rename to docs/design/arch_overview.md index 94bda8b5c58d..75d3e1b7ccc7 100644 --- a/docs/source/design/arch_overview.md +++ b/docs/design/arch_overview.md @@ -1,22 +1,18 @@ -(arch-overview)= - -# Architecture Overview +--- +title: Architecture Overview +--- +[](){ #arch-overview } This document provides an overview of the vLLM architecture. -:::{contents} Table of Contents -:depth: 2 -:local: true -::: +[TOC] ## Entrypoints vLLM provides a number of entrypoints for interacting with the system. The following diagram shows the relationship between them. -:::{image} /assets/design/arch_overview/entrypoints.excalidraw.png -:alt: Entrypoints Diagram -::: +![Entrypoints Diagram](../assets/design/arch_overview/entrypoints.excalidraw.png) ### LLM Class @@ -77,16 +73,14 @@ python -m vllm.entrypoints.openai.api_server --model That code can be found in . -More details on the API server can be found in the [OpenAI-Compatible Server](#openai-compatible-server) document. +More details on the API server can be found in the [OpenAI-Compatible Server][openai-compatible-server] document. ## LLM Engine The `LLMEngine` and `AsyncLLMEngine` classes are central to the functioning of the vLLM system, handling model inference and asynchronous request processing. -:::{image} /assets/design/arch_overview/llm_engine.excalidraw.png -:alt: LLMEngine Diagram -::: +![LLMEngine Diagram](../assets/design/arch_overview/llm_engine.excalidraw.png) ### LLMEngine @@ -137,18 +131,16 @@ input tensors and capturing cudagraphs. ## Model Every model runner object has one model object, which is the actual -`torch.nn.Module` instance. See [huggingface_integration](#huggingface-integration) for how various +`torch.nn.Module` instance. See [huggingface_integration][huggingface-integration] for how various configurations affect the class we ultimately get. ## Class Hierarchy The following figure shows the class hierarchy of vLLM: -> :::{figure} /assets/design/hierarchy.png -> :align: center -> :alt: query -> :width: 100% -> ::: +>
+> ![](../assets/design/hierarchy.png){ align="center" alt="query" width="100%" } +>
There are several important design choices behind this class hierarchy: @@ -178,44 +170,43 @@ of a vision model and a language model. By making the constructor uniform, we can easily create a vision model and a language model and compose them into a vision-language model. -:::{note} -To support this change, all vLLM models' signatures have been updated to: - -```python -def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): -``` - -To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: +!!! note + To support this change, all vLLM models' signatures have been updated to: -```python -class MyOldModel(nn.Module): - def __init__( - self, - config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: - ... - -from vllm.config import VllmConfig -class MyNewModel(MyOldModel): + ```python def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - super().__init__(config, cache_config, quant_config, lora_config, prefix) - -if __version__ >= "0.6.4": - MyModel = MyNewModel -else: - MyModel = MyOldModel -``` - -This way, the model can work with both old and new versions of vLLM. -::: + ``` + + To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: + + ```python + class MyOldModel(nn.Module): + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + ... + + from vllm.config import VllmConfig + class MyNewModel(MyOldModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + super().__init__(config, cache_config, quant_config, lora_config, prefix) + + if __version__ >= "0.6.4": + MyModel = MyNewModel + else: + MyModel = MyOldModel + ``` + + This way, the model can work with both old and new versions of vLLM. 3\. **Sharding and Quantization at Initialization**: Certain features require changing the model weights. For example, tensor parallelism needs to shard the diff --git a/docs/source/design/automatic_prefix_caching.md b/docs/design/automatic_prefix_caching.md similarity index 98% rename from docs/source/design/automatic_prefix_caching.md rename to docs/design/automatic_prefix_caching.md index 3928e0c16568..80883bb1d90d 100644 --- a/docs/source/design/automatic_prefix_caching.md +++ b/docs/design/automatic_prefix_caching.md @@ -1,6 +1,7 @@ -(design-automatic-prefix-caching)= - -# Automatic Prefix Caching +--- +title: Automatic Prefix Caching +--- +[](){ #design-automatic-prefix-caching } The core idea of [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) is to partition the KV cache of each request into KV Blocks. Each block contains the attention keys and values for a fixed number of tokens. The PagedAttention algorithm allows these blocks to be stored in non-contiguous physical memory so that we can eliminate memory fragmentation by allocating the memory on demand. diff --git a/docs/source/design/huggingface_integration.md b/docs/design/huggingface_integration.md similarity index 64% rename from docs/source/design/huggingface_integration.md rename to docs/design/huggingface_integration.md index 7d271b1cfb3a..2d462ccb6535 100644 --- a/docs/source/design/huggingface_integration.md +++ b/docs/design/huggingface_integration.md @@ -1,23 +1,22 @@ -(huggingface-integration)= - -# Integration with HuggingFace +--- +title: Integration with HuggingFace +--- +[](){ #huggingface-integration } This document describes how vLLM integrates with HuggingFace libraries. We will explain step by step what happens under the hood when we run `vllm serve`. Let's say we want to serve the popular QWen model by running `vllm serve Qwen/Qwen2-7B`. 1. The `model` argument is `Qwen/Qwen2-7B`. vLLM determines whether this model exists by checking for the corresponding config file `config.json`. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L162-L182) for the implementation. Within this process: - - - If the `model` argument corresponds to an existing local path, vLLM will load the config file directly from this path. - - If the `model` argument is a HuggingFace model ID consisting of a username and model name, vLLM will first try to use the config file from the HuggingFace local cache, using the `model` argument as the model name and the `--revision` argument as the revision. See [their website](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome) for more information on how the HuggingFace cache works. - - If the `model` argument is a HuggingFace model ID but it is not found in the cache, vLLM will download the config file from the HuggingFace model hub. Refer to [this function](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91) for the implementation. The input arguments include the `model` argument as the model name, the `--revision` argument as the revision, and the environment variable `HF_TOKEN` as the token to access the model hub. In our case, vLLM will download the [config.json](https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json) file. + - If the `model` argument corresponds to an existing local path, vLLM will load the config file directly from this path. + - If the `model` argument is a HuggingFace model ID consisting of a username and model name, vLLM will first try to use the config file from the HuggingFace local cache, using the `model` argument as the model name and the `--revision` argument as the revision. See [their website](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome) for more information on how the HuggingFace cache works. + - If the `model` argument is a HuggingFace model ID but it is not found in the cache, vLLM will download the config file from the HuggingFace model hub. Refer to [this function](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91) for the implementation. The input arguments include the `model` argument as the model name, the `--revision` argument as the revision, and the environment variable `HF_TOKEN` as the token to access the model hub. In our case, vLLM will download the [config.json](https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json) file. 2. After confirming the existence of the model, vLLM loads its config file and converts it into a dictionary. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L185-L186) for the implementation. 3. Next, vLLM [inspects](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L189) the `model_type` field in the config dictionary to [generate](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L190-L216) the config object to use. There are some `model_type` values that vLLM directly supports; see [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L48) for the list. If the `model_type` is not in the list, vLLM will use [AutoConfig.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained) to load the config class, with `model`, `--revision`, and `--trust_remote_code` as the arguments. Please note that: - - - HuggingFace also has its own logic to determine the config class to use. It will again use the `model_type` field to search for the class name in the transformers library; see [here](https://github.com/huggingface/transformers/tree/main/src/transformers/models) for the list of supported models. If the `model_type` is not found, HuggingFace will use the `auto_map` field from the config JSON file to determine the class name. Specifically, it is the `AutoConfig` field under `auto_map`. See [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json) for an example. - - The `AutoConfig` field under `auto_map` points to a module path in the model's repository. To create the config class, HuggingFace will import the module and use the `from_pretrained` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when `--trust_remote_code` is enabled. + - HuggingFace also has its own logic to determine the config class to use. It will again use the `model_type` field to search for the class name in the transformers library; see [here](https://github.com/huggingface/transformers/tree/main/src/transformers/models) for the list of supported models. If the `model_type` is not found, HuggingFace will use the `auto_map` field from the config JSON file to determine the class name. Specifically, it is the `AutoConfig` field under `auto_map`. See [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json) for an example. + - The `AutoConfig` field under `auto_map` points to a module path in the model's repository. To create the config class, HuggingFace will import the module and use the `from_pretrained` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when `--trust_remote_code` is enabled. 4. Subsequently, vLLM applies some historical patches to the config object. These are mostly related to RoPE configuration; see [here](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/config.py#L244) for the implementation. @@ -28,8 +27,7 @@ Beyond that, there are two more things vLLM depends on HuggingFace for. 1. **Tokenizer**: vLLM uses the tokenizer from HuggingFace to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check HuggingFace's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24). 2. **Model weight**: vLLM downloads the model weight from the HuggingFace model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights. - - - It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that: + - It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that: This completes the integration between vLLM and HuggingFace. diff --git a/docs/design/kernel/paged_attention.md b/docs/design/kernel/paged_attention.md new file mode 100644 index 000000000000..6ebe1ee48acf --- /dev/null +++ b/docs/design/kernel/paged_attention.md @@ -0,0 +1,498 @@ +--- +title: vLLM Paged Attention +--- +[](){ #design-paged-attention } + +Currently, vLLM utilizes its own implementation of a multi-head query +attention kernel (`csrc/attention/attention_kernels.cu`). +This kernel is designed to be compatible with +vLLM's paged KV caches, where the key and value cache are stored in +separate blocks (note that this block concept differs from the GPU +thread block. So in a later document, I will refer to vLLM paged +attention block as "block", while refer to GPU thread block as +"thread block"). + +To achieve high performance, this kernel relies on a specially +designed memory layout and access method, specifically when threads +read data from global memory to shared memory. The purpose of this +document is to provide a high-level explanation of the kernel +implementation step by step, aiding those who wish to learn about the +vLLM multi-head query attention kernel. After going through this +document, users will likely have a better understanding and feel easier +to follow the actual implementation. + +Please note that this document may not cover all details, such as how +to calculate the correct index for the corresponding data or the dot +multiplication implementation. However, after reading this document +and becoming familiar with the high-level logic flow, it should be +easier for you to read the actual code and understand the details. + +## Inputs + +The kernel function takes a list of arguments for the current thread +to perform its assigned work. The three most important arguments are +the input pointers `q`, `k_cache`, and `v_cache`, which point +to query, key, and value data on global memory that need to be read +and processed. The output pointer `out` points to global memory +where the result should be written. These four pointers actually +refer to multi-dimensional arrays, but each thread only accesses the +portion of data assigned to it. I have omitted all other runtime +parameters here for simplicity. + +```cpp +template +__device__ void paged_attention_kernel( + ... // Other side args. + const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + ... // Other side args. +) +``` + +There are also a list of template arguments above the function +signature that are determined during compilation time. `scalar_t` +represents the data type of the query, key, and value data elements, +such as FP16. `HEAD_SIZE` indicates the number of elements in each +head. `BLOCK_SIZE` refers to the number of tokens in each block. +`NUM_THREADS` denotes the number of threads in each thread block. +`PARTITION_SIZE` represents the number of tensor parallel GPUs (For +simplicity, we assume this is 0 and tensor parallel is disabled). + +With these arguments, we need to perform a sequence of preparations. +This includes calculating the current head index, block index, and +other necessary variables. However, for now, we can ignore these +preparations and proceed directly to the actual calculations. It will +be easier to understand them once we grasp the entire flow. + +## Concepts + +Just before we dive into the calculation flow, I want to describe a +few concepts that are needed for later sections. However, you may +skip this section and return later if you encounter any confusing +terminologies. + +- **Sequence**: A sequence represents a client request. For example, + the data pointed to by `q` has a shape of + `[num_seqs, num_heads, head_size]`. That represents there are total + `num_seqs` of query sequence data are pointed by `q`. Since this + kernel is a single query attention kernel, each sequence only has one + query token. Hence, the `num_seqs` equals the total number of tokens + that are processed in the batch. +- **Context**: The context consists of the generated tokens from the + sequence. For instance, `["What", "is", "your"]` are the context + tokens, and the input query token is `"name"`. The model might + generate the token `"?"`. +- **Vec**: The vec is a list of elements that are fetched and + calculated together. For query and key data, the vec size + (`VEC_SIZE`) is determined so that each thread group can fetch and + calculate 16 bytes of data at a time. For value data, the vec size + (`V_VEC_SIZE`) is determined so that each thread can fetch and + calculate 16 bytes of data at a time. For example, if the + `scalar_t` is FP16 (2 bytes) and `THREAD_GROUP_SIZE` is 2, the + `VEC_SIZE` will be 4, while the `V_VEC_SIZE` will be 8. +- **Thread group**: The thread group is a small group of + threads(`THREAD_GROUP_SIZE`) that fetches and calculates one + query token and one key token at a time. Each thread handles only a + portion of the token data. The total number of elements processed by + one thread group is referred as `x`. For example, if the thread + group contains 2 threads and the head size is 8, then thread 0 + handles the query and key elements at index 0, 2, 4, 6, while thread + 1 handles the elements at index 1, 3, 5, 7. +- **Block**: The key and value cache data in vLLM are split into + blocks. Each block stores data for a fixed number(`BLOCK_SIZE`) + of tokens at one head. Each block may contain only a portion of the + whole context tokens. For example, if the block size is 16 and the + head size is 128, then for one head, one block can store 16 * 128 = + 2048 elements. +- **Warp**: A warp is a group of 32 threads(`WARP_SIZE`) that + execute simultaneously on a stream multiprocessor (SM). In this + kernel, each warp processes the calculation between one query token + and key tokens of one entire block at a time (it may process multiple + blocks in multiple iterations). For example, if there are 4 warps and + 6 blocks for one context, the assignment would be like warp 0 handles + the 0th, 4th blocks, warp 1 handles the 1st, 5th blocks, warp 2 + handles the 2nd block and warp 3 handles the 3rd block. +- **Thread block**: A thread block is a group of + threads(`NUM_THREADS`) that can access the same shared memory. + Each thread block contains multiple warps(`NUM_WARPS`), and in + this kernel, each thread block processes the calculation between one + query token and key tokens of a whole context. +- **Grid**: A grid is a collection of thread blocks and defines the + shape of the collection. In this kernel, the shape is + `(num_heads, num_seqs, max_num_partitions)`. Therefore, each thread + block only handles the calculation for one head, one sequence, and + one partition. + +## Query + +This section will introduce how query data is stored in memory and +fetched by each thread. As mentioned above, each thread group fetches +one query token data, while each thread itself only handles a part of +one query token data. Within each warp, every thread group will fetch +the same query token data, but will multiply it with different key +token data. + +```cpp +const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; +``` + +
+ ![](../../assets/kernel/query.png){ align="center" alt="query" width="70%" } +
+ +Each thread defines its own `q_ptr` which points to the assigned +query token data on global memory. For example, if `VEC_SIZE` is 4 +and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains +total of 128 elements divided into 128 / 4 = 32 vecs. + +
+ ![](../../assets/kernel/q_vecs.png){ align="center" alt="q_vecs" width="70%" } +
+ +```cpp +__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +``` + +Next, we need to read the global memory data pointed to by `q_ptr` +into shared memory as `q_vecs`. It is important to note that each +vecs is assigned to a different row. For example, if the +`THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs, +while thread 1 handles the 1st row vecs. By reading the query data in +this way, neighboring threads like thread 0 and thread 1 can read +neighbor memory, achieving the memory coalescing to improve +performance. + +## Key + +Similar to the "Query" section, this section introduces memory layout +and assignment for keys. While each thread group only handle one +query token one kernel run, it may handle multiple key tokens across +multiple iterations. Meanwhile, each warp will process multiple blocks +of key tokens in multiple iterations, ensuring that all context +tokens are processed by the entire thread group after the kernel run. +In this context, "handle" refers to performing the dot multiplication +between query data and key data. + +```cpp +const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; +``` + +Unlike to `q_ptr`, `k_ptr` in each thread will point to different +key token at different iterations. As shown above, that `k_ptr` +points to key token data based on `k_cache` at assigned block, +assigned head and assigned token. + +
+ ![](../../assets/kernel/key.png){ align="center" alt="key" width="70%" } +
+ +The diagram above illustrates the memory layout for key data. It +assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is +8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each +rectangle represents all the elements for one key token at one head, +which will be processed by one thread group. The left half shows the +total 16 blocks of key token data for warp 0, while the right half +represents the remaining key token data for other warps or +iterations. Inside each rectangle, there are a total 32 vecs (128 +elements for one token) that will be processed by 2 threads (one +thread group) separately. + +
+ ![](../../assets/kernel/k_vecs.png){ align="center" alt="k_vecs" width="70%" } +
+ +```cpp +K_vec k_vecs[NUM_VECS_PER_THREAD] +``` + +Next, we need to read the key token data from `k_ptr` and store +them on register memory as `k_vecs`. We use register memory for +`k_vecs` because it will only be accessed by one thread once, +whereas `q_vecs` will be accessed by multiple threads multiple +times. Each `k_vecs` will contain multiple vectors for later +calculation. Each vec will be set at each inner iteration. The +assignment of vecs allows neighboring threads in a warp to read +neighboring memory together, which again promotes the memory +coalescing. For instance, thread 0 will read vec 0, while thread 1 +will read vec 1. In the next inner loop, thread 0 will read vec 2, +while thread 1 will read vec 3, and so on. + +You may still be a little confused about the overall flow. Don't +worry, please keep reading the next "QK" section. It will illustrate +the query and key calculation flow in a clearer and higher-level +manner. + +## QK + +As shown the pseudo code below, before the entire for loop block, we +fetch the query data for one token and store it in `q_vecs`. Then, +in the outer for loop, we iterate through different `k_ptrs` that +point to different tokens and prepare the `k_vecs` in the inner for +loop. Finally, we perform the dot multiplication between the +`q_vecs` and each `k_vecs`. + +```cpp +q_vecs = ... +for ... { + k_ptr = ... + for ... { + k_vecs[i] = ... + } + ... + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); +} +``` + +As mentioned before, for each thread, it only fetches part of the +query and key token data at a time. However, there will be a cross +thread group reduction happen in the `Qk_dot<>::dot` . So `qk` +returned here is not just between part of the query and key token dot +multiplication, but actually a full result between entire query and +key token data. + +For example, if the value of `HEAD_SIZE` is 128 and +`THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain +total 64 elements. However, the returned `qk` is actually the +result of dot multiplication between 128 query elements and 128 key +elements. If you want to learn more about the details of the dot +multiplication and reduction, you may refer to the implementation of +`Qk_dot<>::dot`. However, for the sake of simplicity, I will not +cover it in this document. + +## Softmax + +Next, we need to calculate the normalized softmax for all `qk`s, +as shown above, where each $x$ represents a `qk`. To do this, +we must obtain the reduced value of `qk_max`($m(x)$) and +the `exp_sum`($\ell(x)$) of all `qk`s. The reduction +should be performed across the entire thread block, encompassing +results between the query token and all context key tokens. + +$$ +\begin{gather*} +m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ +\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} +\end{gather*} +$$ + +### `qk_max` and `logits` + +Just right after we get the `qk` result, we can set the temporary +`logits` result with `qk` (In the end, the `logits` should +store the normalized softmax result). Also we can compare and collect +the `qk_max` for all `qk`s that are calculated by current +thread group. + +```cpp +if (thread_group_offset == 0) { + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + qk_max = mask ? qk_max : fmaxf(qk_max, qk); +} +``` + +Please note that the `logits` here is on shared memory, so each +thread group will set the fields for its own assigned context tokens. +Overall, the size of logits should be number of context tokens. + +```cpp +for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); +} + +if (lane == 0) { + red_smem[warp_idx] = qk_max; +} +``` + +Then we need to get the reduced `qk_max` across each warp. The main +idea is to make threads in warp to communicate with each other and +get the final max `qk` . + +```cpp +for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); +} +qk_max = VLLM_SHFL_SYNC(qk_max, 0); +``` + +Finally, we can get the reduced `qk_max` from whole thread block by +compare the `qk_max` from all warps in this thread block. Then we +need to broadcast the final result to each thread. + +### `exp_sum` + +Similar to `qk_max`, we need to get the reduced sum value from the +entire thread block too. + +```cpp +for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; +} +... +exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); +``` + +Firstly, sum all exp values from each thread group, and meanwhile, +convert each entry of `logits` from `qk` to `exp(qk - qk_max)`. +Please note, the `qk_max` here is already the max `qk` across the +whole thread block. And then we can do reduction for `exp_sum` +across whole thread block just like the `qk_max`. + +```cpp +const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); +for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; +} +``` + +Finally, with the reduced `qk_max` and `exp_sum`, we can obtain +the final normalized softmax result as `logits`. This `logits` +variable will be used for dot multiplication with the value data in +later steps. Now, it should store the normalized softmax result of +`qk` for all assigned context tokens. + +## Value + +
+ ![](../../assets/kernel/value.png){ align="center" alt="value" width="70%" } +
+ +
+ ![](../../assets/kernel/logits_vec.png){ align="center" alt="logits_vec" width="50%" } +
+ +
+ ![](../../assets/kernel/v_vec.png){ align="center" alt="v_vec" width="70%" } +
+ +Now we need to retrieve the value data and perform dot multiplication +with `logits`. Unlike query and key, there is no thread group +concept for value data. As shown in diagram, different from key token +memory layout, elements from the same column correspond to the same +value token. For one block of value data, there are `HEAD_SIZE` of +rows and `BLOCK_SIZE` of columns that are split into multiple +`v_vecs`. + +Each thread always fetches `V_VEC_SIZE` elements from the same +`V_VEC_SIZE` of tokens at a time. As a result, a single thread +retrieves multiple `v_vec`s from different rows and the same +columns through multiple inner iterations. For each `v_vec`, it +needs to be dot multiplied with the corresponding `logits_vec`, +which is also `V_VEC_SIZE` elements from `logits`. Overall, with +multiple inner iterations, each warp will process one block of value +tokens. And with multiple outer iterations, the whole context value +tokens are processed + +```cpp +float accs[NUM_ROWS_PER_THREAD]; +for ... { // Iteration over different blocks. + logits_vec = ... + for ... { // Iteration over different rows. + v_vec = ... + ... + accs[i] += dot(logits_vec, v_vec); + } +} +``` + +As shown in the above pseudo code, in the outer loop, similar to +`k_ptr`, `logits_vec` iterates over different blocks and reads +`V_VEC_SIZE` elements from `logits`. In the inner loop, each +thread reads `V_VEC_SIZE` elements from the same tokens as a +`v_vec` and performs dot multiplication. It is important to note +that in each inner iteration, the thread fetches different head +position elements for the same tokens. The dot result is then +accumulated in `accs`. Therefore, each entry of `accs` is mapped +to a head position assigned to the current thread. + +For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each +thread fetches 8 value elements for 8 tokens at a time. Each element +is from different tokens at the same head position. If `HEAD_SIZE` +is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to +fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are +a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle +a whole block of value tokens. And each `accs` in each thread +contains 8 elements that accumulated at 8 different head positions. +For the thread 0, the `accs` variable will have 8 elements, which +are 0th, 32th โ€ฆ 224th elements of a value head that are accumulated +from all assigned 8 tokens. + +## LV + +Now, we need to perform reduction for `accs` within each warp. This +process allows each thread to accumulate the `accs` for the +assigned head positions of all tokens in one block. + +```cpp +for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; +} +``` + +Next, we perform reduction for `accs` across all warps, allowing +each thread to have the accumulation of `accs` for the assigned +head positions of all context tokens. Please note that each `accs` +in every thread only stores the accumulation for a portion of +elements of the entire head for all context tokens. However, overall, +all results for output have been calculated but are just stored in +different thread register memory. + +```cpp +float* out_smem = reinterpret_cast(shared_mem); +for (int i = NUM_WARPS; i > 1; i /= 2) { + // Upper warps write to shared memory. + ... + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + ... + dst[row_idx] = accs[i]; + } + + // Lower warps update the output. + const float* src = &out_smem[warp_idx * HEAD_SIZE]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + ... + accs[i] += src[row_idx]; + } + + // Write out the accs. +} +``` + +## Output + +Now we can write all of calculated result from local register memory +to final output global memory. + +```cpp +scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; +``` + +First, we need to define the `out_ptr` variable, which points to +the start address of the assigned sequence and assigned head. + +```cpp +for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } +} +``` + +Finally, we need to iterate over different assigned head positions +and write out the corresponding accumulated result based on the +`out_ptr`. diff --git a/docs/source/design/mm_processing.md b/docs/design/mm_processing.md similarity index 61% rename from docs/source/design/mm_processing.md rename to docs/design/mm_processing.md index dc92a3c2c511..f3685ce76a4b 100644 --- a/docs/source/design/mm_processing.md +++ b/docs/design/mm_processing.md @@ -1,10 +1,11 @@ -(mm-processing)= +--- +title: Multi-Modal Data Processing +--- +[](){ #mm-processing } -# Multi-Modal Data Processing +To enable various optimizations in vLLM such as [chunked prefill][chunked-prefill] and [prefix caching][automatic-prefix-caching], we use [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] to provide the correspondence between placeholder feature tokens (e.g. ``) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor. -To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefill) and [prefix caching](#automatic-prefix-caching), we use {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` to provide the correspondence between placeholder feature tokens (e.g. ``) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor. - -Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`: +Here are the main features of [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor]: ## Prompt Update Detection @@ -15,7 +16,7 @@ One of the main responsibilities of HF processor is to update the prompt with pl The information about which tokens have been updated is key to finding the correspondence between placeholder feature tokens and multi-modal inputs. -In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptUpdate` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens. +In vLLM, this information is specified using [PromptUpdate][vllm.multimodal.processing.PromptUpdate] in [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates]. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens. ## Tokenized Prompt Inputs @@ -43,22 +44,22 @@ While HF processors support text + multi-modal inputs natively, this is not so f Moreover, since the tokenized text has not passed through the HF processor, we have to apply Step 3 by ourselves to keep the output tokens and multi-modal data consistent with each other. -(mm-dummy-text)= +[](){ #mm-dummy-text } ### Dummy text -We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. +We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via [get_dummy_text][vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text]. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. -(mm-automatic-prompt-updating)= +[](){ #mm-automatic-prompt-updating } ### Automatic prompt updating We address the second issue by implementing model-agnostic code in -{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates` to automatically update the prompt with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. +[_apply_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates] to automatically update the prompt with feature placeholder tokens based on the specification outputted by [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates]. ### Summary -With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`. +With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in [_apply_hf_processor_main][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main]. ## Processor Output Caching @@ -66,4 +67,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238) When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache. -Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#mm-automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other. +Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text][mm-dummy-text] to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating][mm-automatic-prompt-updating] afterwards to keep the output tokens and multi-modal data consistent with each other. diff --git a/docs/source/design/multiprocessing.md b/docs/design/multiprocessing.md similarity index 96% rename from docs/source/design/multiprocessing.md rename to docs/design/multiprocessing.md index 43fe5fe2e5e9..412c42fd580e 100644 --- a/docs/source/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -2,14 +2,13 @@ ## Debugging -Please see the [Troubleshooting](#troubleshooting-python-multiprocessing) +Please see the [Troubleshooting][troubleshooting-python-multiprocessing] page for information on known issues and how to solve them. ## Introduction -:::{important} -The source code references are to the state of the code at the time of writing in December, 2024. -::: +!!! warning + The source code references are to the state of the code at the time of writing in December, 2024. The use of Python multiprocessing in vLLM is complicated by: @@ -124,7 +123,7 @@ what is happening. First, a log message from vLLM: WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously initialized. We must use the `spawn` multiprocessing start method. Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See - https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing + https://docs.vllm.ai/en/latest/usage/debugging.html#python-multiprocessing for more information. ``` diff --git a/docs/source/design/plugin_system.md b/docs/design/plugin_system.md similarity index 83% rename from docs/source/design/plugin_system.md rename to docs/design/plugin_system.md index 225030885f62..0764dfb6501b 100644 --- a/docs/source/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -1,12 +1,13 @@ -(plugin-system)= - -# vLLM's Plugin System +--- +title: vLLM's Plugin System +--- +[](){ #plugin-system } The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM. ## How Plugins Work in vLLM -Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [](#arch-overview)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work. +Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [Arch Overview][arch-overview]), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work. ## How vLLM Discovers Plugins @@ -29,8 +30,10 @@ def register(): from vllm import ModelRegistry if "MyLlava" not in ModelRegistry.get_supported_archs(): - ModelRegistry.register_model("MyLlava", - "vllm_add_dummy_model.my_llava:MyLlava") + ModelRegistry.register_model( + "MyLlava", + "vllm_add_dummy_model.my_llava:MyLlava", + ) ``` For more information on adding entry points to your package, please check the [official documentation](https://setuptools.pypa.io/en/latest/userguide/entry_point.html). diff --git a/docs/source/design/v1/metrics.md b/docs/design/v1/metrics.md similarity index 97% rename from docs/source/design/v1/metrics.md rename to docs/design/v1/metrics.md index 7e7c8b925e21..7156ee9dd3ec 100644 --- a/docs/source/design/v1/metrics.md +++ b/docs/design/v1/metrics.md @@ -57,11 +57,11 @@ In v0, the following metrics are exposed via a Prometheus-compatible `/metrics` - `vllm:spec_decode_num_draft_tokens_total` (Counter) - `vllm:spec_decode_num_emitted_tokens_total` (Counter) -These are documented under [Inferencing and Serving -> Production Metrics](project:../../serving/metrics.md). +These are documented under [Inferencing and Serving -> Production Metrics](../../usage/metrics.md). ### Grafana Dashboard -vLLM also provides [a reference example](https://docs.vllm.ai/en/latest/getting_started/examples/prometheus_grafana.html) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard. +vLLM also provides [a reference example](https://docs.vllm.ai/en/latest/examples/prometheus_grafana.html) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard. The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: @@ -222,9 +222,7 @@ And the calculated intervals are: Put another way: -:::{image} /assets/design/v1/metrics/intervals-1.png -:alt: Interval calculations - common case -::: +![Interval calculations - common case](../../assets/design/v1/metrics/intervals-1.png) We explored the possibility of having the frontend calculate these intervals using the timing of events visible by the frontend. However, @@ -239,17 +237,13 @@ When a preemption occurs during decode, since any already generated tokens are reused, we consider the preemption as affecting the inter-token, decode, and inference intervals. -:::{image} /assets/design/v1/metrics/intervals-2.png -:alt: Interval calculations - preempted decode -::: +![Interval calculations - preempted decode](../../assets/design/v1/metrics/intervals-2.png) When a preemption occurs during prefill (assuming such an event is possible), we consider the preemption as affecting the time-to-first-token and prefill intervals. -:::{image} /assets/design/v1/metrics/intervals-3.png -:alt: Interval calculations - preempted prefill -::: +![Interval calculations - preempted prefill](../../assets/design/v1/metrics/intervals-3.png) ### Frontend Stats Collection @@ -415,8 +409,8 @@ The discussion in about adding prefix cache metrics yielded some interesting points which may be relevant to how we approach future metrics. -Every time the prefix cache is queried, we record the number of blocks -queried and the number of queried blocks present in the cache +Every time the prefix cache is queried, we record the number of tokens +queried and the number of queried tokens present in the cache (i.e. hits). However, the metric of interest is the hit rate - i.e. the number of @@ -467,7 +461,7 @@ In general: hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics) for some time before deleting them. -See the [deprecation policy](project:../../contributing/deprecation_policy.md) for +See the [deprecation policy](../../contributing/deprecation_policy.md) for the project-wide deprecation policy. ### Unimplemented - `vllm:tokens_total` @@ -679,7 +673,7 @@ v0 has support for OpenTelemetry tracing: - [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/) - [User-facing - docs](https://docs.vllm.ai/en/latest/getting_started/examples/opentelemetry.html) + docs](https://docs.vllm.ai/en/latest/examples/opentelemetry.html) - [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f) - [IBM product diff --git a/docs/source/design/v1/prefix_caching.md b/docs/design/v1/prefix_caching.md similarity index 94% rename from docs/source/design/v1/prefix_caching.md rename to docs/design/v1/prefix_caching.md index ec661d8ec641..ad041b0059f5 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/design/v1/prefix_caching.md @@ -86,7 +86,7 @@ To improve privacy in shared environments, vLLM supports isolating prefix cache {"role": "user", "content": "Here is a document with details about the world series: ..."}, {"role": "user", "content": "Who won the world series in 2020?"} ], - "cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==" + "cache_salt": "your-cache-salt" } ``` @@ -122,9 +122,7 @@ There are two design points to highlight: As a result, we will have the following components when the KV cache manager is initialized: -:::{image} /assets/design/v1/prefix_caching/overview.png -:alt: Component Overview -::: +![Component Overview](../../assets/design/v1/prefix_caching/overview.png) * Block Pool: A list of KVCacheBlock. * Free Block Queue: Only store the pointers of head and tail blocks for manipulations. @@ -194,9 +192,7 @@ As can be seen, block 3 is a new full block and is cached. However, it is redund When a request is finished, we free all its blocks if no other requests are using them (reference count = 0). In this example, we free request 1 and block 2, 3, 4, 8 associated with it. We can see that the freed blocks are added to the tail of the free queue in the *reverse* order. This is because the last block of a request must hash more tokens and is less likely to be reused by other requests. As a result, it should be evicted first. -:::{image} /assets/design/v1/prefix_caching/free.png -:alt: Free Queue after Free a Request -::: +![Free queue after a request us freed](../../assets/design/v1/prefix_caching/free.png) ### Eviction (LRU) @@ -212,36 +208,24 @@ In this example, we assume the block size is 4 (each block can cache 4 tokens), **Time 1: The cache is empty and a new request comes in.** We allocate 4 blocks. 3 of them are already full and cached. The fourth block is partially full with 3 of 4 tokens. -:::{image} /assets/design/v1/prefix_caching/example-time-1.png -:alt: Example Time 1 -::: +![Example Time 1](../../assets/design/v1/prefix_caching/example-time-1.png) **Time 3: Request 0 makes the block 3 full and asks for a new block to keep decoding.** We cache block 3 and allocate block 4. -:::{image} /assets/design/v1/prefix_caching/example-time-3.png -:alt: Example Time 3 -::: +![Example Time 3](../../assets/design/v1/prefix_caching/example-time-3.png) **Time 4: Request 1 comes in with the 14 prompt tokens, where the first 10 tokens are the same as request 0.** We can see that only the first 2 blocks (8 tokens) hit the cache, because the 3rd block only matches 2 of 4 tokens. -:::{image} /assets/design/v1/prefix_caching/example-time-4.png -:alt: Example Time 4 -::: +![Example Time 4](../../assets/design/v1/prefix_caching/example-time-4.png) **Time 5: Request 0 is finished and free.** Blocks 2, 3 and 4 are added to the free queue in the reverse order (but block 2 and 3 are still cached). Block 0 and 1 are not added to the free queue because they are being used by Request 1. -:::{image} /assets/design/v1/prefix_caching/example-time-5.png -:alt: Example Time 5 -::: +![Example Time 5](../../assets/design/v1/prefix_caching/example-time-5.png) **Time 6: Request 1 is finished and free.** -:::{image} /assets/design/v1/prefix_caching/example-time-6.png -:alt: Example Time 6 -::: +![Example Time 6](../../assets/design/v1/prefix_caching/example-time-6.png) **Time 7: Request 2 comes in with the 29 prompt tokens, where the first 12 tokens are the same as request 0\.** Note that even the block order in the free queue was `7 - 8 - 9 - 4 - 3 - 2 - 6 - 5 - 1 - 0`, the cache hit blocks (i.e., 0, 1, 2) are touched and removed from the queue before allocation, so the free queue becomes `7 - 8 - 9 - 4 - 3 - 6 - 5`. As a result, the allocated blocks are 0 (cached), 1 (cached), 2 (cached), 7, 8, 9, 4, 3 (evicted). -:::{image} /assets/design/v1/prefix_caching/example-time-7.png -:alt: Example Time 7 -::: +![Example Time 7](../../assets/design/v1/prefix_caching/example-time-7.png) diff --git a/docs/source/design/v1/torch_compile.md b/docs/design/v1/torch_compile.md similarity index 98% rename from docs/source/design/v1/torch_compile.md rename to docs/design/v1/torch_compile.md index 4d8ce0fd9227..64b6f0cc0a9b 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/design/v1/torch_compile.md @@ -99,7 +99,9 @@ This time, Inductor compilation is completely bypassed, and we will load from di The above example just uses Inductor to compile for a general shape (i.e. symbolic shape). We can also use Inductor to compile for some of the specific shapes, for example: -`vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'compile_sizes': [1, 2, 4, 8]}"` +``` +vllm serve meta-llama/Llama-3.2-1B --compilation_config '{"compile_sizes": [1, 2, 4, 8]}' +``` Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At this time, all of the shapes in the computation graph are static and known, and we will turn on auto-tuning to tune for max performance. This can be slow when you run it for the first time, but the next time you run it, we can directly bypass the tuning and run the tuned kernel. @@ -134,12 +136,14 @@ The cudagraphs are captured and managed by the compiler backend, and replayed wh By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`: -`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` +``` +vllm serve meta-llama/Llama-3.2-1B --compilation-config '{"cudagraph_capture_sizes": [1, 2, 4, 8]}' +``` Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture. ### Full Cudagraph capture -It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"` +It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config '{"full_cuda_graph": true}'`. Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled. diff --git a/docs/features/automatic_prefix_caching.md b/docs/features/automatic_prefix_caching.md new file mode 100644 index 000000000000..5e92796ddda7 --- /dev/null +++ b/docs/features/automatic_prefix_caching.md @@ -0,0 +1,28 @@ +--- +title: Automatic Prefix Caching +--- +[](){ #automatic-prefix-caching } + +## Introduction + +Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part. + +!!! note + Technical details on how vLLM implements APC can be found [here][design-automatic-prefix-caching]. + +## Enabling APC in vLLM + +Set `enable_prefix_caching=True` in vLLM engine to enable APC. Here is an example: + + + +## Example workloads + +We describe two example workloads, where APC can provide huge performance benefit: + +- Long document query, where the user repeatedly queries the same long document (e.g. software manual or annual report) with different queries. In this case, instead of processing the long document again and again, APC allows vLLM to process this long document *only once*, and all future requests can avoid recomputing this long document by reusing its KV cache. This allows vLLM to serve future requests with much higher throughput and much lower latency. +- Multi-round conversation, where the user may chat with the application multiple times in the same chatting session. In this case, instead of processing the whole chatting history again and again, APC allows vLLM to reuse the processing results of the chat history across all future rounds of conversation, allowing vLLM to serve future requests with much higher throughput and much lower latency. + +## Limits + +APC in general does not reduce the performance of vLLM. With that being said, APC only reduces the time of processing the queries (the prefilling phase) and does not reduce the time of generating new tokens (the decoding phase). So APC does not bring performance gain when vLLM spends most of the time generating answers to the queries (e.g. when the length of the answer is long), or new queries do not share the same prefix with any of existing queries (so that the computation cannot be reused). diff --git a/docs/features/compatibility_matrix.md b/docs/features/compatibility_matrix.md new file mode 100644 index 000000000000..77ceea49f173 --- /dev/null +++ b/docs/features/compatibility_matrix.md @@ -0,0 +1,77 @@ +--- +title: Compatibility Matrix +--- +[](){ #compatibility-matrix } + +The tables below show mutually exclusive features and the support on some hardware. + +The symbols used have the following meanings: + +- โœ… = Full compatibility +- ๐ŸŸ  = Partial compatibility +- โŒ = No compatibility + +!!! note + Check the โŒ or ๐ŸŸ  with links to see tracking issue for unsupported feature/hardware combination. + +## Feature x Feature + + + +| Feature | [CP][chunked-prefill] | [APC][automatic-prefix-caching] | [LoRA][lora-adapter] | prmpt adptr | [SD][spec-decode] | CUDA graph | pooling | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | +|-----------------------------------------------------------|-------------------------|-----------------------------------|------------------------|---------------------------------------------------|---------------------|--------------|-----------------------------------------------|-------------------------------------------------------|--------------------------------------|---------------------------------------------------|-------------------------------------------------------------|--------------------|---------------------------------------------|-----------|---------------| +| [CP][chunked-prefill] | โœ… | | | | | | | | | | | | | | | +| [APC][automatic-prefix-caching] | โœ… | โœ… | | | | | | | | | | | | | | +| [LoRA][lora-adapter] | โœ… | โœ… | โœ… | | | | | | | | | | | | | +| prmpt adptr | โœ… | โœ… | โœ… | โœ… | | | | | | | | | | | | +| [SD][spec-decode] | โœ… | โœ… | โŒ | โœ… | โœ… | | | | | | | | | | | +| CUDA graph | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | | | | | | | | | | +| pooling | โŒ | โŒ | โŒ | โŒ | โŒ | โŒ | โœ… | | | | | | | | | +| enc-dec | โŒ | [โŒ](gh-issue:7366) | โŒ | โŒ | [โŒ](gh-issue:7366) | โœ… | โœ… | โœ… | | | | | | | | +| logP | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | โœ… | โœ… | | | | | | | +| prmpt logP | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | โœ… | โœ… | โœ… | | | | | | +| async output | โœ… | โœ… | โœ… | โœ… | โŒ | โœ… | โŒ | โŒ | โœ… | โœ… | โœ… | | | | | +| multi-step | โŒ | โœ… | โŒ | โœ… | โŒ | โœ… | โŒ | โŒ | โœ… | โœ… | โœ… | โœ… | | | | +| mm | โœ… | [๐ŸŸ ](gh-pr:8348) | [๐ŸŸ ](gh-pr:4194) | โ” | โ” | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โ” | โœ… | | | +| best-of | โœ… | โœ… | โœ… | โœ… | [โŒ](gh-issue:6137) | โœ… | โŒ | โœ… | โœ… | โœ… | โ” | [โŒ](gh-issue:7968) | โœ… | โœ… | | +| beam-search | โœ… | โœ… | โœ… | โœ… | [โŒ](gh-issue:6137) | โœ… | โŒ | โœ… | โœ… | โœ… | โ” | [โŒ](gh-issue:7968) | โ” | โœ… | โœ… | + +[](){ #feature-x-hardware } + +## Feature x Hardware + +| Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | +|-----------------------------------------------------------|--------------------|----------|----------|-------|----------|--------------------|-------| +| [CP][chunked-prefill] | [โŒ](gh-issue:2729) | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | +| [APC][automatic-prefix-caching] | [โŒ](gh-issue:3687) | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | +| [LoRA][lora-adapter] | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | +| prmpt adptr | โœ… | โœ… | โœ… | โœ… | โœ… | [โŒ](gh-issue:8475) | โœ… | +| [SD][spec-decode] | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | +| CUDA graph | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | โœ… | +| pooling | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โ” | +| enc-dec | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | +| mm | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | +| logP | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | +| prmpt logP | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | +| async output | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | โŒ | +| multi-step | โœ… | โœ… | โœ… | โœ… | โœ… | [โŒ](gh-issue:8477) | โœ… | +| best-of | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | +| beam-search | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | diff --git a/docs/source/features/disagg_prefill.md b/docs/features/disagg_prefill.md similarity index 87% rename from docs/source/features/disagg_prefill.md rename to docs/features/disagg_prefill.md index 2fa20140c086..54be05647d94 100644 --- a/docs/source/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -1,12 +1,12 @@ -(disagg-prefill)= - -# Disaggregated Prefilling (experimental) +--- +title: Disaggregated Prefilling (experimental) +--- +[](){ #disagg-prefill } This page introduces you the disaggregated prefilling feature in vLLM. -:::{note} -This feature is experimental and subject to change. -::: +!!! note + This feature is experimental and subject to change. ## Why disaggregated prefilling? @@ -15,9 +15,8 @@ Two main reasons: - **Tuning time-to-first-token (TTFT) and inter-token-latency (ITL) separately**. Disaggregated prefilling put prefill and decode phase of LLM inference inside different vLLM instances. This gives you the flexibility to assign different parallel strategies (e.g. `tp` and `pp`) to tune TTFT without affecting ITL, or to tune ITL without affecting TTFT. - **Controlling tail ITL**. Without disaggregated prefilling, vLLM may insert some prefill jobs during the decoding of one request. This results in higher tail latency. Disaggregated prefilling helps you solve this issue and control tail ITL. Chunked prefill with a proper chunk size also can achieve the same goal, but in practice it's hard to figure out the correct chunk size value. So disaggregated prefilling is a much more reliable way to control tail ITL. -:::{note} -Disaggregated prefill DOES NOT improve throughput. -::: +!!! note + Disaggregated prefill DOES NOT improve throughput. ## Usage example @@ -39,21 +38,16 @@ Key abstractions for disaggregated prefilling: - **LookupBuffer**: LookupBuffer provides two API: `insert` KV cache and `drop_select` KV cache. The semantics of `insert` and `drop_select` are similar to SQL, where `insert` inserts a KV cache into the buffer, and `drop_select` returns the KV cache that matches the given condition and drop it from the buffer. - **Pipe**: A single-direction FIFO pipe for tensor transmission. It supports `send_tensor` and `recv_tensor`. -:::{note} -`insert` is non-blocking operation but `drop_select` is blocking operation. -::: +!!! note + `insert` is non-blocking operation but `drop_select` is blocking operation. Here is a figure illustrating how the above 3 abstractions are organized: -:::{image} /assets/features/disagg_prefill/abstraction.jpg -:alt: Disaggregated prefilling abstractions -::: +![Disaggregated prefilling abstractions](../assets/features/disagg_prefill/abstraction.jpg) The workflow of disaggregated prefilling is as follows: -:::{image} /assets/features/disagg_prefill/overview.jpg -:alt: Disaggregated prefilling workflow -::: +![Disaggregated prefilling workflow](../assets/features/disagg_prefill/overview.jpg) The `buffer` corresponds to `insert` API in LookupBuffer, and the `drop_select` corresponds to `drop_select` API in LookupBuffer. diff --git a/docs/source/features/lora.md b/docs/features/lora.md similarity index 84% rename from docs/source/features/lora.md rename to docs/features/lora.md index b5b51095b3a7..642462f7c455 100644 --- a/docs/source/features/lora.md +++ b/docs/features/lora.md @@ -1,10 +1,11 @@ -(lora-adapter)= - -# LoRA Adapters +--- +title: LoRA Adapters +--- +[](){ #lora-adapter } This document shows you how to use [LoRA adapters](https://arxiv.org/abs/2106.09685) with vLLM on top of a base model. -LoRA adapters can be used with any vLLM model that implements {class}`~vllm.model_executor.models.interfaces.SupportsLoRA`. +LoRA adapters can be used with any vLLM model that implements [SupportsLoRA][vllm.model_executor.models.interfaces.SupportsLoRA]. Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save them locally with @@ -60,13 +61,12 @@ vllm serve meta-llama/Llama-2-7b-hf \ --lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/ ``` -:::{note} -The commit ID `0dfa347e8877a4d4ed19ee56c140fa518470028c` may change over time. Please check the latest commit ID in your environment to ensure you are using the correct one. -::: +!!! note + The commit ID `0dfa347e8877a4d4ed19ee56c140fa518470028c` may change over time. Please check the latest commit ID in your environment to ensure you are using the correct one. The server entrypoint accepts all other LoRA configuration parameters (`max_loras`, `max_lora_rank`, `max_cpu_loras`, etc.), which will apply to all forthcoming requests. Upon querying the `/models` endpoint, we should see our LoRA along -with its base model: +with its base model (if `jq` is not installed, you can follow [this guide](https://jqlang.org/download/) to install it.): ```bash curl localhost:8000/v1/models | jq . @@ -134,7 +134,7 @@ curl -X POST http://localhost:8000/v1/load_lora_adapter \ }' ``` -Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter +Upon a successful request, the API will respond with a `200 OK` status code from `vllm serve`, and `curl` returns the response body: `Success: LoRA adapter 'sql_adapter' added successfully`. If an error occurs, such as if the adapter cannot be found or loaded, an appropriate error message will be returned. Unloading a LoRA Adapter: @@ -142,6 +142,8 @@ Unloading a LoRA Adapter: To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint with the name or ID of the adapter to be unloaded. +Upon a successful request, the API responds with a `200 OK` status code from `vllm serve`, and `curl` returns the response body: `Success: LoRA adapter 'sql_adapter' removed successfully`. + Example request to unload a LoRA adapter: ```bash @@ -157,9 +159,12 @@ Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adap You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds. -You can either install existing plugins or implement your own. +You can either install existing plugins or implement your own. By default, vLLM comes with a [resolver plugin to load LoRA adapters from a local directory.](https://github.com/vllm-project/vllm/tree/main/vllm/plugins/lora_resolvers) +To enable this resolver, set `VLLM_ALLOW_RUNTIME_LORA_UPDATING` to True, set `VLLM_PLUGINS` to include `lora_filesystem_resolver`, and then set `VLLM_LORA_RESOLVER_CACHE_DIR` to a local directory. When vLLM receives a request using a LoRA adapter `foobar`, +it will first look in the local directory for a directory `foobar`, and attempt to load the contents of that directory as a LoRA adapter. If successful, the request will complete as normal and +that adapter will then be available for normal use on the server. -Steps to implement your own LoRAResolver plugin: +Alternatively, follow these example steps to implement your own plugin: 1. Implement the LoRAResolver interface. Example of a simple S3 LoRAResolver implementation: diff --git a/docs/source/serving/multimodal_inputs.md b/docs/features/multimodal_inputs.md similarity index 84% rename from docs/source/serving/multimodal_inputs.md rename to docs/features/multimodal_inputs.md index bb2997f008ed..19b668172902 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -1,20 +1,20 @@ -(multimodal-inputs)= +--- +title: Multimodal Inputs +--- +[](){ #multimodal-inputs } -# Multimodal Inputs +This page teaches you how to pass multi-modal inputs to [multi-modal models][supported-mm-models] in vLLM. -This page teaches you how to pass multi-modal inputs to [multi-modal models](#supported-mm-models) in vLLM. - -:::{note} -We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes, -and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests. -::: +!!! note + We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes, + and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests. ## Offline Inference -To input multi-modal data, follow this schema in {class}`vllm.inputs.PromptType`: +To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: - `prompt`: The prompt should follow the format that is documented on HuggingFace. -- `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.inputs.MultiModalDataDict`. +- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][]. ### Image Inputs @@ -211,16 +211,15 @@ for o in outputs: Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). -:::{important} -A chat template is **required** to use Chat Completions API. -For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`. +!!! warning + A chat template is **required** to use Chat Completions API. + For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`. -If no default chat template is available, we will first look for a built-in fallback in . -If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument. + If no default chat template is available, we will first look for a built-in fallback in . + If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument. -For certain models, we provide alternative chat templates inside . -For example, VLM2Vec uses which is different from the default one for Phi-3-Vision. -::: + For certain models, we provide alternative chat templates inside . + For example, VLM2Vec uses which is different from the default one for Phi-3-Vision. ### Image Inputs @@ -284,25 +283,21 @@ print("Chat completion output:", chat_response.choices[0].message.content) Full example: -:::{tip} -Loading from local file paths is also supported on vLLM: You can specify the allowed local media path via `--allowed-local-media-path` when launching the API server/engine, -and pass the file path as `url` in the API request. -::: - -:::{tip} -There is no need to place image placeholders in the text content of the API request - they are already represented by the image content. -In fact, you can place image placeholders in the middle of the text by interleaving text and image content. -::: +!!! tip + Loading from local file paths is also supported on vLLM: You can specify the allowed local media path via `--allowed-local-media-path` when launching the API server/engine, + and pass the file path as `url` in the API request. -:::{note} -By default, the timeout for fetching images through HTTP URL is `5` seconds. -You can override this by setting the environment variable: +!!! tip + There is no need to place image placeholders in the text content of the API request - they are already represented by the image content. + In fact, you can place image placeholders in the middle of the text by interleaving text and image content. -```console -export VLLM_IMAGE_FETCH_TIMEOUT= -``` +!!! note + By default, the timeout for fetching images through HTTP URL is `5` seconds. + You can override this by setting the environment variable: -::: + ```console + export VLLM_IMAGE_FETCH_TIMEOUT= + ``` ### Video Inputs @@ -357,15 +352,13 @@ print("Chat completion output from image url:", result) Full example: -:::{note} -By default, the timeout for fetching videos through HTTP URL is `30` seconds. -You can override this by setting the environment variable: +!!! note + By default, the timeout for fetching videos through HTTP URL is `30` seconds. + You can override this by setting the environment variable: -```console -export VLLM_VIDEO_FETCH_TIMEOUT= -``` - -::: + ```console + export VLLM_VIDEO_FETCH_TIMEOUT= + ``` ### Audio Inputs @@ -461,15 +454,13 @@ print("Chat completion output from audio url:", result) Full example: -:::{note} -By default, the timeout for fetching audios through HTTP URL is `10` seconds. -You can override this by setting the environment variable: - -```console -export VLLM_AUDIO_FETCH_TIMEOUT= -``` +!!! note + By default, the timeout for fetching audios through HTTP URL is `10` seconds. + You can override this by setting the environment variable: -::: + ```console + export VLLM_AUDIO_FETCH_TIMEOUT= + ``` ### Embedding Inputs @@ -535,7 +526,6 @@ chat_completion = client.chat.completions.create( ) ``` -:::{note} -Only one message can contain `{"type": "image_embeds"}`. -If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc. -::: +!!! note + Only one message can contain `{"type": "image_embeds"}`. + If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc. diff --git a/docs/features/prompt_embeds.md b/docs/features/prompt_embeds.md new file mode 100644 index 000000000000..6f5616e05d8c --- /dev/null +++ b/docs/features/prompt_embeds.md @@ -0,0 +1,43 @@ +# Prompt Embedding Inputs + +This page teaches you how to pass prompt embedding inputs to vLLM. + +## What are prompt embeddings? + +The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary. + +!!! note + Prompt embeddings are currently only supported in the v0 engine. + +## Offline Inference + +To input multi-modal data, follow this schema in [vllm.inputs.EmbedsPrompt][]: + +- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model. + +### Hugging Face Transformers Inputs + +You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples: + + + +## Online Serving + +Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. + +When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. + +Prompt embeddings are passed in as base64 encoded torch tensors. + +### Transformers Inputs via OpenAI Client + +First, launch the OpenAI-compatible server: + +```bash +vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \ + --max-model-len 4096 --enable-prompt-embeds +``` + +Then, you can use the OpenAI client as follows: + + diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md new file mode 100644 index 000000000000..71f62065f63d --- /dev/null +++ b/docs/features/quantization/README.md @@ -0,0 +1,22 @@ +--- +title: Quantization +--- +[](){ #quantization-index } + +Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices. + +Contents: + +- [Supported_Hardware](supported_hardware.md) +- [Auto_Awq](auto_awq.md) +- [Bnb](bnb.md) +- [Bitblas](bitblas.md) +- [Gguf](gguf.md) +- [Gptqmodel](gptqmodel.md) +- [Int4](int4.md) +- [Int8](int8.md) +- [Fp8](fp8.md) +- [Modelopt](modelopt.md) +- [Quark](quark.md) +- [Quantized_Kvcache](quantized_kvcache.md) +- [Torchao](torchao.md) diff --git a/docs/source/features/quantization/auto_awq.md b/docs/features/quantization/auto_awq.md similarity index 93% rename from docs/source/features/quantization/auto_awq.md rename to docs/features/quantization/auto_awq.md index b4ac597f5a79..4366a080f52c 100644 --- a/docs/source/features/quantization/auto_awq.md +++ b/docs/features/quantization/auto_awq.md @@ -1,6 +1,7 @@ -(auto-awq)= - -# AutoAWQ +--- +title: AutoAWQ +--- +[](){ #auto-awq } To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github.com/casper-hansen/AutoAWQ). Quantization reduces the model's precision from BF16/FP16 to INT4 which effectively reduces the total model memory footprint. @@ -41,7 +42,9 @@ print(f'Model is quantized and saved at "{quant_path}"') To run an AWQ model with vLLM, you can use [TheBloke/Llama-2-7b-Chat-AWQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-AWQ) with the following command: ```console -python examples/offline_inference/llm_engine_example.py --model TheBloke/Llama-2-7b-Chat-AWQ --quantization awq +python examples/offline_inference/llm_engine_example.py \ + --model TheBloke/Llama-2-7b-Chat-AWQ \ + --quantization awq ``` AWQ models are also supported directly through the LLM entrypoint: diff --git a/docs/source/features/quantization/bitblas.md b/docs/features/quantization/bitblas.md similarity index 62% rename from docs/source/features/quantization/bitblas.md rename to docs/features/quantization/bitblas.md index d0b2bf858c9b..9001725d9c02 100644 --- a/docs/source/features/quantization/bitblas.md +++ b/docs/features/quantization/bitblas.md @@ -1,14 +1,14 @@ -(bitblas)= - -# BitBLAS +--- +title: BitBLAS +--- +[](){ #bitblas } vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations. -:::{note} -Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). -Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. -For details see [supported hardware](https://docs.vllm.ai/en/latest/features/quantization/supported_hardware.html). -::: +!!! note + Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). + Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. + For details see [supported hardware](https://docs.vllm.ai/en/latest/features/quantization/supported_hardware.html). Below are the steps to utilize BitBLAS with vLLM. @@ -33,7 +33,12 @@ import torch # "hxbgsyxh/llama-13b-4bit-g-1-bitblas" is a pre-quantized checkpoint. model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas" -llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas") +llm = LLM( + model=model_id, + dtype=torch.bfloat16, + trust_remote_code=True, + quantization="bitblas" +) ``` ## Read gptq format checkpoint @@ -44,5 +49,11 @@ import torch # "hxbgsyxh/llama-13b-4bit-g-1" is a pre-quantized checkpoint. model_id = "hxbgsyxh/llama-13b-4bit-g-1" -llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024) +llm = LLM( + model=model_id, + dtype=torch.float16, + trust_remote_code=True, + quantization="bitblas", + max_model_len=1024 +) ``` diff --git a/docs/source/features/quantization/bnb.md b/docs/features/quantization/bnb.md similarity index 79% rename from docs/source/features/quantization/bnb.md rename to docs/features/quantization/bnb.md index 1843a33a3dfd..a8dc2476f30a 100644 --- a/docs/source/features/quantization/bnb.md +++ b/docs/features/quantization/bnb.md @@ -1,6 +1,7 @@ -(bits-and-bytes)= - -# BitsAndBytes +--- +title: BitsAndBytes +--- +[](){ #bits-and-bytes } vLLM now supports [BitsAndBytes](https://github.com/TimDettmers/bitsandbytes) for more efficient model inference. BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy. @@ -14,7 +15,7 @@ pip install bitsandbytes>=0.45.3 vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. -You can find bitsandbytes quantized models on . +You can find bitsandbytes quantized models on [Hugging Face](https://huggingface.co/models?search=bitsandbytes). And usually, these repositories have a config.json file that includes a quantization_config section. ## Read quantized checkpoint @@ -26,7 +27,11 @@ from vllm import LLM import torch # unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint. model_id = "unsloth/tinyllama-bnb-4bit" -llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True) +llm = LLM( + model=model_id, + dtype=torch.bfloat16, + trust_remote_code=True +) ``` ## Inflight quantization: load as 4bit quantization @@ -37,8 +42,12 @@ For inflight 4bit quantization with BitsAndBytes, you need to explicitly specify from vllm import LLM import torch model_id = "huggyllama/llama-7b" -llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \ -quantization="bitsandbytes") +llm = LLM( + model=model_id, + dtype=torch.bfloat16, + trust_remote_code=True, + quantization="bitsandbytes" +) ``` ## OpenAI Compatible Server diff --git a/docs/source/features/quantization/fp8.md b/docs/features/quantization/fp8.md similarity index 88% rename from docs/source/features/quantization/fp8.md rename to docs/features/quantization/fp8.md index cb304d54726c..01d5d9da046d 100644 --- a/docs/source/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -1,6 +1,7 @@ -(fp8)= - -# FP8 W8A8 +--- +title: FP8 W8A8 +--- +[](){ #fp8 } vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x. Currently, only Hopper and Ada Lovelace GPUs are officially supported for W8A8. @@ -14,10 +15,9 @@ The FP8 types typically supported in hardware have two distinct representations, - **E4M3**: Consists of 1 sign bit, 4 exponent bits, and 3 bits of mantissa. It can store values up to +/-448 and `nan`. - **E5M2**: Consists of 1 sign bit, 5 exponent bits, and 2 bits of mantissa. It can store values up to +/-57344, +/- `inf`, and `nan`. The tradeoff for the increased dynamic range is lower precision of the stored values. -:::{note} -FP8 computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada Lovelace, Hopper). -FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin. -::: +!!! note + FP8 computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada Lovelace, Hopper). + FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin. ## Installation @@ -94,9 +94,8 @@ print(result[0].outputs[0].text) Evaluate accuracy with `lm_eval` (for example on 250 samples of `gsm8k`): -:::{note} -Quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations. -::: +!!! note + Quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations. ```console $ MODEL=$PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic @@ -133,6 +132,5 @@ result = model.generate("Hello, my name is") print(result[0].outputs[0].text) ``` -:::{warning} -Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. -::: +!!! warning + Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. diff --git a/docs/source/features/quantization/gguf.md b/docs/features/quantization/gguf.md similarity index 64% rename from docs/source/features/quantization/gguf.md rename to docs/features/quantization/gguf.md index e93e4dcd3b57..72f758f653a8 100644 --- a/docs/source/features/quantization/gguf.md +++ b/docs/features/quantization/gguf.md @@ -1,39 +1,42 @@ -(gguf)= +--- +title: GGUF +--- +[](){ #gguf } -# GGUF +!!! warning + Please note that GGUF support in vLLM is highly experimental and under-optimized at the moment, it might be incompatible with other features. Currently, you can use GGUF as a way to reduce memory footprint. If you encounter any issues, please report them to the vLLM team. -:::{warning} -Please note that GGUF support in vLLM is highly experimental and under-optimized at the moment, it might be incompatible with other features. Currently, you can use GGUF as a way to reduce memory footprint. If you encounter any issues, please report them to the vLLM team. -::: - -:::{warning} -Currently, vllm only supports loading single-file GGUF models. If you have a multi-files GGUF model, you can use [gguf-split](https://github.com/ggerganov/llama.cpp/pull/6135) tool to merge them to a single-file model. -::: +!!! warning + Currently, vllm only supports loading single-file GGUF models. If you have a multi-files GGUF model, you can use [gguf-split](https://github.com/ggerganov/llama.cpp/pull/6135) tool to merge them to a single-file model. To run a GGUF model with vLLM, you can download and use the local GGUF model from [TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF](https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF) with the following command: ```console wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion. -vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 +vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ + --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 ``` You can also add `--tensor-parallel-size 2` to enable tensor parallelism inference with 2 GPUs: ```console # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion. -vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tensor-parallel-size 2 +vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ + --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --tensor-parallel-size 2 ``` -:::{warning} -We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size. -::: +!!! warning + We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size. GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-config-path ```console # If you model is not supported by huggingface you can manually provide a huggingface compatible config path -vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0 +vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ + --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0 ``` You can also use the GGUF model directly through the LLM entrypoint: diff --git a/docs/source/features/quantization/gptqmodel.md b/docs/features/quantization/gptqmodel.md similarity index 95% rename from docs/source/features/quantization/gptqmodel.md rename to docs/features/quantization/gptqmodel.md index 9771d5a4fe9e..53e938d2cbd7 100644 --- a/docs/source/features/quantization/gptqmodel.md +++ b/docs/features/quantization/gptqmodel.md @@ -1,6 +1,7 @@ -(gptqmodel)= - -# GPTQModel +--- +title: GPTQModel +--- +[](){ #gptqmodel } To create a new 4-bit or 8-bit GPTQ quantized model, you can leverage [GPTQModel](https://github.com/ModelCloud/GPTQModel) from ModelCloud.AI. @@ -58,7 +59,8 @@ model.save(quant_path) To run an GPTQModel quantized model with vLLM, you can use [DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2](https://huggingface.co/ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2) with the following command: ```console -python examples/offline_inference/llm_engine_example.py --model ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2 +python examples/offline_inference/llm_engine_example.py \ + --model ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2 ``` ## Using GPTQModel with vLLM's Python API diff --git a/docs/source/features/quantization/int4.md b/docs/features/quantization/int4.md similarity index 94% rename from docs/source/features/quantization/int4.md rename to docs/features/quantization/int4.md index 7a0ab4ad229e..b7d09206365f 100644 --- a/docs/source/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -1,14 +1,14 @@ -(int4)= - -# INT4 W4A16 +--- +title: INT4 W4A16 +--- +[](){ #int4 } vLLM supports quantizing weights to INT4 for memory savings and inference acceleration. This quantization method is particularly useful for reducing model size and maintaining low latency in workloads with low queries per second (QPS). Please visit the HF collection of [quantized INT4 checkpoints of popular LLMs ready to use with vLLM](https://huggingface.co/collections/neuralmagic/int4-llms-for-vllm-668ec34bf3c9fa45f857df2c). -:::{note} -INT4 computation is supported on NVIDIA GPUs with compute capability > 8.0 (Ampere, Ada Lovelace, Hopper, Blackwell). -::: +!!! note + INT4 computation is supported on NVIDIA GPUs with compute capability > 8.0 (Ampere, Ada Lovelace, Hopper, Blackwell). ## Prerequisites @@ -121,9 +121,8 @@ $ lm_eval --model vllm \ --batch_size 'auto' ``` -:::{note} -Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations. -::: +!!! note + Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations. ## Best Practices diff --git a/docs/source/features/quantization/int8.md b/docs/features/quantization/int8.md similarity index 92% rename from docs/source/features/quantization/int8.md rename to docs/features/quantization/int8.md index 1e4b01d35575..1d9fba9dc87f 100644 --- a/docs/source/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -1,15 +1,15 @@ -(int8)= - -# INT8 W8A8 +--- +title: INT8 W8A8 +--- +[](){ #int8 } vLLM supports quantizing weights and activations to INT8 for memory savings and inference acceleration. This quantization method is particularly useful for reducing model size while maintaining good performance. Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs ready to use with vLLM](https://huggingface.co/collections/neuralmagic/int8-llms-for-vllm-668ec32c049dca0369816415). -:::{note} -INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper, Blackwell). -::: +!!! note + INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper, Blackwell). ## Prerequisites @@ -125,9 +125,8 @@ $ lm_eval --model vllm \ --batch_size 'auto' ``` -:::{note} -Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations. -::: +!!! note + Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations. ## Best Practices diff --git a/docs/source/features/quantization/modelopt.md b/docs/features/quantization/modelopt.md similarity index 100% rename from docs/source/features/quantization/modelopt.md rename to docs/features/quantization/modelopt.md diff --git a/docs/source/features/quantization/quantized_kvcache.md b/docs/features/quantization/quantized_kvcache.md similarity index 98% rename from docs/source/features/quantization/quantized_kvcache.md rename to docs/features/quantization/quantized_kvcache.md index 86e6354ec82e..e3ebd024bab3 100644 --- a/docs/source/features/quantization/quantized_kvcache.md +++ b/docs/features/quantization/quantized_kvcache.md @@ -1,6 +1,7 @@ -(quantized-kvcache)= - -# Quantized KV Cache +--- +title: Quantized KV Cache +--- +[](){ #quantized-kvcache } ## FP8 KV Cache diff --git a/docs/source/features/quantization/quark.md b/docs/features/quantization/quark.md similarity index 94% rename from docs/source/features/quantization/quark.md rename to docs/features/quantization/quark.md index 955890dbc75b..51da98cc09d3 100644 --- a/docs/source/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -1,6 +1,7 @@ -(quark)= - -# AMD QUARK +--- +title: AMD QUARK +--- +[](){ #quark } Quantization can effectively reduce memory and bandwidth usage, accelerate computation and improve throughput while with minimal accuracy loss. vLLM can leverage [Quark](https://quark.docs.amd.com/latest/), @@ -86,13 +87,12 @@ We need to set the quantization configuration, you can check for further details. Here we use FP8 per-tensor quantization on weight, activation, kv-cache and the quantization algorithm is AutoSmoothQuant. -:::{note} -Note the quantization algorithm needs a JSON config file and the config file is located in -[Quark Pytorch examples](https://quark.docs.amd.com/latest/pytorch/pytorch_examples.html), -under the directory `examples/torch/language_modeling/llm_ptq/models`. For example, -AutoSmoothQuant config file for Llama is -`examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json`. -::: +!!! note + Note the quantization algorithm needs a JSON config file and the config file is located in + [Quark Pytorch examples](https://quark.docs.amd.com/latest/pytorch/pytorch_examples.html), + under the directory `examples/torch/language_modeling/llm_ptq/models`. For example, + AutoSmoothQuant config file for Llama is + `examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json`. ```python from quark.torch.quantization import (Config, QuantizationConfig, diff --git a/docs/features/quantization/supported_hardware.md b/docs/features/quantization/supported_hardware.md new file mode 100644 index 000000000000..2967bf9c7504 --- /dev/null +++ b/docs/features/quantization/supported_hardware.md @@ -0,0 +1,28 @@ +--- +title: Supported Hardware +--- +[](){ #quantization-supported-hardware } + +The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: + +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Inferentia | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------| +| AWQ | โŒ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | +| GPTQ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | +| Marlin (GPTQ/AWQ/FP8) | โŒ | โŒ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | โŒ | โŒ | โŒ | +| INT8 (W8A8) | โŒ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | โœ…๏ธŽ | โŒ | โœ…๏ธŽ | +| FP8 (W8A8) | โŒ | โŒ | โŒ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | โŒ | โŒ | +| BitBLAS (GPTQ) | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | โŒ | โŒ | โŒ | +| AQLM | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | โŒ | โŒ | โŒ | +| bitsandbytes | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | โŒ | โŒ | โŒ | +| DeepSpeedFP | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | โŒ | โŒ | โŒ | +| GGUF | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | โŒ | โŒ | โŒ | โŒ | + +- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. +- โœ…๏ธŽ indicates that the quantization method is supported on the specified hardware. +- โŒ indicates that the quantization method is not supported on the specified hardware. + +!!! note + This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. + + For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/source/features/quantization/torchao.md b/docs/features/quantization/torchao.md similarity index 86% rename from docs/source/features/quantization/torchao.md rename to docs/features/quantization/torchao.md index 82100c6ddcac..a7a517af85aa 100644 --- a/docs/source/features/quantization/torchao.md +++ b/docs/features/quantization/torchao.md @@ -7,7 +7,9 @@ We recommend installing the latest torchao nightly with ```console # Install the latest TorchAO nightly build # Choose the CUDA version that matches your system (cu126, cu128, etc.) -pip install --pre torchao>=10.0.0 --index-url https://download.pytorch.org/whl/nightly/cu126 +pip install \ + --pre torchao>=10.0.0 \ + --index-url https://download.pytorch.org/whl/nightly/cu126 ``` ## Quantizing HuggingFace Models @@ -20,7 +22,12 @@ from torchao.quantization import Int8WeightOnlyConfig model_name = "meta-llama/Meta-Llama-3-8B" quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) -quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config) +quantized_model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto", + quantization_config=quantization_config +) tokenizer = AutoTokenizer.from_pretrained(model_name) input_text = "What are we having for dinner?" input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") diff --git a/docs/source/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md similarity index 87% rename from docs/source/features/reasoning_outputs.md rename to docs/features/reasoning_outputs.md index a079eb8b77e7..cbcb246912f4 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -1,6 +1,7 @@ -(reasoning-outputs)= - -# Reasoning Outputs +--- +title: Reasoning Outputs +--- +[](){ #reasoning-outputs } vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions. @@ -17,14 +18,17 @@ vLLM currently supports the following reasoning models: | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | โŒ | โŒ | | [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | โœ… | -- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. +!!! note + IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. + The reasoning feature for the Qwen3 series is enabled by default. To disable it, you must pass `enable_thinking=False` in your `chat_template_kwargs`. ## Quickstart To use reasoning models, you need to specify the `--reasoning-parser` flags when making a request to the chat completion endpoint. The `--reasoning-parser` flag specifies the reasoning parser to use for extracting reasoning content from the model output. ```bash -vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --reasoning-parser deepseek_r1 ``` Next, make a request to the model that should return the reasoning content in the response. @@ -47,6 +51,8 @@ model = models.data[0].id # Round 1 messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] # For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` +# For Qwen3 series, if you want to disable thinking in reasoning mode, add: +# extra_body={"chat_template_kwargs": {"enable_thinking": False}} response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content @@ -83,7 +89,7 @@ Streaming chat completions are also supported for reasoning models. The `reasoni } ``` -OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client support extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example: +OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client supports extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example: ```python from openai import OpenAI @@ -102,6 +108,8 @@ model = models.data[0].id messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] # For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` +# For Qwen3 series, if you want to disable thinking in reasoning mode, add: +# extra_body={"chat_template_kwargs": {"enable_thinking": False}} stream = client.chat.completions.create(model=model, messages=messages, stream=True) @@ -139,10 +147,10 @@ Remember to check whether the `reasoning_content` exists in the response before The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now. ```bash -VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 ``` -Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine. +The following is an example client: ```python from openai import OpenAI @@ -160,12 +168,10 @@ client = OpenAI( models = client.models.list() model = models.data[0].id - class People(BaseModel): name: str age: int - json_schema = People.model_json_schema() prompt = ("Generate a JSON with the name and age of one random person.") @@ -221,7 +227,7 @@ print(f"Function called: {tool_call.name}") print(f"Arguments: {tool_call.arguments}") ``` -For more examples, please refer to . +For more examples, please refer to . ## Limitations @@ -229,13 +235,12 @@ For more examples, please refer to . ```python # import the required packages -from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( - ReasoningParser, ReasoningParserManager) +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) @@ -286,7 +291,7 @@ class ExampleParser(ReasoningParser): """ ``` -Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`. +Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in . ```python @dataclass @@ -312,7 +317,7 @@ class DeepSeekReasoner(Reasoner): ... ``` -The structured output engine like `xgrammar` will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. +The structured output engine like [xgrammar](https://github.com/mlc-ai/xgrammar) will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. Finally, you can enable reasoning for the model by using the `--reasoning-parser` flags. diff --git a/docs/source/features/spec_decode.md b/docs/features/spec_decode.md similarity index 91% rename from docs/source/features/spec_decode.md rename to docs/features/spec_decode.md index f16e0d96522d..5080960f72dd 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -1,16 +1,15 @@ -(spec-decode)= +--- +title: Speculative Decoding +--- +[](){ #spec-decode } -# Speculative Decoding +!!! warning + Please note that speculative decoding in vLLM is not yet optimized and does + not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. + The work to optimize it is ongoing and can be followed here: -:::{warning} -Please note that speculative decoding in vLLM is not yet optimized and does -not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. -The work to optimize it is ongoing and can be followed here: -::: - -:::{warning} -Currently, speculative decoding in vLLM is not compatible with pipeline parallelism. -::: +!!! warning + Currently, speculative decoding in vLLM is not compatible with pipeline parallelism. This document shows how to use [Speculative Decoding](https://x.com/karpathy/status/1697318534555336961) with vLLM. Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference. @@ -46,14 +45,18 @@ for output in outputs: To perform the same with an online mode launch the server: ```bash -python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \ - --seed 42 -tp 1 --gpu_memory_utilization 0.8 \ +python -m vllm.entrypoints.openai.api_server \ + --host 0.0.0.0 \ + --port 8000 \ + --model facebook/opt-6.7b \ + --seed 42 \ + -tp 1 \ + --gpu_memory_utilization 0.8 \ --speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}' ``` -:::{warning} -Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately has been deprecated now. -::: +!!! warning + Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately has been deprecated now. Then use a client: @@ -172,7 +175,7 @@ A variety of speculative models of this type are available on HF hub: ## Speculating using EAGLE based draft models The following code configures vLLM to use speculative decoding where proposals are generated by -an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](). +an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](gh-file:examples/offline_inference/eagle.py). ```python from vllm import LLM, SamplingParams @@ -255,7 +258,7 @@ speculative decoding, breaking down the guarantees into three key areas: 3. **vLLM Logprob Stability** \- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the same request across runs. For more details, see the FAQ section - titled *Can the output of a prompt vary across runs in vLLM?* in the [FAQs](#faq). + titled *Can the output of a prompt vary across runs in vLLM?* in the [FAQs][faq]. While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding can occur due to following factors: @@ -264,7 +267,7 @@ can occur due to following factors: - **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially due to non-deterministic behavior in batched operations or numerical instability. -For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the [FAQs](#faq). +For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the [FAQs][faq]. ## Resources for vLLM contributors diff --git a/docs/source/features/structured_outputs.md b/docs/features/structured_outputs.md similarity index 96% rename from docs/source/features/structured_outputs.md rename to docs/features/structured_outputs.md index 03119ec7441c..f96b598cff98 100644 --- a/docs/source/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -1,6 +1,7 @@ -(structured-outputs)= - -# Structured Outputs +--- +title: Structured Outputs +--- +[](){ #structured-outputs } vLLM supports the generation of structured outputs using [xgrammar](https://github.com/mlc-ai/xgrammar) or @@ -20,7 +21,7 @@ The following parameters are supported, which must be added as extra parameters: - `guided_grammar`: the output will follow the context free grammar. - `structural_tag`: Follow a JSON schema within a set of specified tags within the generated text. -You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server) page. +You can see the complete list of supported parameters on the [OpenAI-Compatible Server][openai-compatible-server] page. Structured outputs are supported by default in the OpenAI-Compatible Server. You may choose to specify the backend to use by setting the @@ -83,13 +84,11 @@ class CarType(str, Enum): truck = "Truck" coupe = "Coupe" - class CarDescription(BaseModel): brand: str model: str car_type: CarType - json_schema = CarDescription.model_json_schema() completion = client.chat.completions.create( @@ -105,11 +104,10 @@ completion = client.chat.completions.create( print(completion.choices[0].message.content) ``` -:::{tip} -While not strictly necessary, normally itยดs better to indicate in the prompt the -JSON schema and how the fields should be populated. This can improve the -results notably in most cases. -::: +!!! tip + While not strictly necessary, normally itยดs better to indicate in the prompt the + JSON schema and how the fields should be populated. This can improve the + results notably in most cases. Finally we have the `guided_grammar` option, which is probably the most difficult to use, but itยดs really powerful. It allows us to define complete @@ -160,12 +158,10 @@ Here is a simple example demonstrating how to get structured output using Pydant from pydantic import BaseModel from openai import OpenAI - class Info(BaseModel): name: str age: int - client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="dummy") completion = client.beta.chat.completions.parse( model="meta-llama/Llama-3.1-8B-Instruct", @@ -199,17 +195,14 @@ from typing import List from pydantic import BaseModel from openai import OpenAI - class Step(BaseModel): explanation: str output: str - class MathResponse(BaseModel): steps: list[Step] final_answer: str - client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="dummy") completion = client.beta.chat.completions.parse( model="meta-llama/Llama-3.1-8B-Instruct", diff --git a/docs/source/features/tool_calling.md b/docs/features/tool_calling.md similarity index 93% rename from docs/source/features/tool_calling.md rename to docs/features/tool_calling.md index f3b808b3d2b7..6ee1060dd050 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -93,7 +93,7 @@ specify the `name` of one of the tools in the `tool_choice` parameter of the cha ## Required Function Calling -vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The required guided decoding features (JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](https://docs.vllm.ai/en/latest/getting_started/v1_user_guide.html#feature-model) for the V1 engine. +vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The required guided decoding features (JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](https://docs.vllm.ai/en/latest/usage/v1_guide.html#feature-model) for the V1 engine. When tool_choice='required' is set, the model is guaranteed to generate one or more tool calls based on the specified tool list in the `tools` parameter. The number of tool calls depends on the user's query. The output format strictly follows the schema defined in the `tools` parameter. @@ -158,13 +158,13 @@ All Llama 3.1, 3.2 and 4 models should be supported. * `meta-llama/Llama-3.2-*` * `meta-llama/Llama-4-*` -The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. +The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for llama 4 models, it is recommended to use the `llama4_pythonic` tool parser. Other tool calling formats like the built in python tool calling or custom tool calling are not supported. Known issues: -1. Parallel tool calls are not supported. +1. Parallel tool calls are not supported for llama 3, but it is supported in llama 4 models. 2. The model can generate parameters with a wrong format, such as generating an array serialized as string instead of an array. @@ -177,11 +177,10 @@ images. Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` -VLLM also provides a JSON based chat template for Llama 4: -* - this is based on the "official" chat template for the Llama 4 -models, but tweaked so that it works better with vLLM. +VLLM also provides a pythonic and JSON based chat template for Llama 4, but pythonic tool calling is recommended: +* - this is based on the [official chat template](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/) for the Llama 4 models. -For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`. +For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`. #### IBM Granite @@ -236,6 +235,13 @@ For Qwen2.5, the chat template in tokenizer_config.json has already included sup Flags: `--tool-call-parser hermes` +### DeepSeek-V3 Models (`deepseek_v3`) + +Supported models: +* `deepseek-ai/DeepSeek-V3-0324` + +Flags: `--tool-call-parser deepseek_v3 --chat-template examples/tool_chat_template_deepseekv3.jinja` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. @@ -316,7 +322,6 @@ class ExampleToolParser(ToolParser): tool_calls=[], content=text) - ``` Then you can use this plugin in the command line like this. diff --git a/docs/getting_started/installation/.nav.yml b/docs/getting_started/installation/.nav.yml new file mode 100644 index 000000000000..7acfc015ff50 --- /dev/null +++ b/docs/getting_started/installation/.nav.yml @@ -0,0 +1,5 @@ +nav: + - README.md + - gpu.md + - cpu.md + - ai_accelerator.md \ No newline at end of file diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md new file mode 100644 index 000000000000..36bb16cc0224 --- /dev/null +++ b/docs/getting_started/installation/README.md @@ -0,0 +1,20 @@ +--- +title: Installation +--- +[](){ #installation-index } + +vLLM supports the following hardware platforms: + +- [GPU](gpu.md) + - [NVIDIA CUDA](gpu.md#nvidia-cuda) + - [AMD ROCm](gpu.md#amd-rocm) + - [Intel XPU](gpu.md#intel-xpu) +- [CPU](cpu.md) + - [Intel/AMD x86](cpu.md#intelamd-x86) + - [ARM AArch64](cpu.md#arm-aarch64) + - [Apple silicon](cpu.md#apple-silicon) + - [IBM Z (S390X)](cpu.md#ibm-z-s390x) +- [Other AI accelerators](ai_accelerator.md) + - [Google TPU](ai_accelerator.md#google-tpu) + - [Intel Gaudi](ai_accelerator.md#intel-gaudi) + - [AWS Neuron](ai_accelerator.md#aws-neuron) diff --git a/docs/getting_started/installation/ai_accelerator.md b/docs/getting_started/installation/ai_accelerator.md new file mode 100644 index 000000000000..a4f136a172fe --- /dev/null +++ b/docs/getting_started/installation/ai_accelerator.md @@ -0,0 +1,117 @@ +# Other AI accelerators + +vLLM is a Python library that supports the following AI accelerators. Select your AI accelerator type to see vendor specific instructions: + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:installation" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:installation" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:installation" + +## Requirements + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:requirements" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:requirements" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:requirements" + +## Configure a new environment + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:configure-a-new-environment" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:configure-a-new-environment" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:configure-a-new-environment" + +## Set up using Python + +### Pre-built wheels + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:pre-built-wheels" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:pre-built-wheels" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:pre-built-wheels" + +### Build wheel from source + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:build-wheel-from-source" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:build-wheel-from-source" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:build-wheel-from-source" + +## Set up using Docker + +### Pre-built images + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:pre-built-images" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:pre-built-images" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:pre-built-images" + +### Build image from source + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:build-image-from-source" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:build-image-from-source" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:build-image-from-source" + +## Extra information + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:extra-information" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:extra-information" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:extra-information" diff --git a/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md b/docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md similarity index 83% rename from docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md rename to docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md index 78938de317c4..00935a37417e 100644 --- a/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md +++ b/docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md @@ -1,12 +1,12 @@ -# Installation +# --8<-- [start:installation] This tab provides instructions on running vLLM with Intel Gaudi devices. -:::{attention} -There are no pre-built wheels or images for this device, so you must build vLLM from source. -::: +!!! warning + There are no pre-built wheels or images for this device, so you must build vLLM from source. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - OS: Ubuntu 22.04 LTS - Python: 3.10 @@ -45,16 +45,27 @@ Use the following commands to run a Docker image: ```console docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest +docker run \ + -it \ + --runtime=habana \ + -e HABANA_VISIBLE_DEVICES=all \ + -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ + --cap-add=sys_nice \ + --net=host \ + --ipc=host \ + vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest ``` -## Set up using Python +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] -### Pre-built wheels +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] Currently, there are no pre-built Intel Gaudi wheels. -### Build wheel from source +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] To build and install vLLM from source, run: @@ -75,29 +86,39 @@ pip install -r requirements/hpu.txt python setup.py develop ``` -## Set up using Docker +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] Currently, there are no pre-built Intel Gaudi images. -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] ```console docker build -f docker/Dockerfile.hpu -t vllm-hpu-env . -docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --rm vllm-hpu-env +docker run \ + -it \ + --runtime=habana \ + -e HABANA_VISIBLE_DEVICES=all \ + -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ + --cap-add=sys_nice \ + --net=host \ + --rm vllm-hpu-env ``` -:::{tip} -If you're observing the following error: `docker: Error response from daemon: Unknown runtime specified habana.`, please refer to "Install Using Containers" section of [Intel Gaudi Software Stack and Driver Installation](https://docs.habana.ai/en/v1.18.0/Installation_Guide/Bare_Metal_Fresh_OS.html). Make sure you have `habana-container-runtime` package installed and that `habana` container runtime is registered. -::: +!!! tip + If you're observing the following error: `docker: Error response from daemon: Unknown runtime specified habana.`, please refer to "Install Using Containers" section of [Intel Gaudi Software Stack and Driver Installation](https://docs.habana.ai/en/v1.18.0/Installation_Guide/Bare_Metal_Fresh_OS.html). Make sure you have `habana-container-runtime` package installed and that `habana` container runtime is registered. -## Extra information +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] ## Supported features -- [Offline inference](#offline-inference) -- Online serving via [OpenAI-Compatible Server](#openai-compatible-server) +- [Offline inference][offline-inference] +- Online serving via [OpenAI-Compatible Server][openai-compatible-server] - HPU autodetection - no need to manually select device within vLLM - Paged KV cache with algorithms enabled for Intel Gaudi accelerators - Custom Intel Gaudi implementations of Paged Attention, KV cache ops, @@ -157,41 +178,25 @@ Gaudi2 devices. Configurations that are not listed may or may not work. Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via `PT_HPU_LAZY_MODE` environment variable), and `--enforce-eager` flag. -:::{list-table} vLLM execution modes -:widths: 25 25 50 -:header-rows: 1 - -- * `PT_HPU_LAZY_MODE` - * `enforce_eager` - * execution mode -- * 0 - * 0 - * torch.compile -- * 0 - * 1 - * PyTorch eager mode -- * 1 - * 0 - * HPU Graphs -- * 1 - * 1 - * PyTorch lazy mode -::: - -:::{warning} -In 1.18.0, all modes utilizing `PT_HPU_LAZY_MODE=0` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode. -::: - -(gaudi-bucketing-mechanism)= +| `PT_HPU_LAZY_MODE` | `enforce_eager` | execution mode | +|----------------------|-------------------|--------------------| +| 0 | 0 | torch.compile | +| 0 | 1 | PyTorch eager mode | +| 1 | 0 | HPU Graphs | +
vLLM execution modes
+ +!!! warning + In 1.18.0, all modes utilizing `PT_HPU_LAZY_MODE=0` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode. + +[](){ #gaudi-bucketing-mechanism } ### Bucketing mechanism Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. [Intel Gaudi Graph Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime) is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - `batch_size` and `sequence_length`. -:::{note} -Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. -::: +!!! note + Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. Bucketing ranges are determined with 3 parameters - `min`, `step` and `max`. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: @@ -224,15 +229,13 @@ min = 128, step = 128, max = 512 In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. -:::{warning} -If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. -::: +!!! warning + If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as `(4, 512)` prefill bucket, as `batch_size` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as `(4, 512)` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a `(2, 512)` bucket, or context length increases above 512 tokens, in which case it will become `(4, 640)` bucket. -:::{note} -Bucketing is transparent to a client -- padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. -::: +!!! note + Bucketing is transparent to a client -- padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. ### Warmup @@ -252,11 +255,10 @@ INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB ``` -This example uses the same buckets as in the [Bucketing Mechanism](#gaudi-bucketing-mechanism) section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. +This example uses the same buckets as in the [Bucketing Mechanism][gaudi-bucketing-mechanism] section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. -:::{tip} -Compiling all the buckets might take some time and can be turned off with `VLLM_SKIP_WARMUP=true` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. -::: +!!! tip + Compiling all the buckets might take some time and can be turned off with `VLLM_SKIP_WARMUP=true` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. ### HPU Graph capture @@ -271,9 +273,8 @@ With its default value (`VLLM_GRAPH_RESERVED_MEM=0.1`), 10% of usable memory wil Environment variable `VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (`VLLM_GRAPH_PROMPT_RATIO=0.3`), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. -:::{note} -`gpu_memory_utilization` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, `gpu_memory_utilization` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. -::: +!!! note + `gpu_memory_utilization` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, `gpu_memory_utilization` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: @@ -282,9 +283,8 @@ User can also configure the strategy for capturing HPU Graphs for prompt and dec When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. -:::{note} -`VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. -::: +!!! note + `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): @@ -401,3 +401,4 @@ the below: higher batches. You can do that by adding `--enforce-eager` flag to server (for online serving), or by passing `enforce_eager=True` argument to LLM constructor (for offline inference). +# --8<-- [end:extra-information] diff --git a/docs/source/getting_started/installation/ai_accelerator/neuron.inc.md b/docs/getting_started/installation/ai_accelerator/neuron.inc.md similarity index 74% rename from docs/source/getting_started/installation/ai_accelerator/neuron.inc.md rename to docs/getting_started/installation/ai_accelerator/neuron.inc.md index b4bfb696faa2..f08c78fba6c8 100644 --- a/docs/source/getting_started/installation/ai_accelerator/neuron.inc.md +++ b/docs/getting_started/installation/ai_accelerator/neuron.inc.md @@ -1,14 +1,14 @@ -# Installation +# --8<-- [start:installation] vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK with continuous batching. Paged Attention and Chunked Prefill are currently in development and will be available soon. Data types currently supported in Neuron SDK are FP16 and BF16. -:::{attention} -There are no pre-built wheels or images for this device, so you must build vLLM from source. -::: +!!! warning + There are no pre-built wheels or images for this device, so you must build vLLM from source. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - OS: Linux - Python: 3.9 -- 3.11 @@ -38,7 +38,8 @@ The installation of drivers and tools wouldn't be necessary, if [Deep Learning A sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <= 0.5.3`. You may see an error `cannot import name 'default_dump_dir...`. To work around this, run a `pip install --upgrade triton==3.0.0` after installing the vLLM wheel. -::: +!!! note + The currently supported version of Pytorch for Neuron installs `triton` version `2.1.0`. This is incompatible with `vllm >= 0.5.3`. You may see an error `cannot import name 'default_dump_dir...`. To work around this, run a `pip install --upgrade triton==3.0.0` after installing the vLLM wheel. Following instructions are applicable to Neuron SDK 2.16 and beyond. @@ -94,12 +97,17 @@ source aws_neuron_venv_pytorch/bin/activate # Install Jupyter notebook kernel pip install ipykernel -python3.10 -m ipykernel install --user --name aws_neuron_venv_pytorch --display-name "Python (torch-neuronx)" +python3.10 -m ipykernel install \ + --user \ + --name aws_neuron_venv_pytorch \ + --display-name "Python (torch-neuronx)" pip install jupyter notebook pip install environment_kernels # Set pip repository pointing to the Neuron repository -python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com +python -m pip config set \ + global.extra-index-url \ + https://pip.repos.neuron.amazonaws.com # Install wget, awscli python -m pip install wget @@ -122,18 +130,23 @@ VLLM_TARGET_DEVICE="neuron" pip install . If neuron packages are detected correctly in the installation process, `vllm-0.3.0+neuron212` will be installed. -## Set up using Docker +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] Currently, there are no pre-built Neuron images. -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] -See for instructions on building the Docker image. +See [deployment-docker-build-image-from-source][deployment-docker-build-image-from-source] for instructions on building the Docker image. Make sure to use in place of the default Dockerfile. -## Extra information +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] There is no extra information for this device. +# --8<-- [end:extra-information] diff --git a/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md b/docs/getting_started/installation/ai_accelerator/tpu.inc.md similarity index 55% rename from docs/source/getting_started/installation/ai_accelerator/tpu.inc.md rename to docs/getting_started/installation/ai_accelerator/tpu.inc.md index 4459cc61e1cd..d0b168120137 100644 --- a/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md +++ b/docs/getting_started/installation/ai_accelerator/tpu.inc.md @@ -1,4 +1,4 @@ -# Installation +# --8<-- [start:installation] Tensor Processing Units (TPUs) are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning workloads. TPUs @@ -30,11 +30,11 @@ For TPU pricing information, see [Cloud TPU pricing](https://cloud.google.com/tp You may need additional persistent storage for your TPU VMs. For more information, see [Storage options for Cloud TPU data](https://cloud.devsite.corp.google.com/tpu/docs/storage-options). -:::{attention} -There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. -::: +!!! warning + There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - Google Cloud TPU VM - TPU versions: v6e, v5e, v5p, v4 @@ -51,10 +51,9 @@ When you request queued resources, the request is added to a queue maintained by the Cloud TPU service. When the requested resource becomes available, it's assigned to your Google Cloud project for your immediate exclusive use. -:::{note} -In all of the following commands, replace the ALL CAPS parameter names with -appropriate values. See the parameter descriptions table for more information. -::: +!!! note + In all of the following commands, replace the ALL CAPS parameter names with + appropriate values. See the parameter descriptions table for more information. ### Provision Cloud TPUs with GKE @@ -79,33 +78,15 @@ gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --service-account SERVICE_ACCOUNT ``` -:::{list-table} Parameter descriptions -:header-rows: 1 - -- * Parameter name - * Description -- * QUEUED_RESOURCE_ID - * The user-assigned ID of the queued resource request. -- * TPU_NAME - * The user-assigned name of the TPU which is created when the queued - resource request is allocated. -- * PROJECT_ID - * Your Google Cloud project -- * ZONE - * The GCP zone where you want to create your Cloud TPU. The value you use - depends on the version of TPUs you are using. For more information, see - `TPU regions and zones `_ -- * ACCELERATOR_TYPE - * The TPU version you want to use. Specify the TPU version, for example - `v5litepod-4` specifies a v5e TPU with 4 cores, `v6e-1` specifies a v6e TPU with 1 core. For more information, - see [TPU versions](https://cloud.devsite.corp.google.com/tpu/docs/system-architecture-tpu-vm#versions). -- * RUNTIME_VERSION - * The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). For more information see [TPU VM images](https://cloud.google.com/tpu/docs/runtimes). -- * SERVICE_ACCOUNT - * The email address for your service account. You can find it in the IAM - Cloud Console under *Service Accounts*. For example: - `tpu-service-account@.iam.gserviceaccount.com` -::: +| Parameter name | Description | +|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| QUEUED_RESOURCE_ID | The user-assigned ID of the queued resource request. | +| TPU_NAME | The user-assigned name of the TPU which is created when the queued | +| PROJECT_ID | Your Google Cloud project | +| ZONE | The GCP zone where you want to create your Cloud TPU. The value you use | +| ACCELERATOR_TYPE | The TPU version you want to use. Specify the TPU version, for example | +| RUNTIME_VERSION | The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). For more information see [TPU VM images](https://cloud.google.com/tpu/docs/runtimes). | +
Parameter descriptions
Connect to your TPU using SSH: @@ -113,13 +94,16 @@ Connect to your TPU using SSH: gcloud compute tpus tpu-vm ssh TPU_NAME --zone ZONE ``` -## Set up using Python +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] -### Pre-built wheels +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] Currently, there are no pre-built TPU wheels. -### Build wheel from source +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] Install Miniconda: @@ -161,13 +145,16 @@ Run the setup script: VLLM_TARGET_DEVICE="tpu" python -m pip install -e . ``` -## Set up using Docker +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] -See for instructions on using the official Docker image, making sure to substitute the image name `vllm/vllm-openai` with `vllm/vllm-tpu`. +See [deployment-docker-pre-built-image][deployment-docker-pre-built-image] for instructions on using the official Docker image, making sure to substitute the image name `vllm/vllm-openai` with `vllm/vllm-tpu`. -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] You can use to build a Docker image with TPU support. @@ -182,31 +169,30 @@ Run the Docker image with the following command: docker run --privileged --net host --shm-size=16G -it vllm-tpu ``` -:::{note} -Since TPU relies on XLA which requires static shapes, vLLM bucketizes the -possible input shapes and compiles an XLA graph for each shape. The -compilation time may take 20~30 minutes in the first run. However, the -compilation time reduces to ~5 minutes afterwards because the XLA graphs are -cached in the disk (in {code}`VLLM_XLA_CACHE_PATH` or {code}`~/.cache/vllm/xla_cache` by default). -::: +!!! note + Since TPU relies on XLA which requires static shapes, vLLM bucketizes the + possible input shapes and compiles an XLA graph for each shape. The + compilation time may take 20~30 minutes in the first run. However, the + compilation time reduces to ~5 minutes afterwards because the XLA graphs are + cached in the disk (in `VLLM_XLA_CACHE_PATH` or `~/.cache/vllm/xla_cache` by default). -:::{tip} -If you encounter the following error: +!!! tip + If you encounter the following error: -```console -from torch._C import * # noqa: F403 -ImportError: libopenblas.so.0: cannot open shared object file: No such -file or directory -``` - -Install OpenBLAS with the following command: + ```console + from torch._C import * # noqa: F403 + ImportError: libopenblas.so.0: cannot open shared object file: No such + file or directory + ``` -```console -sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev -``` + Install OpenBLAS with the following command: -::: + ```console + sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev + ``` -## Extra information +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] There is no extra information for this device. +# --8<-- [end:extra-information] diff --git a/docs/source/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md similarity index 74% rename from docs/source/getting_started/installation/cpu.md rename to docs/getting_started/installation/cpu.md index 2c0ec60d7100..18c96b264ad8 100644 --- a/docs/source/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -2,107 +2,47 @@ vLLM is a Python library that supports the following CPU variants. Select your CPU type to see vendor specific instructions: -:::::{tab-set} -:sync-group: device +=== "Intel/AMD x86" -::::{tab-item} Intel/AMD x86 -:selected: -:sync: x86 + --8<-- "docs/getting_started/installation/cpu/x86.inc.md:installation" -:::{include} cpu/x86.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: +=== "ARM AArch64" -:::: + --8<-- "docs/getting_started/installation/cpu/arm.inc.md:installation" -::::{tab-item} ARM AArch64 -:sync: arm +=== "Apple silicon" -:::{include} cpu/arm.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: + --8<-- "docs/getting_started/installation/cpu/apple.inc.md:installation" -:::: +=== "IBM Z (S390X)" -::::{tab-item} Apple silicon -:sync: apple - -:::{include} cpu/apple.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: - -:::: - -::::{tab-item} IBM Z (S390X) -:sync: s390x - -:::{include} cpu/s390x.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: - -:::: - -::::: + --8<-- "docs/getting_started/installation/cpu/s390x.inc.md:installation" ## Requirements - Python: 3.9 -- 3.12 -:::::{tab-set} -:sync-group: device - -::::{tab-item} Intel/AMD x86 -:sync: x86 - -:::{include} cpu/x86.inc.md -:start-after: "## Requirements" -:end-before: "## Set up using Python" -::: - -:::: - -::::{tab-item} ARM AArch64 -:sync: arm - -:::{include} cpu/arm.inc.md -:start-after: "## Requirements" -:end-before: "## Set up using Python" -::: +=== "Intel/AMD x86" -:::: + --8<-- "docs/getting_started/installation/cpu/x86.inc.md:requirements" -::::{tab-item} Apple silicon -:sync: apple +=== "ARM AArch64" -:::{include} cpu/apple.inc.md -:start-after: "## Requirements" -:end-before: "## Set up using Python" -::: + --8<-- "docs/getting_started/installation/cpu/arm.inc.md:requirements" -:::: +=== "Apple silicon" -::::{tab-item} IBM Z (S390X) -:sync: s390x + --8<-- "docs/getting_started/installation/cpu/apple.inc.md:requirements" -:::{include} cpu/s390x.inc.md -:start-after: "## Requirements" -:end-before: "## Set up using Python" -::: +=== "IBM Z (S390X)" -:::: - -::::: + --8<-- "docs/getting_started/installation/cpu/s390x.inc.md:requirements" ## Set up using Python ### Create a new Python environment -:::{include} python_env_setup.inc.md -::: +--8<-- "docs/getting_started/installation/python_env_setup.inc.md" ### Pre-built wheels @@ -110,69 +50,29 @@ Currently, there are no pre-built CPU wheels. ### Build wheel from source -:::::{tab-set} -:sync-group: device - -::::{tab-item} Intel/AMD x86 -:sync: x86 - -:::{include} cpu/x86.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: - -:::: - -::::{tab-item} ARM AArch64 -:sync: arm +=== "Intel/AMD x86" -:::{include} cpu/arm.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: + --8<-- "docs/getting_started/installation/cpu/x86.inc.md:build-wheel-from-source" -:::: +=== "ARM AArch64" -::::{tab-item} Apple silicon -:sync: apple + --8<-- "docs/getting_started/installation/cpu/arm.inc.md:build-wheel-from-source" -:::{include} cpu/apple.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: +=== "Apple silicon" -:::: + --8<-- "docs/getting_started/installation/cpu/apple.inc.md:build-wheel-from-source" -::::{tab-item} IBM Z (s390x) -:sync: s390x +=== "IBM Z (s390x)" -:::{include} cpu/s390x.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: - -:::: - -::::: + --8<-- "docs/getting_started/installation/cpu/s390x.inc.md:build-wheel-from-source" ## Set up using Docker ### Pre-built images -:::::{tab-set} -:sync-group: device - -::::{tab-item} Intel/AMD x86 -:sync: x86 - -:::{include} cpu/x86.inc.md -:start-after: "### Pre-built images" -:end-before: "### Build image from source" -::: - -:::: +=== "Intel/AMD x86" -::::: + --8<-- "docs/getting_started/installation/cpu/x86.inc.md:pre-built-images" ### Build image from source @@ -192,13 +92,11 @@ $ docker run --rm \ other vLLM OpenAI server arguments ``` -::::{tip} -For ARM or Apple silicon, use `docker/Dockerfile.arm` -:::: +!!! tip + For ARM or Apple silicon, use `docker/Dockerfile.arm` -::::{tip} -For IBM Z (s390x), use `docker/Dockerfile.s390x` and in `docker run` use flag `--dtype float` -:::: +!!! tip + For IBM Z (s390x), use `docker/Dockerfile.s390x` and in `docker run` use flag `--dtype float` ## Supported features diff --git a/docs/source/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md similarity index 58% rename from docs/source/getting_started/installation/cpu/apple.inc.md rename to docs/getting_started/installation/cpu/apple.inc.md index 7bc9e85ecd96..7a91e3ce5e5b 100644 --- a/docs/source/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -1,24 +1,27 @@ -# Installation +# --8<-- [start:installation] vLLM has experimental support for macOS with Apple silicon. For now, users shall build from the source vLLM to natively run on macOS. Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. -:::{attention} -There are no pre-built wheels or images for this device, so you must build vLLM from source. -::: +!!! warning + There are no pre-built wheels or images for this device, so you must build vLLM from source. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - OS: `macOS Sonoma` or later - SDK: `XCode 15.4` or later with Command Line Tools - Compiler: `Apple Clang >= 15.0.0` -## Set up using Python +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] -### Pre-built wheels +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] -### Build wheel from source +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] After installation of XCode and the Command Line Tools, which include Apple Clang, execute the following commands to build and install vLLM from the source. @@ -29,9 +32,8 @@ pip install -r requirements/cpu.txt pip install -e . ``` -:::{note} -On macOS the `VLLM_TARGET_DEVICE` is automatically set to `cpu`, which currently is the only supported device. -::: +!!! note + On macOS the `VLLM_TARGET_DEVICE` is automatically set to `cpu`, which currently is the only supported device. #### Troubleshooting @@ -51,10 +53,15 @@ If the build has error like the following snippet where standard C++ headers can 1 error generated. ``` -## Set up using Docker +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] -## Extra information +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] +# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/arm.inc.md b/docs/getting_started/installation/cpu/arm.inc.md new file mode 100644 index 000000000000..59b71dcaf911 --- /dev/null +++ b/docs/getting_started/installation/cpu/arm.inc.md @@ -0,0 +1,41 @@ +# --8<-- [start:installation] + +vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform. + +ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. + +!!! warning + There are no pre-built wheels or images for this device, so you must build vLLM from source. + +# --8<-- [end:installation] +# --8<-- [start:requirements] + +- OS: Linux +- Compiler: `gcc/g++ >= 12.3.0` (optional, recommended) +- Instruction Set Architecture (ISA): NEON support is required + +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] + +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] + +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] + +--8<-- "docs/getting_started/installation/cpu/cpu/build.inc.md" + +Testing has been conducted on AWS Graviton3 instances for compatibility. + +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] + +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] + +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] + +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] +# --8<-- [end:extra-information] diff --git a/docs/source/getting_started/installation/cpu/build.inc.md b/docs/getting_started/installation/cpu/build.inc.md similarity index 96% rename from docs/source/getting_started/installation/cpu/build.inc.md rename to docs/getting_started/installation/cpu/build.inc.md index f385f3d5b198..7d6472afa7ea 100644 --- a/docs/source/getting_started/installation/cpu/build.inc.md +++ b/docs/getting_started/installation/cpu/build.inc.md @@ -32,3 +32,5 @@ If you want to develop vllm, install it in editable mode instead. ```console VLLM_TARGET_DEVICE=cpu python setup.py develop ``` + +# --8<-- [end:extra-information] diff --git a/docs/source/getting_started/installation/cpu/s390x.inc.md b/docs/getting_started/installation/cpu/s390x.inc.md similarity index 64% rename from docs/source/getting_started/installation/cpu/s390x.inc.md rename to docs/getting_started/installation/cpu/s390x.inc.md index 9b41173b44ce..670485feefb6 100644 --- a/docs/source/getting_started/installation/cpu/s390x.inc.md +++ b/docs/getting_started/installation/cpu/s390x.inc.md @@ -1,25 +1,28 @@ -# Installation +# --8<-- [start:installation] vLLM has experimental support for s390x architecture on IBM Z platform. For now, users shall build from the vLLM source to natively run on IBM Z platform. Currently the CPU implementation for s390x architecture supports FP32 datatype only. -:::{attention} -There are no pre-built wheels or images for this device, so you must build vLLM from source. -::: +!!! warning + There are no pre-built wheels or images for this device, so you must build vLLM from source. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - OS: `Linux` - SDK: `gcc/g++ >= 12.3.0` or later with Command Line Tools - Instruction Set Architecture (ISA): VXE support is required. Works with Z14 and above. - Build install python packages: `pyarrow`, `torch` and `torchvision` -## Set up using Python +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] -### Pre-built wheels +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] -### Build wheel from source +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] Install the following packages from the package manager before building the vLLM. For example on RHEL 9.4: @@ -39,9 +42,8 @@ curl https://sh.rustup.rs -sSf | sh -s -- -y && \ Execute the following commands to build and install vLLM from the source. -::::{tip} -Please build the following dependencies, `torchvision`, `pyarrow` from the source before building vLLM. -:::: +!!! tip + Please build the following dependencies, `torchvision`, `pyarrow` from the source before building vLLM. ```console sed -i '/^torch/d' requirements-build.txt # remove torch from requirements-build.txt since we use nightly builds @@ -53,10 +55,15 @@ Please build the following dependencies, `torchvision`, `pyarrow` from the sourc pip install dist/*.whl ``` -## Set up using Docker +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] -## Extra information +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] +# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md new file mode 100644 index 000000000000..9434eeea8b4a --- /dev/null +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -0,0 +1,46 @@ +# --8<-- [start:installation] + +vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. + +!!! warning + There are no pre-built wheels or images for this device, so you must build vLLM from source. + +# --8<-- [end:installation] +# --8<-- [start:requirements] + +- OS: Linux +- Compiler: `gcc/g++ >= 12.3.0` (optional, recommended) +- Instruction Set Architecture (ISA): AVX512 (optional, recommended) + +!!! tip + [Intel Extension for PyTorch (IPEX)](https://github.com/intel/intel-extension-for-pytorch) extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. + +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] + +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] + +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] + +--8<-- "docs/getting_started/installation/cpu/cpu/build.inc.md" + +!!! note + - AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, which brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16. + - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable `VLLM_CPU_AVX512BF16=1` before the building. + +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] + +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] + +See [https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo) + +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] + +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] +# --8<-- [end:extra-information] diff --git a/docs/source/getting_started/installation/device.template.md b/docs/getting_started/installation/device.template.md similarity index 100% rename from docs/source/getting_started/installation/device.template.md rename to docs/getting_started/installation/device.template.md diff --git a/docs/getting_started/installation/gpu.md b/docs/getting_started/installation/gpu.md new file mode 100644 index 000000000000..3c983f600673 --- /dev/null +++ b/docs/getting_started/installation/gpu.md @@ -0,0 +1,124 @@ +# GPU + +vLLM is a Python library that supports the following GPU variants. Select your GPU type to see vendor specific instructions: + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:installation" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:installation" + +=== "Intel XPU" + + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:installation" + +## Requirements + +- OS: Linux +- Python: 3.9 -- 3.12 + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:requirements" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:requirements" + +=== "Intel XPU" + + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:requirements" + +## Set up using Python + +### Create a new Python environment + +--8<-- "docs/getting_started/installation/python_env_setup.inc.md" + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:create-a-new-python-environment" + +=== "AMD ROCm" + + There is no extra information on creating a new Python environment for this device. + +=== "Intel XPU" + + There is no extra information on creating a new Python environment for this device. + +### Pre-built wheels + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:pre-built-wheels" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:pre-built-wheels" + +=== "Intel XPU" + + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:pre-built-wheels" + +[](){ #build-from-source } + +### Build wheel from source + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:build-wheel-from-source" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:build-wheel-from-source" + +=== "Intel XPU" + + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:build-wheel-from-source" + +## Set up using Docker + +### Pre-built images + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:pre-built-images" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:pre-built-images" + +=== "Intel XPU" + + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:pre-built-images" + +### Build image from source + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:build-image-from-source" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:build-image-from-source" + +=== "Intel XPU" + + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:build-image-from-source" + +## Supported features + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:supported-features" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:supported-features" + +=== "Intel XPU" + + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:supported-features" diff --git a/docs/source/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md similarity index 62% rename from docs/source/getting_started/installation/gpu/cuda.inc.md rename to docs/getting_started/installation/gpu/cuda.inc.md index 06915f09dd51..64dccef63d73 100644 --- a/docs/source/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -1,43 +1,52 @@ -# Installation +# --8<-- [start:installation] -vLLM contains pre-compiled C++ and CUDA (12.6) binaries. +vLLM contains pre-compiled C++ and CUDA (12.8) binaries. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.) -## Set up using Python +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] ### Create a new Python environment -:::{note} -PyTorch installed via `conda` will statically link `NCCL` library, which can cause issues when vLLM tries to use `NCCL`. See for more details. -::: +!!! note + PyTorch installed via `conda` will statically link `NCCL` library, which can cause issues when vLLM tries to use `NCCL`. See for more details. In order to be performant, vLLM has to compile many cuda kernels. The compilation unfortunately introduces binary incompatibility with other CUDA versions and PyTorch versions, even for the same PyTorch version with different building configurations. -Therefore, it is recommended to install vLLM with a **fresh new** environment. If either you have a different CUDA version or you want to use an existing PyTorch installation, you need to build vLLM from source. See [below](#build-from-source) for more details. +Therefore, it is recommended to install vLLM with a **fresh new** environment. If either you have a different CUDA version or you want to use an existing PyTorch installation, you need to build vLLM from source. See [below][build-from-source] for more details. -### Pre-built wheels +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] You can install vLLM using either `pip` or `uv pip`: ```console -# Install vLLM with CUDA 12.6. -pip install vllm # If you are using pip. -uv pip install vllm # If you are using uv. +# Install vLLM with CUDA 12.8. +# If you are using pip. +pip install vllm --extra-index-url https://download.pytorch.org/whl/cu128 +# If you are using uv. +uv pip install vllm --torch-backend=auto ``` -As of now, vLLM's binaries are compiled with CUDA 12.6 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.8, 11.8, and public PyTorch release versions: +We recommend leveraging `uv` to [automatically select the appropriate PyTorch index at runtime](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection) by inspecting the installed CUDA driver version via `--torch-backend=auto` (or `UV_TORCH_BACKEND=auto`). To select a specific backend (e.g., `cu126`), set `--torch-backend=cu126` (or `UV_TORCH_BACKEND=cu126`). If this doesn't work, try running `uv self update` to update `uv` first. + +!!! note + NVIDIA Blackwell GPUs (B200, GB200) require a minimum of CUDA 12.8, so make sure you are installing PyTorch wheels with at least that version. PyTorch itself offers a [dedicated interface](https://pytorch.org/get-started/locally/) to determine the appropriate pip command to run for a given target configuration. + +As of now, vLLM's binaries are compiled with CUDA 12.8 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.6, 11.8, and public PyTorch release versions: ```console # Install vLLM with CUDA 11.8. export VLLM_VERSION=0.6.1.post1 -export PYTHON_VERSION=310 -pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 +export PYTHON_VERSION=312 +uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` -(install-the-latest-code)= +[](){ #install-the-latest-code } #### Install the latest code @@ -46,40 +55,47 @@ LLM inference is a fast-evolving field, and the latest code may contain bug fixe ##### Install the latest code using `pip` ```console -pip install -U vllm --pre --extra-index-url https://wheels.vllm.ai/nightly +pip install -U vllm \ + --pre \ + --extra-index-url https://wheels.vllm.ai/nightly ``` `--pre` is required for `pip` to consider pre-released versions. -If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), due to the limitation of `pip`, you have to specify the full URL of the wheel file by embedding the commit hash in the URL: +Another way to install the latest code is to use `uv`: ```console -export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch -pip install https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl +uv pip install -U vllm \ + --torch-backend=auto \ + --extra-index-url https://wheels.vllm.ai/nightly ``` -Note that the wheels are built with Python 3.8 ABI (see [PEP 425](https://peps.python.org/pep-0425/) for more details about ABI), so **they are compatible with Python 3.8 and later**. The version string in the wheel file name (`1.0.0.dev`) is just a placeholder to have a unified URL for the wheels, the actual versions of wheels are contained in the wheel metadata (the wheels listed in the extra index url have correct versions). Although we don't support Python 3.8 any more (because PyTorch 2.5 dropped support for Python 3.8), the wheels are still built with Python 3.8 ABI to keep the same wheel name as before. - -##### Install the latest code using `uv` +##### Install specific revisions using `pip` -Another way to install the latest code is to use `uv`: +If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), due to the limitation of `pip`, you have to specify the full URL of the wheel file by embedding the commit hash in the URL: ```console -uv pip install -U vllm --extra-index-url https://wheels.vllm.ai/nightly +export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch +pip install https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl ``` +Note that the wheels are built with Python 3.8 ABI (see [PEP 425](https://peps.python.org/pep-0425/) for more details about ABI), so **they are compatible with Python 3.8 and later**. The version string in the wheel file name (`1.0.0.dev`) is just a placeholder to have a unified URL for the wheels, the actual versions of wheels are contained in the wheel metadata (the wheels listed in the extra index url have correct versions). Although we don't support Python 3.8 any more (because PyTorch 2.5 dropped support for Python 3.8), the wheels are still built with Python 3.8 ABI to keep the same wheel name as before. + ##### Install specific revisions using `uv` If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL: ```console export VLLM_COMMIT=72d9c316d3f6ede485146fe5aabd4e61dbc59069 # use full commit hash from the main branch -uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT} +uv pip install vllm \ + --torch-backend=auto \ + --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT} ``` The `uv` approach works for vLLM `v0.6.6` and later and offers an easy-to-remember command. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. -### Build wheel from source +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] #### Set up using Python-only build (without compilation) @@ -92,15 +108,15 @@ VLLM_USE_PRECOMPILED=1 pip install --editable . ``` This command will do the following: + 1. Look for the current branch in your vLLM clone. -2. Identify the corresponding base commit in the main branch. -3. Download the pre-built wheel of the base commit. -4. Use its compiled libraries in the installation. +1. Identify the corresponding base commit in the main branch. +1. Download the pre-built wheel of the base commit. +1. Use its compiled libraries in the installation. -:::{note} -1. If you change C++ or kernel code, you cannot use Python-only build; otherwise you will see an import error about library not found or undefined symbol. -2. If you rebase your dev branch, it is recommended to uninstall vllm and re-run the above command to make sure your libraries are up to date. -::: +!!! note + 1. If you change C++ or kernel code, you cannot use Python-only build; otherwise you will see an import error about library not found or undefined symbol. + 2. If you rebase your dev branch, it is recommended to uninstall vllm and re-run the above command to make sure your libraries are up to date. In case you see an error about wheel not found when running the above command, it might be because the commit you based on in the main branch was just merged and the wheel is being built. In this case, you can wait for around an hour to try again, or manually assign the previous commit in the installation using the `VLLM_PRECOMPILED_WHEEL_LOCATION` environment variable. @@ -110,12 +126,11 @@ export VLLM_PRECOMPILED_WHEEL_LOCATION=https://wheels.vllm.ai/${VLLM_COMMIT}/vll pip install --editable . ``` -You can find more information about vLLM's wheels in . +You can find more information about vLLM's wheels in [install-the-latest-code][install-the-latest-code]. -:::{note} -There is a possibility that your source code may have a different commit ID compared to the latest vLLM wheel, which could potentially lead to unknown errors. -It is recommended to use the same commit ID for the source code as the vLLM wheel you have installed. Please refer to for instructions on how to install a specified wheel. -::: +!!! note + There is a possibility that your source code may have a different commit ID compared to the latest vLLM wheel, which could potentially lead to unknown errors. + It is recommended to use the same commit ID for the source code as the vLLM wheel you have installed. Please refer to [install-the-latest-code][install-the-latest-code] for instructions on how to install a specified wheel. #### Full build (with compilation) @@ -127,17 +142,16 @@ cd vllm pip install -e . ``` -:::{tip} -Building from source requires a lot of compilation. If you are building from source repeatedly, it's more efficient to cache the compilation results. +!!! tip + Building from source requires a lot of compilation. If you are building from source repeatedly, it's more efficient to cache the compilation results. -For example, you can install [ccache](https://github.com/ccache/ccache) using `conda install ccache` or `apt install ccache` . -As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, subsequent builds will be much faster. + For example, you can install [ccache](https://github.com/ccache/ccache) using `conda install ccache` or `apt install ccache` . + As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, subsequent builds will be much faster. -When using `ccache` with `pip install -e .`, you should run `CCACHE_NOHASHDIR="true" pip install --no-build-isolation -e .`. This is because `pip` creates a new folder with a random name for each build, preventing `ccache` from recognizing that the same files are being built. + When using `ccache` with `pip install -e .`, you should run `CCACHE_NOHASHDIR="true" pip install --no-build-isolation -e .`. This is because `pip` creates a new folder with a random name for each build, preventing `ccache` from recognizing that the same files are being built. -[sccache](https://github.com/mozilla/sccache) works similarly to `ccache`, but has the capability to utilize caching in remote storage environments. -The following environment variables can be set to configure the vLLM `sccache` remote: `SCCACHE_BUCKET=vllm-build-sccache SCCACHE_REGION=us-west-2 SCCACHE_S3_NO_CREDENTIALS=1`. We also recommend setting `SCCACHE_IDLE_TIMEOUT=0`. -::: + [sccache](https://github.com/mozilla/sccache) works similarly to `ccache`, but has the capability to utilize caching in remote storage environments. + The following environment variables can be set to configure the vLLM `sccache` remote: `SCCACHE_BUCKET=vllm-build-sccache SCCACHE_REGION=us-west-2 SCCACHE_S3_NO_CREDENTIALS=1`. We also recommend setting `SCCACHE_IDLE_TIMEOUT=0`. ##### Use an existing PyTorch installation @@ -184,7 +198,11 @@ Additionally, if you have trouble building vLLM, we recommend using the NVIDIA P ```console # Use `--ipc=host` to make sure the shared memory is large enough. -docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3 +docker run \ + --gpus all \ + -it \ + --rm \ + --ipc=host nvcr.io/nvidia/pytorch:23.10-py3 ``` If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from [the official website](https://developer.nvidia.com/cuda-toolkit-archive). After installation, set the environment variable `CUDA_HOME` to the installation path of CUDA Toolkit, and make sure that the `nvcc` compiler is in your `PATH`, e.g.: @@ -212,11 +230,13 @@ export VLLM_TARGET_DEVICE=empty pip install -e . ``` -## Set up using Docker +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] -See for instructions on using the official Docker image. +See [deployment-docker-pre-built-image][deployment-docker-pre-built-image] for instructions on using the official Docker image. Another way to access the latest code is to use the docker images: @@ -229,10 +249,12 @@ These docker images are used for CI and testing only, and they are not intended The latest code can contain bugs and may not be stable. Please use it with caution. -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] -See for instructions on building the Docker image. +See [deployment-docker-build-image-from-source][deployment-docker-build-image-from-source] for instructions on building the Docker image. ## Supported features -See compatibility matrix for feature support information. +See [feature-x-hardware][feature-x-hardware] compatibility matrix for feature support information. +# --8<-- [end:extra-information] diff --git a/docs/source/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md similarity index 66% rename from docs/source/getting_started/installation/gpu/rocm.inc.md rename to docs/getting_started/installation/gpu/rocm.inc.md index dc74368fe2c9..0029b3a24496 100644 --- a/docs/source/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -1,28 +1,31 @@ -# Installation +# --8<-- [start:installation] vLLM supports AMD GPUs with ROCm 6.3. -:::{attention} -There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. -::: +!!! warning + There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100/1101), Radeon RX 9000 series (gfx1200/1201) - ROCm 6.3 -## Set up using Python +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] -### Pre-built wheels +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] Currently, there are no pre-built ROCm wheels. -### Build wheel from source +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] 0. Install prerequisites (skip if you are already in an environment/docker with the following installed): -- [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html) -- [PyTorch](https://pytorch.org/) + - [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html) + - [PyTorch](https://pytorch.org/) For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. @@ -49,9 +52,8 @@ Currently, there are no pre-built ROCm wheels. cd ../.. ``` - :::{note} - If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. - ::: + !!! note + If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. 2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention) @@ -69,9 +71,8 @@ Currently, there are no pre-built ROCm wheels. cd .. ``` - :::{note} - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - ::: + !!! note + You might need to downgrade the "ninja" version to 1.10 as it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) 3. If you choose to build AITER yourself to use a certain branch or commit, you can build AITER using the following steps: @@ -84,55 +85,56 @@ Currently, there are no pre-built ROCm wheels. python3 setup.py develop ``` - :::{note} - You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose. - ::: + !!! note + You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose. 4. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps: ```bash - $ pip install --upgrade pip + pip install --upgrade pip # Build & install AMD SMI - $ pip install /opt/rocm/share/amd_smi + pip install /opt/rocm/share/amd_smi # Install dependencies - $ pip install --upgrade numba scipy huggingface-hub[cli,hf_transfer] setuptools_scm - $ pip install "numpy<2" - $ pip install -r requirements/rocm.txt + pip install --upgrade numba \ + scipy \ + huggingface-hub[cli,hf_transfer] \ + setuptools_scm + pip install "numpy<2" + pip install -r requirements/rocm.txt # Build vLLM for MI210/MI250/MI300. - $ export PYTORCH_ROCM_ARCH="gfx90a;gfx942" - $ python3 setup.py develop + export PYTORCH_ROCM_ARCH="gfx90a;gfx942" + python3 setup.py develop ``` This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation. - :::{tip} - - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. - - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. - - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. - - The ROCm version of PyTorch, ideally, should match the ROCm driver version. - ::: + !!! tip + - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. + - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. + - The ROCm version of PyTorch, ideally, should match the ROCm driver version. -:::{tip} -- For MI300x (gfx942) users, to achieve optimal performance, please refer to [MI300x tuning guide](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html) for performance optimization and tuning tips on system and workflow level. - For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization). -::: +!!! tip + - For MI300x (gfx942) users, to achieve optimal performance, please refer to [MI300x tuning guide](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html) for performance optimization and tuning tips on system and workflow level. + For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization). ## Set up using Docker (Recommended) -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] The [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized docker image designed for validating inference performance on the AMD Instinctโ„ข MI300X accelerator. -:::{tip} -Please check [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/performance-validation/mi300x/vllm-benchmark.html) -for instructions on how to use this prebuilt docker image. -::: +!!! tip + Please check [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/performance-validation/mi300x/vllm-benchmark.html) + for instructions on how to use this prebuilt docker image. -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] Building the Docker image from source is the recommended way to use vLLM with ROCm. @@ -155,7 +157,9 @@ It is important that the user kicks off the docker build using buildkit. Either To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: ```console -DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm_base -t rocm/vllm-dev:base . +DOCKER_BUILDKIT=1 docker build \ + -f docker/Dockerfile.rocm_base \ + -t rocm/vllm-dev:base . ``` #### Build an image with vLLM @@ -190,7 +194,11 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm -t vllm-rocm . To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: ```console -DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f docker/Dockerfile.rocm -t vllm-rocm . +DOCKER_BUILDKIT=1 docker build \ + --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" \ + -f docker/Dockerfile.rocm \ + -t vllm-rocm \ + . ``` To run the above docker image `vllm-rocm`, use the below command: @@ -213,4 +221,5 @@ Where the `` is the location where the model is stored, for examp ## Supported features -See compatibility matrix for feature support information. +See [feature-x-hardware][feature-x-hardware] compatibility matrix for feature support information. +# --8<-- [end:extra-information] diff --git a/docs/source/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md similarity index 67% rename from docs/source/getting_started/installation/gpu/xpu.inc.md rename to docs/getting_started/installation/gpu/xpu.inc.md index 4ab41a21c2a1..bee9a7ebb717 100644 --- a/docs/source/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -1,23 +1,26 @@ -# Installation +# --8<-- [start:installation] vLLM initially supports basic model inference and serving on Intel GPU platform. -:::{attention} -There are no pre-built wheels or images for this device, so you must build vLLM from source. -::: +!!! warning + There are no pre-built wheels or images for this device, so you must build vLLM from source. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - Supported Hardware: Intel Data Center GPU, Intel ARC GPU - OneAPI requirements: oneAPI 2025.0 -## Set up using Python +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] -### Pre-built wheels +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] Currently, there are no pre-built XPU wheels. -### Build wheel from source +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] - First, install required driver and Intel OneAPI 2025.0 or later. - Second, install Python packages for vLLM XPU backend building: @@ -35,18 +38,20 @@ pip install -v -r requirements/xpu.txt VLLM_TARGET_DEVICE=xpu python setup.py install ``` -:::{note} -- FP16 is the default data type in the current XPU backend. The BF16 data - type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. -::: +!!! note + - FP16 is the default data type in the current XPU backend. The BF16 data + type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. -## Set up using Docker +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] Currently, there are no pre-built XPU images. -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] ```console $ docker build -f docker/Dockerfile.xpu -t vllm-xpu-env --shm-size=4g . @@ -66,7 +71,6 @@ XPU platform supports **tensor parallel** inference/serving and also supports ** python -m vllm.entrypoints.openai.api_server \ --model=facebook/opt-13b \ --dtype=bfloat16 \ - --device=xpu \ --max_model_len=1024 \ --distributed-executor-backend=ray \ --pipeline-parallel-size=2 \ @@ -74,3 +78,4 @@ python -m vllm.entrypoints.openai.api_server \ ``` By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the helper script. +# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/python_env_setup.inc.md b/docs/getting_started/installation/python_env_setup.inc.md new file mode 100644 index 000000000000..911301d68335 --- /dev/null +++ b/docs/getting_started/installation/python_env_setup.inc.md @@ -0,0 +1,6 @@ +It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: + +```console +uv venv --python 3.12 --seed +source .venv/bin/activate +``` diff --git a/docs/source/getting_started/quickstart.md b/docs/getting_started/quickstart.md similarity index 67% rename from docs/source/getting_started/quickstart.md rename to docs/getting_started/quickstart.md index 25189b006c26..d24e75e8141d 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -1,11 +1,12 @@ -(quickstart)= - -# Quickstart +--- +title: Quickstart +--- +[](){ #quickstart } This guide will help you quickly get started with vLLM to perform: -- [Offline batched inference](#quickstart-offline) -- [Online serving using OpenAI-compatible server](#quickstart-online) +- [Offline batched inference][quickstart-offline] +- [Online serving using OpenAI-compatible server][quickstart-online] ## Prerequisites @@ -19,50 +20,51 @@ If you are using NVIDIA GPUs, you can install vLLM using [pip](https://pypi.org/ It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: ```console -uv venv myenv --python 3.12 --seed -source myenv/bin/activate -uv pip install vllm +uv venv --python 3.12 --seed +source .venv/bin/activate +uv pip install vllm --torch-backend=auto ``` -Another delightful way is to use `uv run` with `--with [dependency]` option, which allows you to run commands such as `vllm serve` without creating an environment: +`uv` can [automatically select the appropriate PyTorch index at runtime](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection) by inspecting the installed CUDA driver version via `--torch-backend=auto` (or `UV_TORCH_BACKEND=auto`). To select a specific backend (e.g., `cu126`), set `--torch-backend=cu126` (or `UV_TORCH_BACKEND=cu126`). + +Another delightful way is to use `uv run` with `--with [dependency]` option, which allows you to run commands such as `vllm serve` without creating any permanent environment: ```console uv run --with vllm vllm --help ``` -You can also use [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html) to create and manage Python environments. +You can also use [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html) to create and manage Python environments. You can install `uv` to the conda environment through `pip` if you want to manage it within the environment. ```console conda create -n myenv python=3.12 -y conda activate myenv -pip install vllm +pip install --upgrade uv +uv pip install vllm --torch-backend=auto ``` -:::{note} -For non-CUDA platforms, please refer [here](#installation-index) for specific instructions on how to install vLLM. -::: +!!! note + For more detail and non-CUDA platforms, please refer [here][installation-index] for specific instructions on how to install vLLM. -(quickstart-offline)= +[](){ #quickstart-offline } ## Offline Batched Inference With vLLM installed, you can start generating texts for list of input prompts (i.e. offline batch inferencing). See the example script: -The first line of this example imports the classes {class}`~vllm.LLM` and {class}`~vllm.SamplingParams`: +The first line of this example imports the classes [LLM][vllm.LLM] and [SamplingParams][vllm.SamplingParams]: -- {class}`~vllm.LLM` is the main class for running offline inference with vLLM engine. -- {class}`~vllm.SamplingParams` specifies the parameters for the sampling process. +- [LLM][vllm.LLM] is the main class for running offline inference with vLLM engine. +- [SamplingParams][vllm.SamplingParams] specifies the parameters for the sampling process. ```python from vllm import LLM, SamplingParams ``` -The next section defines a list of input prompts and sampling parameters for text generation. The [sampling temperature](https://arxiv.org/html/2402.05201v1) is set to `0.8` and the [nucleus sampling probability](https://en.wikipedia.org/wiki/Top-p_sampling) is set to `0.95`. You can find more information about the sampling parameters [here](#sampling-params). -:::{important} -By default, vLLM will use sampling parameters recommended by model creator by applying the `generation_config.json` from the Hugging Face model repository if it exists. In most cases, this will provide you with the best results by default if {class}`~vllm.SamplingParams` is not specified. +The next section defines a list of input prompts and sampling parameters for text generation. The [sampling temperature](https://arxiv.org/html/2402.05201v1) is set to `0.8` and the [nucleus sampling probability](https://en.wikipedia.org/wiki/Top-p_sampling) is set to `0.95`. You can find more information about the sampling parameters [here][sampling-params]. +!!! warning + By default, vLLM will use sampling parameters recommended by model creator by applying the `generation_config.json` from the Hugging Face model repository if it exists. In most cases, this will provide you with the best results by default if [SamplingParams][vllm.SamplingParams] is not specified. -However, if vLLM's default sampling parameters are preferred, please set `generation_config="vllm"` when creating the {class}`~vllm.LLM` instance. -::: + However, if vLLM's default sampling parameters are preferred, please set `generation_config="vllm"` when creating the [LLM][vllm.LLM] instance. ```python prompts = [ @@ -74,15 +76,18 @@ prompts = [ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) ``` -The {class}`~vllm.LLM` class initializes vLLM's engine and the [OPT-125M model](https://arxiv.org/abs/2205.01068) for offline inference. The list of supported models can be found [here](#supported-models). +The [LLM][vllm.LLM] class initializes vLLM's engine and the [OPT-125M model](https://arxiv.org/abs/2205.01068) for offline inference. The list of supported models can be found [here][supported-models]. ```python llm = LLM(model="facebook/opt-125m") ``` -:::{note} -By default, vLLM downloads models from [Hugging Face](https://huggingface.co/). If you would like to use models from [ModelScope](https://www.modelscope.cn), set the environment variable `VLLM_USE_MODELSCOPE` before initializing the engine. -::: +!!! note + By default, vLLM downloads models from [Hugging Face](https://huggingface.co/). If you would like to use models from [ModelScope](https://www.modelscope.cn), set the environment variable `VLLM_USE_MODELSCOPE` before initializing the engine. + + ```shell + export VLLM_USE_MODELSCOPE=True + ``` Now, the fun part! The outputs are generated using `llm.generate`. It adds the input prompts to the vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of `RequestOutput` objects, which include all of the output tokens. @@ -95,7 +100,7 @@ for output in outputs: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -(quickstart-online)= +[](){ #quickstart-online } ## OpenAI-Compatible Server @@ -108,15 +113,13 @@ Run the following command to start the vLLM server with the [Qwen2.5-1.5B-Instru vllm serve Qwen/Qwen2.5-1.5B-Instruct ``` -:::{note} -By default, the server uses a predefined chat template stored in the tokenizer. -You can learn about overriding it [here](#chat-template). -::: -:::{important} -By default, the server applies `generation_config.json` from the huggingface model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator. +!!! note + By default, the server uses a predefined chat template stored in the tokenizer. + You can learn about overriding it [here][chat-template]. +!!! warning + By default, the server applies `generation_config.json` from the huggingface model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator. -To disable this behavior, please pass `--generation-config vllm` when launching the server. -::: + To disable this behavior, please pass `--generation-config vllm` when launching the server. This server can be queried in the same format as OpenAI API. For example, to list the models: @@ -207,6 +210,5 @@ Currently, vLLM supports multiple backends for efficient Attention computation a If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`. -```{attention} -There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see for instructions on how to install it. -``` +!!! warning + There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see for instructions on how to install it. diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index 747ffb7b3033..000000000000 --- a/docs/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.https://www.sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "" goto help - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py new file mode 100644 index 000000000000..6f290efe45c2 --- /dev/null +++ b/docs/mkdocs/hooks/generate_examples.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +import itertools +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal + +import regex as re + +ROOT_DIR = Path(__file__).parent.parent.parent.parent +ROOT_DIR_RELATIVE = '../../../../..' +EXAMPLE_DIR = ROOT_DIR / "examples" +EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples" +print(ROOT_DIR.resolve()) +print(EXAMPLE_DIR.resolve()) +print(EXAMPLE_DOC_DIR.resolve()) + + +def fix_case(text: str) -> str: + subs = { + "api": "API", + "cli": "CLI", + "cpu": "CPU", + "llm": "LLM", + "mae": "MAE", + "tpu": "TPU", + "aqlm": "AQLM", + "gguf": "GGUF", + "lora": "LoRA", + "rlhf": "RLHF", + "vllm": "vLLM", + "openai": "OpenAI", + "lmcache": "LMCache", + "multilora": "MultiLoRA", + "mlpspeculator": "MLPSpeculator", + r"fp\d+": lambda x: x.group(0).upper(), # e.g. fp16, fp32 + r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16 + } + for pattern, repl in subs.items(): + text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE) + return text + + +@dataclass +class Example: + """ + Example class for generating documentation content from a given path. + + Attributes: + path (Path): The path to the main directory or file. + category (str): The category of the document. + main_file (Path): The main file in the directory. + other_files (list[Path]): list of other files in the directory. + title (str): The title of the document. + + Methods: + __post_init__(): Initializes the main_file, other_files, and title attributes. + determine_main_file() -> Path: Determines the main file in the given path. + determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file. + determine_title() -> str: Determines the title of the document. + generate() -> str: Generates the documentation content. + """ # noqa: E501 + path: Path + category: str = None + main_file: Path = field(init=False) + other_files: list[Path] = field(init=False) + title: str = field(init=False) + + def __post_init__(self): + self.main_file = self.determine_main_file() + self.other_files = self.determine_other_files() + self.title = self.determine_title() + + def determine_main_file(self) -> Path: + """ + Determines the main file in the given path. + If the path is a file, it returns the path itself. Otherwise, it searches + for Markdown files (*.md) in the directory and returns the first one found. + Returns: + Path: The main file path, either the original path if it's a file or the first + Markdown file found in the directory. + Raises: + IndexError: If no Markdown files are found in the directory. + """ # noqa: E501 + return self.path if self.path.is_file() else list( + self.path.glob("*.md")).pop() + + def determine_other_files(self) -> list[Path]: + """ + Determine other files in the directory excluding the main file. + + This method checks if the given path is a file. If it is, it returns an empty list. + Otherwise, it recursively searches through the directory and returns a list of all + files that are not the main file. + + Returns: + list[Path]: A list of Path objects representing the other files in the directory. + """ # noqa: E501 + if self.path.is_file(): + return [] + is_other_file = lambda file: file.is_file() and file != self.main_file + return [file for file in self.path.rglob("*") if is_other_file(file)] + + def determine_title(self) -> str: + return fix_case(self.path.stem.replace("_", " ").title()) + + def generate(self) -> str: + content = f"---\ntitle: {self.title}\n---\n\n" + content += f"Source .\n\n" + + # Use long code fence to avoid issues with + # included files containing code fences too + code_fence = "``````" + is_code = self.main_file.suffix != ".md" + if is_code: + content += f"{code_fence}{self.main_file.suffix[1:]}\n" + content += f'--8<-- "{self.main_file}"\n' + if is_code: + content += f"{code_fence}\n" + content += "\n" + + if not self.other_files: + return content + + content += "## Example materials\n\n" + for file in sorted(self.other_files): + content += f'??? abstract "{file.relative_to(self.path)}"\n' + if file.suffix != ".md": + content += f" {code_fence}{file.suffix[1:]}\n" + content += f' --8<-- "{file}"\n' + if file.suffix != ".md": + content += f" {code_fence}\n" + + return content + + +def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): + # Create the EXAMPLE_DOC_DIR if it doesn't exist + if not EXAMPLE_DOC_DIR.exists(): + EXAMPLE_DOC_DIR.mkdir(parents=True) + + categories = sorted(p for p in EXAMPLE_DIR.iterdir() if p.is_dir()) + + examples = [] + glob_patterns = ["*.py", "*.md", "*.sh"] + # Find categorised examples + for category in categories: + globs = [category.glob(pattern) for pattern in glob_patterns] + for path in itertools.chain(*globs): + examples.append(Example(path, category.stem)) + # Find examples in subdirectories + for path in category.glob("*/*.md"): + examples.append(Example(path.parent, category.stem)) + + # Generate the example documentation + for example in sorted(examples, key=lambda e: e.path.stem): + example_name = f"{example.path.stem}.md" + doc_path = EXAMPLE_DOC_DIR / example.category / example_name + print(doc_path) + if not doc_path.parent.exists(): + doc_path.parent.mkdir(parents=True) + with open(doc_path, "w+") as f: + f.write(example.generate()) diff --git a/docs/mkdocs/hooks/remove_announcement.py b/docs/mkdocs/hooks/remove_announcement.py new file mode 100644 index 000000000000..e5f8549d8383 --- /dev/null +++ b/docs/mkdocs/hooks/remove_announcement.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Literal + + +def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): + # see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa + if os.getenv('READTHEDOCS_VERSION_TYPE') == "tag": + # remove the warning banner if the version is a tagged release + docs_dir = os.path.dirname(__file__) + announcement_path = os.path.join(docs_dir, + "mkdocs/overrides/main.html") + # The file might be removed already if the build is triggered multiple + # times (readthedocs build both HTML and PDF versions separately) + if os.path.exists(announcement_path): + os.remove(announcement_path) diff --git a/docs/mkdocs/hooks/url_schemes.py b/docs/mkdocs/hooks/url_schemes.py new file mode 100644 index 000000000000..c738828085ba --- /dev/null +++ b/docs/mkdocs/hooks/url_schemes.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +import regex as re +from mkdocs.config.defaults import MkDocsConfig +from mkdocs.structure.files import Files +from mkdocs.structure.pages import Page + + +def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, + files: Files): + gh_icon = ":octicons-mark-github-16:" + gh_url = "https://github.com" + repo_url = f"{gh_url}/vllm-project/vllm" + org_url = f"{gh_url}/orgs/vllm-project" + urls = { + "issue": f"{repo_url}/issues", + "pr": f"{repo_url}/pull", + "project": f"{org_url}/projects", + "dir": f"{repo_url}/tree/main", + "file": f"{repo_url}/blob/main", + } + titles = { + "issue": "Issue #", + "pr": "Pull Request #", + "project": "Project #", + "dir": "", + "file": "", + } + + scheme = r"gh-(?P.+?):(?P.+?)(#(?P.+?))?" + inline_link = re.compile(r"\[(?P[^\[]+?)\]\(" + scheme + r"\)") + auto_link = re.compile(f"<{scheme}>") + + def replace_inline_link(match: re.Match) -> str: + url = f'{urls[match.group("type")]}/{match.group("path")}' + if fragment := match.group("fragment"): + url += f"#{fragment}" + + return f'[{gh_icon} {match.group("title")}]({url})' + + def replace_auto_link(match: re.Match) -> str: + type = match.group("type") + path = match.group("path") + title = f"{titles[type]}{path}" + url = f"{urls[type]}/{path}" + if fragment := match.group("fragment"): + url += f"#{fragment}" + + return f"[{gh_icon} {title}]({url})" + + markdown = inline_link.sub(replace_inline_link, markdown) + markdown = auto_link.sub(replace_auto_link, markdown) + + return markdown diff --git a/docs/source/_static/custom.js b/docs/mkdocs/javascript/run_llm_widget.js similarity index 54% rename from docs/source/_static/custom.js rename to docs/mkdocs/javascript/run_llm_widget.js index 58bc2ebb9614..d0e5560e92b4 100644 --- a/docs/source/_static/custom.js +++ b/docs/mkdocs/javascript/run_llm_widget.js @@ -17,22 +17,3 @@ document.addEventListener("DOMContentLoaded", function () { script.async = true; document.head.appendChild(script); }); - -// Update URL search params when tab is clicked - document.addEventListener("DOMContentLoaded", function () { - const tabs = document.querySelectorAll(".sd-tab-label"); - - function updateURL(tab) { - const syncGroup = tab.getAttribute("data-sync-group"); - const syncId = tab.getAttribute("data-sync-id"); - if (syncGroup && syncId) { - const url = new URL(window.location); - url.searchParams.set(syncGroup, syncId); - window.history.replaceState(null, "", url); - } - } - - tabs.forEach(tab => { - tab.addEventListener("click", () => updateURL(tab)); - }); -}); diff --git a/docs/mkdocs/overrides/main.html b/docs/mkdocs/overrides/main.html new file mode 100644 index 000000000000..bdd62ebc158d --- /dev/null +++ b/docs/mkdocs/overrides/main.html @@ -0,0 +1,5 @@ +{% extends "base.html" %} + +{% block announce %} + <p>You are viewing the latest developer preview docs. <a href="https://docs.vllm.ai/en/stable/">Click here</a> to view docs for the latest stable release.</p> +{% endblock %} diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css new file mode 100644 index 000000000000..088143ed5956 --- /dev/null +++ b/docs/mkdocs/stylesheets/extra.css @@ -0,0 +1,36 @@ +/* Warning for latest docs */ +.md-banner { + background-color: var(--md-warning-bg-color); + color: var(--md-warning-fg-color); +} + +/* https://christianoliff.com/blog/styling-external-links-with-an-icon-in-css/ */ +a:not(:has(svg)):not(.md-icon):not(.autorefs-external) { + align-items: center; + + &[href^="//"]::after, + &[href^="http://"]::after, + &[href^="https://"]::after { + content: ""; + width: 12px; + height: 12px; + margin-left: 4px; + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='16' height='16' stroke='gray' viewBox='0 0 16 16'%3E%3Cpath fill-rule='evenodd' d='M8.636 3.5a.5.5 0 0 0-.5-.5H1.5A1.5 1.5 0 0 0 0 4.5v10A1.5 1.5 0 0 0 1.5 16h10a1.5 1.5 0 0 0 1.5-1.5V7.864a.5.5 0 0 0-1 0V14.5a.5.5 0 0 1-.5.5h-10a.5.5 0 0 1-.5-.5v-10a.5.5 0 0 1 .5-.5h6.636a.5.5 0 0 0 .5-.5z'/%3E%3Cpath fill-rule='evenodd' d='M16 .5a.5.5 0 0 0-.5-.5h-5a.5.5 0 0 0 0 1h3.793L6.146 9.146a.5.5 0 1 0 .708.708L15 1.707V5.5a.5.5 0 0 0 1 0v-5z'/%3E%3C/svg%3E"); + background-position: center; + background-repeat: no-repeat; + background-size: contain; + display: inline-block; + } +} + +/* Light mode: darker section titles */ +body[data-md-color-scheme="default"] .md-nav__item--section > label.md-nav__link .md-ellipsis { + color: rgba(0, 0, 0, 0.7) !important; + font-weight: 700; +} + +/* Dark mode: lighter gray section titles */ +body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link .md-ellipsis { + color: rgba(255, 255, 255, 0.75) !important; + font-weight: 700; +} diff --git a/docs/source/models/extensions/fastsafetensor.md b/docs/models/extensions/fastsafetensor.md similarity index 100% rename from docs/source/models/extensions/fastsafetensor.md rename to docs/models/extensions/fastsafetensor.md diff --git a/docs/source/models/extensions/runai_model_streamer.md b/docs/models/extensions/runai_model_streamer.md similarity index 61% rename from docs/source/models/extensions/runai_model_streamer.md rename to docs/models/extensions/runai_model_streamer.md index e0daa6f86dde..6755b574ea67 100644 --- a/docs/source/models/extensions/runai_model_streamer.md +++ b/docs/models/extensions/runai_model_streamer.md @@ -1,6 +1,7 @@ -(runai-model-streamer)= - -# Loading models with Run:ai Model Streamer +--- +title: Loading models with Run:ai Model Streamer +--- +[](){ #runai-model-streamer } Run:ai Model Streamer is a library to read tensors in concurrency, while streaming it to GPU memory. Further reading can be found in [Run:ai Model Streamer Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/README.md). @@ -15,19 +16,25 @@ pip3 install vllm[runai] To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag: ```console -vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer +vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \ + --load-format runai_streamer ``` To run model from AWS S3 object store run: ```console -vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer +vllm serve s3://core-llm/Llama-3-8b \ + --load-format runai_streamer ``` To run model from a S3 compatible object store run: ```console -RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING=0 AWS_EC2_METADATA_DISABLED=true AWS_ENDPOINT_URL=https://storage.googleapis.com vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer +RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING=0 \ +AWS_EC2_METADATA_DISABLED=true \ +AWS_ENDPOINT_URL=https://storage.googleapis.com \ +vllm serve s3://core-llm/Llama-3-8b \ + --load-format runai_streamer ``` ## Tunable parameters @@ -38,19 +45,22 @@ You can tune `concurrency` that controls the level of concurrency and number of For reading from S3, it will be the number of client instances the host is opening to the S3 server. ```console -vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"concurrency":16}' +vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \ + --load-format runai_streamer \ + --model-loader-extra-config '{"concurrency":16}' ``` You can control the size of the CPU Memory buffer to which tensors are read from the file, and limit this size. You can read further about CPU buffer memory limiting [here](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md#runai_streamer_memory_limit). ```console -vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"memory_limit":5368709120}' +vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \ + --load-format runai_streamer \ + --model-loader-extra-config '{"memory_limit":5368709120}' ``` -:::{note} -For further instructions about tunable parameters and additional parameters configurable through environment variables, read the [Environment Variables Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md). -::: +!!! note + For further instructions about tunable parameters and additional parameters configurable through environment variables, read the [Environment Variables Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md). ## Sharded Model Loading @@ -63,7 +73,9 @@ vllm serve /path/to/sharded/model --load-format runai_streamer_sharded The sharded loader expects model files to follow the same naming pattern as the regular sharded state loader: `model-rank-{rank}-part-{part}.safetensors`. You can customize this pattern using the `pattern` parameter in `--model-loader-extra-config`: ```console -vllm serve /path/to/sharded/model --load-format runai_streamer_sharded --model-loader-extra-config '{"pattern":"custom-model-rank-{rank}-part-{part}.safetensors"}' +vllm serve /path/to/sharded/model \ + --load-format runai_streamer_sharded \ + --model-loader-extra-config '{"pattern":"custom-model-rank-{rank}-part-{part}.safetensors"}' ``` To create sharded model files, you can use the script provided in <gh-file:examples/offline_inference/save_sharded_state.py>. This script demonstrates how to save a model in the sharded format that is compatible with the Run:ai Model Streamer sharded loader. @@ -71,9 +83,10 @@ To create sharded model files, you can use the script provided in <gh-file:examp The sharded loader supports all the same tunable parameters as the regular Run:ai Model Streamer, including `concurrency` and `memory_limit`. These can be configured in the same way: ```console -vllm serve /path/to/sharded/model --load-format runai_streamer_sharded --model-loader-extra-config '{"concurrency":16, "memory_limit":5368709120}' +vllm serve /path/to/sharded/model \ + --load-format runai_streamer_sharded \ + --model-loader-extra-config '{"concurrency":16, "memory_limit":5368709120}' ``` -:::{note} -The sharded loader is particularly efficient for tensor or pipeline parallel models where each worker only needs to read its own shard rather than the entire checkpoint. -::: +!!! note + The sharded loader is particularly efficient for tensor or pipeline parallel models where each worker only needs to read its own shard rather than the entire checkpoint. diff --git a/docs/source/models/extensions/tensorizer.md b/docs/models/extensions/tensorizer.md similarity index 69% rename from docs/source/models/extensions/tensorizer.md rename to docs/models/extensions/tensorizer.md index cd94c81e620a..b6feb405c6ca 100644 --- a/docs/source/models/extensions/tensorizer.md +++ b/docs/models/extensions/tensorizer.md @@ -1,6 +1,7 @@ -(tensorizer)= - -# Loading models with CoreWeave's Tensorizer +--- +title: Loading models with CoreWeave's Tensorizer +--- +[](){ #tensorizer } vLLM supports loading models with [CoreWeave's Tensorizer](https://docs.coreweave.com/coreweave-machine-learning-and-ai/inference/tensorizer). vLLM model tensors that have been serialized to disk, an HTTP/HTTPS endpoint, or S3 endpoint can be deserialized @@ -9,8 +10,7 @@ shorter Pod startup times and CPU memory usage. Tensor encryption is also suppor For more information on CoreWeave's Tensorizer, please refer to [CoreWeave's Tensorizer documentation](https://github.com/coreweave/tensorizer). For more information on serializing a vLLM model, as well a general usage guide to using Tensorizer with vLLM, see -the [vLLM example script](https://docs.vllm.ai/en/latest/getting_started/examples/tensorize_vllm_model.html). +the [vLLM example script](https://docs.vllm.ai/en/latest/examples/tensorize_vllm_model.html). -:::{note} -Note that to use this feature you will need to install `tensorizer` by running `pip install vllm[tensorizer]`. -::: +!!! note + Note that to use this feature you will need to install `tensorizer` by running `pip install vllm[tensorizer]`. diff --git a/docs/source/models/generative_models.md b/docs/models/generative_models.md similarity index 63% rename from docs/source/models/generative_models.md rename to docs/models/generative_models.md index dd765e4a9765..566b1c29fca9 100644 --- a/docs/source/models/generative_models.md +++ b/docs/models/generative_models.md @@ -1,24 +1,25 @@ -(generative-models)= - -# Generative Models +--- +title: Generative Models +--- +[](){ #generative-models } vLLM provides first-class support for generative models, which covers most of LLMs. -In vLLM, generative models implement the {class}`~vllm.model_executor.models.VllmModelForTextGeneration` interface. +In vLLM, generative models implement the [VllmModelForTextGeneration][vllm.model_executor.models.VllmModelForTextGeneration] interface. Based on the final hidden states of the input, these models output log probabilities of the tokens to generate, -which are then passed through {class}`~vllm.model_executor.layers.Sampler` to obtain the final text. +which are then passed through [Sampler][vllm.model_executor.layers.Sampler] to obtain the final text. For generative models, the only supported `--task` option is `"generate"`. Usually, this is automatically inferred so you don't have to specify it. ## Offline Inference -The {class}`~vllm.LLM` class provides various methods for offline inference. -See <project:#configuration> for a list of options when initializing the model. +The [LLM][vllm.LLM] class provides various methods for offline inference. +See [configuration][configuration] for a list of options when initializing the model. ### `LLM.generate` -The {class}`~vllm.LLM.generate` method is available to all generative models in vLLM. +The [generate][vllm.LLM.generate] method is available to all generative models in vLLM. It is similar to [its counterpart in HF Transformers](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate), except that tokenization and detokenization are also performed automatically. @@ -34,7 +35,7 @@ for output in outputs: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -You can optionally control the language generation by passing {class}`~vllm.SamplingParams`. +You can optionally control the language generation by passing [SamplingParams][vllm.SamplingParams]. For example, you can use greedy sampling by setting `temperature=0`: ```python @@ -50,16 +51,15 @@ for output in outputs: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -:::{important} -By default, vLLM will use sampling parameters recommended by model creator by applying the `generation_config.json` from the huggingface model repository if it exists. In most cases, this will provide you with the best results by default if {class}`~vllm.SamplingParams` is not specified. +!!! warning + By default, vLLM will use sampling parameters recommended by model creator by applying the `generation_config.json` from the huggingface model repository if it exists. In most cases, this will provide you with the best results by default if [SamplingParams][vllm.SamplingParams] is not specified. -However, if vLLM's default sampling parameters are preferred, please pass `generation_config="vllm"` when creating the {class}`~vllm.LLM` instance. -::: + However, if vLLM's default sampling parameters are preferred, please pass `generation_config="vllm"` when creating the [LLM][vllm.LLM] instance. A code example can be found here: <gh-file:examples/offline_inference/basic/basic.py> ### `LLM.beam_search` -The {class}`~vllm.LLM.beam_search` method implements [beam search](https://huggingface.co/docs/transformers/en/generation_strategies#beam-search) on top of {class}`~vllm.LLM.generate`. +The [beam_search][vllm.LLM.beam_search] method implements [beam search](https://huggingface.co/docs/transformers/en/generation_strategies#beam-search) on top of [generate][vllm.LLM.generate]. For example, to search using 5 beams and output at most 50 tokens: ```python @@ -77,14 +77,13 @@ for output in outputs: ### `LLM.chat` -The {class}`~vllm.LLM.chat` method implements chat functionality on top of {class}`~vllm.LLM.generate`. +The [chat][vllm.LLM.chat] method implements chat functionality on top of [generate][vllm.LLM.generate]. In particular, it accepts input similar to [OpenAI Chat Completions API](https://platform.openai.com/docs/api-reference/chat) and automatically applies the model's [chat template](https://huggingface.co/docs/transformers/en/chat_templating) to format the prompt. -:::{important} -In general, only instruction-tuned models have a chat template. -Base models may perform poorly as they are not trained to respond to the chat conversation. -::: +!!! warning + In general, only instruction-tuned models have a chat template. + Base models may perform poorly as they are not trained to respond to the chat conversation. ```python from vllm import LLM @@ -133,7 +132,7 @@ outputs = llm.chat(conversation, chat_template=custom_template) ## Online Serving -Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints that correspond to the offline APIs: +Our [OpenAI-Compatible Server][openai-compatible-server] provides endpoints that correspond to the offline APIs: -- [Completions API](#completions-api) is similar to `LLM.generate` but only accepts text. -- [Chat API](#chat-api) is similar to `LLM.chat`, accepting both text and [multi-modal inputs](#multimodal-inputs) for models with a chat template. +- [Completions API][completions-api] is similar to `LLM.generate` but only accepts text. +- [Chat API][chat-api] is similar to `LLM.chat`, accepting both text and [multi-modal inputs][multimodal-inputs] for models with a chat template. diff --git a/docs/source/models/pooling_models.md b/docs/models/pooling_models.md similarity index 62% rename from docs/source/models/pooling_models.md rename to docs/models/pooling_models.md index 8c8d1832d382..89a128915a76 100644 --- a/docs/source/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -1,70 +1,48 @@ -(pooling-models)= - -# Pooling Models +--- +title: Pooling Models +--- +[](){ #pooling-models } vLLM also supports pooling models, including embedding, reranking and reward models. -In vLLM, pooling models implement the {class}`~vllm.model_executor.models.VllmModelForPooling` interface. -These models use a {class}`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input +In vLLM, pooling models implement the [VllmModelForPooling][vllm.model_executor.models.VllmModelForPooling] interface. +These models use a [Pooler][vllm.model_executor.layers.Pooler] to extract the final hidden states of the input before returning them. -:::{note} -We currently support pooling models primarily as a matter of convenience. -As shown in the [Compatibility Matrix](#compatibility-matrix), most vLLM features are not applicable to -pooling models as they only work on the generation or decode stage, so performance may not improve as much. -::: +!!! note + We currently support pooling models primarily as a matter of convenience. + As shown in the [Compatibility Matrix][compatibility-matrix], most vLLM features are not applicable to + pooling models as they only work on the generation or decode stage, so performance may not improve as much. For pooling models, we support the following `--task` options. The selected option sets the default pooler used to extract the final hidden states: -:::{list-table} -:widths: 50 25 25 25 -:header-rows: 1 - -- * Task - * Pooling Type - * Normalization - * Softmax -- * Embedding (`embed`) - * `LAST` - * โœ…๏ธŽ - * โŒ -- * Classification (`classify`) - * `LAST` - * โŒ - * โœ…๏ธŽ -- * Sentence Pair Scoring (`score`) - * \* - * \* - * \* -- * Reward Modeling (`reward`) - * `ALL` - * โŒ - * โŒ -::: +| Task | Pooling Type | Normalization | Softmax | +|---------------------------------|----------------|-----------------|-----------| +| Embedding (`embed`) | `LAST` | โœ…๏ธŽ | โŒ | +| Classification (`classify`) | `LAST` | โŒ | โœ…๏ธŽ | +| Sentence Pair Scoring (`score`) | \* | \* | \* | \*The default pooler is always defined by the model. -:::{note} -If the model's implementation in vLLM defines its own pooler, the default pooler is set to that instead of the one specified in this table. -::: +!!! note + If the model's implementation in vLLM defines its own pooler, the default pooler is set to that instead of the one specified in this table. When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models, we attempt to override the default pooler based on its Sentence Transformers configuration file (`modules.json`). -:::{tip} -You can customize the model's pooling method via the `--override-pooler-config` option, -which takes priority over both the model's and Sentence Transformers's defaults. -::: +!!! tip + You can customize the model's pooling method via the `--override-pooler-config` option, + which takes priority over both the model's and Sentence Transformers's defaults. ## Offline Inference -The {class}`~vllm.LLM` class provides various methods for offline inference. -See <project:#configuration> for a list of options when initializing the model. +The [LLM][vllm.LLM] class provides various methods for offline inference. +See [configuration][configuration] for a list of options when initializing the model. ### `LLM.encode` -The {class}`~vllm.LLM.encode` method is available to all pooling models in vLLM. +The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. It returns the extracted hidden states directly, which is useful for reward models. ```python @@ -79,7 +57,7 @@ print(f"Data: {data!r}") ### `LLM.embed` -The {class}`~vllm.LLM.embed` method outputs an embedding vector for each prompt. +The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt. It is primarily designed for embedding models. ```python @@ -96,7 +74,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/embe ### `LLM.classify` -The {class}`~vllm.LLM.classify` method outputs a probability vector for each prompt. +The [classify][vllm.LLM.classify] method outputs a probability vector for each prompt. It is primarily designed for classification models. ```python @@ -113,13 +91,12 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas ### `LLM.score` -The {class}`~vllm.LLM.score` method outputs similarity scores between sentence pairs. +The [score][vllm.LLM.score] method outputs similarity scores between sentence pairs. It is designed for embedding models and cross encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems. -:::{note} -vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. -To handle RAG at a higher level, you should use integration frameworks such as [LangChain](https://github.com/langchain-ai/langchain). -::: +!!! note + vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. + To handle RAG at a higher level, you should use integration frameworks such as [LangChain](https://github.com/langchain-ai/langchain). ```python from vllm import LLM @@ -136,26 +113,25 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/scor ## Online Serving -Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints that correspond to the offline APIs: +Our [OpenAI-Compatible Server][openai-compatible-server] provides endpoints that correspond to the offline APIs: -- [Pooling API](#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models. -- [Embeddings API](#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](#multimodal-inputs) for embedding models. -- [Score API](#score-api) is similar to `LLM.score` for cross-encoder models. +- [Pooling API][pooling-api] is similar to `LLM.encode`, being applicable to all types of pooling models. +- [Embeddings API][embeddings-api] is similar to `LLM.embed`, accepting both text and [multi-modal inputs][multimodal-inputs] for embedding models. +- [Classification API][classification-api] is similar to `LLM.classify` and is applicable to sequence classification models. +- [Score API][score-api] is similar to `LLM.score` for cross-encoder models. ## Matryoshka Embeddings [Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost. -:::{warning} -Not all embedding models are trained using Matryoshka Representation Learning. To avoid misuse of the `dimensions` parameter, vLLM returns an error for requests that attempt to change the output dimension of models that do not support Matryoshka Embeddings. - -For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model will result in the following error. +!!! warning + Not all embedding models are trained using Matryoshka Representation Learning. To avoid misuse of the `dimensions` parameter, vLLM returns an error for requests that attempt to change the output dimension of models that do not support Matryoshka Embeddings. -```json -{"object":"error","message":"Model \"BAAI/bge-m3\" does not support matryoshka representation, changing output dimensions will lead to poor results.","type":"BadRequestError","param":null,"code":400} -``` + For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model will result in the following error. -::: + ```json + {"object":"error","message":"Model \"BAAI/bge-m3\" does not support matryoshka representation, changing output dimensions will lead to poor results.","type":"BadRequestError","param":null,"code":400} + ``` ### Manually enable Matryoshka Embeddings @@ -171,7 +147,7 @@ vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_ ### Offline Inference -You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter in {class}`~vllm.PoolingParams`. +You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter in [PoolingParams][vllm.PoolingParams]. ```python from vllm import LLM, PoolingParams diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md new file mode 100644 index 000000000000..7594c6e6fbf1 --- /dev/null +++ b/docs/models/supported_models.md @@ -0,0 +1,690 @@ +--- +title: Supported Models +--- +[](){ #supported-models } + +vLLM supports [generative](./generative_models.md) and [pooling](./pooling_models.md) models across various tasks. +If a model supports more than one task, you can set the task via the `--task` argument. + +For each task, we list the model architectures that have been implemented in vLLM. +Alongside each architecture, we include some popular models that use it. + +## Model Implementation + +### vLLM + +If vLLM natively supports a model, its implementation can be found in <gh-file:vllm/model_executor/models>. + +These models are what we list in [supported-text-models][supported-text-models] and [supported-mm-models][supported-mm-models]. + +[](){ #transformers-backend } + +### Transformers + +vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned! + +To check if the modeling backend is Transformers, you can simply do this: + +```python +from vllm import LLM +llm = LLM(model=..., task="generate") # Name or path of your model +llm.apply_model(lambda model: print(type(model))) +``` + +If it is `TransformersForCausalLM` then it means it's based on Transformers! + +!!! tip + You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference][offline-inference] or `--model-impl transformers` for the [openai-compatible-server][openai-compatible-server]. + +!!! note + vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. + +#### Custom models + +If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! + +For a model to be compatible with the Transformers backend for vLLM it must: + +- be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)): + * The model directory must have the correct structure (e.g. `config.json` is present). + * `config.json` must contain `auto_map.AutoModel`. +- be a Transformers backend for vLLM compatible model (see [writing-custom-models][writing-custom-models]): + * Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`). + +If the compatible model is: + +- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for [offline-inference][offline-inference] or `--trust-remote-code` for the [openai-compatible-server][openai-compatible-server]. +- in a local directory, simply pass directory path to `model=<MODEL_DIR>` for [offline-inference][offline-inference] or `vllm serve <MODEL_DIR>` for the [openai-compatible-server][openai-compatible-server]. + +This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! + +[](){ #writing-custom-models } + +#### Writing custom models + +This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)). + +To make your model compatible with the Transformers backend, it needs: + +1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. +2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention. +3. `MyModel` must contain `_supports_attention_backend = True`. + +```python title="modeling_my_model.py" + +from transformers import PreTrainedModel +from torch import nn + +class MyAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): + ... + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) + ... + +class MyModel(PreTrainedModel): + _supports_attention_backend = True +``` + +Here is what happens in the background when this model is loaded: + +1. The config is loaded. +2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. +3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. + +That's it! + +For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class: + +```python title="configuration_my_model.py" + +from transformers import PretrainedConfig + +class MyConfig(PretrainedConfig): + base_model_tp_plan = { + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } +``` + +- `base_model_tp_plan` is a `dict` that maps fully qualified layer name patterns to tensor parallel styles (currently only `"colwise"` and `"rowwise"` are supported). +- `base_model_pp_plan` is a `dict` that maps direct child layer names to `tuple`s of `list`s of `str`s: + * You only need to do this for layers which are not present on all pipeline stages + * vLLM assumes that there will be only one `nn.ModuleList`, which is distributed across the pipeline stages + * The `list` in the first element of the `tuple` contains the names of the input arguments + * The `list` in the last element of the `tuple` contains the names of the variables the layer outputs to in your modeling code + +## Loading a Model + +### Hugging Face Hub + +By default, vLLM loads models from [Hugging Face (HF) Hub](https://huggingface.co/models). To change the download path for models, you can set the `HF_HOME` environment variable; for more details, refer to [their official documentation](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome). + +To determine whether a given model is natively supported, you can check the `config.json` file inside the HF repository. +If the `"architectures"` field contains a model architecture listed below, then it should be natively supported. + +Models do not _need_ to be natively supported to be used in vLLM. +The [Transformers backend][transformers-backend] enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!). + +!!! tip + The easiest way to check if your model is really supported at runtime is to run the program below: + + ```python + from vllm import LLM + + # For generative models (task=generate) only + llm = LLM(model=..., task="generate") # Name or path of your model + output = llm.generate("Hello, my name is") + print(output) + + # For pooling models (task={embed,classify,reward,score}) only + llm = LLM(model=..., task="embed") # Name or path of your model + output = llm.encode("Hello, my name is") + print(output) + ``` + + If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported. + +Otherwise, please refer to [Adding a New Model][new-model] for instructions on how to implement your model in vLLM. +Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support. + +#### Download a model + +If you prefer, you can use the Hugging Face CLI to [download a model](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-download) or specific files from a model repository: + +```console +# Download a model +huggingface-cli download HuggingFaceH4/zephyr-7b-beta + +# Specify a custom cache directory +huggingface-cli download HuggingFaceH4/zephyr-7b-beta --cache-dir ./path/to/cache + +# Download a specific file from a model repo +huggingface-cli download HuggingFaceH4/zephyr-7b-beta eval_results.json +``` + +#### List the downloaded models + +Use the Hugging Face CLI to [manage models](https://huggingface.co/docs/huggingface_hub/guides/manage-cache#scan-your-cache) stored in local cache: + +```console +# List cached models +huggingface-cli scan-cache + +# Show detailed (verbose) output +huggingface-cli scan-cache -v + +# Specify a custom cache directory +huggingface-cli scan-cache --dir ~/.cache/huggingface/hub +``` + +#### Delete a cached model + +Use the Hugging Face CLI to interactively [delete downloaded model](https://huggingface.co/docs/huggingface_hub/guides/manage-cache#clean-your-cache) from the cache: + +```console +# The `delete-cache` command requires extra dependencies to work with the TUI. +# Please run `pip install huggingface_hub[cli]` to install them. + +# Launch the interactive TUI to select models to delete +$ huggingface-cli delete-cache +? Select revisions to delete: 1 revisions selected counting for 438.9M. + โ—‹ None of the following (if selected, nothing will be deleted). +Model BAAI/bge-base-en-v1.5 (438.9M, used 1 week ago) +โฏ โ—‰ a5beb1e3: main # modified 1 week ago + +Model BAAI/bge-large-en-v1.5 (1.3G, used 1 week ago) + โ—‹ d4aa6901: main # modified 1 week ago + +Model BAAI/bge-reranker-base (1.1G, used 4 weeks ago) + โ—‹ 2cfc18c9: main # modified 4 weeks ago + +Press <space> to select, <enter> to validate and <ctrl+c> to quit without modification. + +# Need to confirm after selected +? Select revisions to delete: 1 revision(s) selected. +? 1 revisions selected counting for 438.9M. Confirm deletion ? Yes +Start deletion. +Done. Deleted 1 repo(s) and 0 revision(s) for a total of 438.9M. +``` + +#### Using a proxy + +Here are some tips for loading/downloading models from Hugging Face using a proxy: + +- Set the proxy globally for your session (or set it in the profile file): + +```shell +export http_proxy=http://your.proxy.server:port +export https_proxy=http://your.proxy.server:port +``` + +- Set the proxy for just the current command: + +```shell +https_proxy=http://your.proxy.server:port huggingface-cli download <model_name> + +# or use vllm cmd directly +https_proxy=http://your.proxy.server:port vllm serve <model_name> --disable-log-requests +``` + +- Set the proxy in Python interpreter: + +```python +import os + +os.environ['http_proxy'] = 'http://your.proxy.server:port' +os.environ['https_proxy'] = 'http://your.proxy.server:port' +``` + +### ModelScope + +To use models from [ModelScope](https://www.modelscope.cn) instead of Hugging Face Hub, set an environment variable: + +```shell +export VLLM_USE_MODELSCOPE=True +``` + +And use with `trust_remote_code=True`. + +```python +from vllm import LLM + +llm = LLM(model=..., revision=..., task=..., trust_remote_code=True) + +# For generative models (task=generate) only +output = llm.generate("Hello, my name is") +print(output) + +# For pooling models (task={embed,classify,reward,score}) only +output = llm.encode("Hello, my name is") +print(output) +``` + +[](){ #feature-status-legend } + +## Feature Status Legend + +- โœ…๏ธŽ indicates that the feature is supported for the model. + +- ๐Ÿšง indicates that the feature is planned but not yet supported for the model. + +- โš ๏ธ indicates that the feature is available but may have known issues or limitations. + +[](){ #supported-text-models } + +## List of Text-only Language Models + +### Generative Models + +See [this page][generative-models] for more information on how to use generative models. + +#### Text Generation + +Specified using `--task generate`. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|---------------------------------------------------|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------| +| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | โœ…๏ธŽ | | +| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | | | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | โœ…๏ธŽ | | +| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | +| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | โœ…๏ธŽ | | +| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | โœ…๏ธŽ | | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | โœ…๏ธŽ | | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | โœ…๏ธŽ | | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | โœ…๏ธŽ | | +| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | โœ…๏ธŽ | | +| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | โœ…๏ธŽ | | +| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | โœ…๏ธŽ | | +| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | โœ…๏ธŽ | | +| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | โœ…๏ธŽ | โœ…๏ธŽ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | โœ…๏ธŽ | โœ…๏ธŽ | +| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | โœ…๏ธŽ | | +| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | โœ…๏ธŽ | | +| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | โœ…๏ธŽ | | +| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | โœ…๏ธŽ | | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | โœ…๏ธŽ | | +| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | โœ…๏ธŽ | | +| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | โœ…๏ธŽ | | +| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | โœ…๏ธŽ | | +| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | โœ…๏ธŽ | | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | +| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | โœ…๏ธŽ | | +| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | โœ…๏ธŽ | | +| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | โœ…๏ธŽ | | +| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | โœ…๏ธŽ | | +| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | โœ…๏ธŽ | | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | + +!!! note + Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. + +### Pooling Models + +See [this page](./pooling_models.md) for more information on how to use pooling models. + +!!! warning + Since some model architectures support both generative and pooling tasks, + you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. + +#### Text Embedding + +Specified using `--task embed`. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------| +| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | +| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | โœ…๏ธŽ | | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | โœ…๏ธŽ | โœ…๏ธŽ | +| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ๏ธŽ | | +| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ๏ธŽ | ๏ธŽ | +| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ๏ธŽ | ๏ธŽ | +| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ๏ธŽ | ๏ธŽ | +| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | + +!!! note + `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. + You should manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. + +!!! note + For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded. + See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882). + +!!! note + `jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights. + +!!! note + The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture. + +If your model is not in the above list, we will try to automatically convert the model using +[as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model]. By default, the embeddings +of the whole prompt are extracted from the normalized hidden state corresponding to the last token. + +#### Reward Modeling + +Specified using `--task reward`. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|---------------------------|-----------------|------------------------------------------------------------------------|------------------------|-----------------------------| +| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | +| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | + +If your model is not in the above list, we will try to automatically convert the model using +[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly. + +!!! warning + For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, + e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + +#### Classification + +Specified using `--task classify`. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|----------------------------------|----------|----------------------------------------|------------------------|-----------------------------| +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | + +If your model is not in the above list, we will try to automatically convert the model using +[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. + +#### Sentence Pair Scoring + +Specified using `--task score`. + +| Architecture | Models | Example HF Models | +|---------------------------------------|-------------------|----------------------------------------------| +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | + +[](){ #supported-mm-models } + +## List of Multimodal Language Models + +The following modalities are supported depending on the model: + +- **T**ext +- **I**mage +- **V**ideo +- **A**udio + +Any combination of modalities joined by `+` are supported. + +- e.g.: `T + I` means that the model supports text-only, image-only, and text-with-image inputs. + +On the other hand, modalities separated by `/` are mutually exclusive. + +- e.g.: `T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. + +See [this page][multimodal-inputs] on how to pass multi-modal inputs to the model. + +!!! warning + **To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference) + or `--limit-mm-per-prompt` (online serving). For example, to enable passing up to 4 images per text prompt: + + Offline inference: + + ```python + from vllm import LLM + + llm = LLM( + model="Qwen/Qwen2-VL-7B-Instruct", + limit_mm_per_prompt={"image": 4}, + ) + ``` + + Online serving: + + ```bash + vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}' + ``` + + **This is no longer required if you are using vLLM V1.** + +!!! note + vLLM currently only supports adding LoRA to the language backbone of multimodal models. + +### Generative Models + +See [this page][generative-models] for more information on how to use generative models. + +#### Text Generation + +Specified using `--task generate`. + +| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|----------------------------------------------|--------------------------------------------------------------------------|-----------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| +| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | โœ…๏ธŽ | โœ…๏ธŽ | | +| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. | | | | +| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โš ๏ธ | +| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ\* | | +| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | โœ…๏ธŽ | | | +| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `Mistral3ForConditionalGeneration` | Mistral3 | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | +| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | โœ…๏ธŽ | | | +| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | โœ…๏ธŽ | โš ๏ธ | | +| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `PixtralForConditionalGeneration` | Pixtral | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | | +| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | โœ…๏ธŽ | โœ…๏ธŽ | | +| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | โœ…๏ธŽ | โœ…๏ธŽ\* | | +| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | โœ…๏ธŽ | โœ…๏ธŽ | | +| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | โœ…๏ธŽ | โœ…๏ธŽ | | + +<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM. +    โ€ข For example, to use DeepSeek-VL2 series models: +      `--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` +<sup>E</sup> Pre-computed embeddings can be inputted for this modality. +<sup>+</sup> Multiple items can be inputted per text prompt for this modality. + +!!! warning + Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. + However, there are differences in how they handle text + image inputs: + + V0 correctly implements the model's attention pattern: + - Uses bidirectional attention between the image tokens corresponding to the same image + - Uses causal attention for other tokens + - Implemented via (naive) PyTorch SDPA with masking tensors + - Note: May use significant memory for long prompts with image + + V1 currently uses a simplified attention pattern: + - Uses causal attention for all tokens, including image tokens + - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` + - Will be updated in the future to support the correct behavior + + This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. + +!!! note + Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. + +!!! note + `h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80. + +!!! note + To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. + +!!! warning + The output quality of `AllenAI/Molmo-7B-D-0924` (especially in object localization tasks) has deteriorated in recent updates. + + For the best results, we recommend using the following dependency versions (tested on A10 and L40): + + ```text + # Core vLLM-compatible dependencies with Molmo accuracy setup (tested on L40) + torch==2.5.1 + torchvision==0.20.1 + transformers==4.48.1 + tokenizers==0.21.0 + tiktoken==0.7.0 + vllm==0.7.0 + + # Optional but recommended for improved performance and stability + triton==3.1.0 + xformers==0.0.28.post3 + uvloop==0.21.0 + protobuf==5.29.3 + openai==1.60.2 + opencv-python-headless==4.11.0.86 + pillow==10.4.0 + + # Installed FlashAttention (for float16 only) + flash-attn>=2.5.6 # Not used in float32, but should be documented + ``` + + **Note:** Make sure you understand the security implications of using outdated packages. + +!!! note + The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. + For more details, please see: <gh-pr:4087#issuecomment-2250397630> + +!!! warning + Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. + +!!! note + To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via + `pip install git+https://github.com/huggingface/transformers.git`. + + Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. + `--mm-processor-kwargs '{"use_audio_in_video": true}'`. + +### Pooling Models + +See [this page](./pooling_models.md) for more information on how to use pooling models. + +!!! warning + Since some model architectures support both generative and pooling tasks, + you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. + +#### Text Embedding + +Specified using `--task embed`. + +Any text generation model can be converted into an embedding model by passing `--task embed`. + +!!! note + To get the best results, you should use pooling models that are specifically trained as such. + +The following table lists those that are tested in vLLM. + +| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------| +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | โœ…๏ธŽ | | +| `Phi3VForCausalLM` | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | ๐Ÿšง | โœ…๏ธŽ | + +#### Transcription + +Specified using `--task transcription`. + +Speech2Text models trained specifically for Automatic Speech Recognition. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|----------------|----------|---------------------|------------------------|-----------------------------| + +--- + +## Model Support Policy + +At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Hereโ€™s how we manage third-party model support: + +1. **Community-Driven Support**: We encourage community contributions for adding new models. When a user requests support for a new model, we welcome pull requests (PRs) from the community. These contributions are evaluated primarily on the sensibility of the output they generate, rather than strict consistency with existing implementations such as those in transformers. **Call for contribution:** PRs coming directly from model vendors are greatly appreciated! + +2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results. + + !!! tip + When comparing the output of `model.generate` from Hugging Face Transformers with the output of `llm.generate` from vLLM, note that the former reads the model's generation config file (i.e., [generation_config.json](https://github.com/huggingface/transformers/blob/19dabe96362803fb0a9ae7073d03533966598b17/src/transformers/generation/utils.py#L1945)) and applies the default parameters for generation, while the latter only uses the parameters passed to the function. Ensure all sampling parameters are identical when comparing outputs. + +3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback. + +4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use. + +5. **Selective Focus**: Our resources are primarily directed towards models with significant user interest and impact. Models that are less frequently used may receive less attention, and we rely on the community to play a more active role in their upkeep and improvement. + +Through this approach, vLLM fosters a collaborative environment where both the core development team and the broader community contribute to the robustness and diversity of the third-party models supported in our ecosystem. + +Note that, as an inference engine, vLLM does not introduce new models. Therefore, all models supported by vLLM are third-party models in this regard. + +We have the following levels of testing for models: + +1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. +2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. +3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/docs/seed_parameter_behavior.md b/docs/seed_parameter_behavior.md deleted file mode 100644 index ff17525cf8e2..000000000000 --- a/docs/seed_parameter_behavior.md +++ /dev/null @@ -1,51 +0,0 @@ -# Seed Parameter Behavior in vLLM - -## Overview - -The `seed` parameter in vLLM is used to control the random states for various random number generators. This parameter can affect the behavior of random operations in user code, especially when working with models in vLLM. - -## Default Behavior - -By default, the `seed` parameter is set to `None`. When the `seed` parameter is `None`, the global random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that the random operations will behave as expected, without any fixed random states. - -## Specifying a Seed - -If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly. This can be useful for reproducibility, as it ensures that the random operations produce the same results across multiple runs. - -## Example Usage - -### Without Specifying a Seed - -```python -import random -from vllm import LLM - -# Initialize a vLLM model without specifying a seed -model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct") - -# Try generating random numbers -print(random.randint(0, 100)) # Outputs different numbers across runs -``` - -### Specifying a Seed - -```python -import random -from vllm import LLM - -# Initialize a vLLM model with a specific seed -model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", seed=42) - -# Try generating random numbers -print(random.randint(0, 100)) # Outputs the same number across runs -``` - -## Important Notes - -- If the `seed` parameter is not specified, the behavior of global random states remains unaffected. -- If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set to that value. -- This behavior can be useful for reproducibility but may lead to non-intuitive behavior if the user is not explicitly aware of it. - -## Conclusion - -Understanding the behavior of the `seed` parameter in vLLM is crucial for ensuring the expected behavior of random operations in your code. By default, the `seed` parameter is set to `None`, which means that the global random states are not affected. However, specifying a seed value can help achieve reproducibility in your experiments. diff --git a/docs/source/serving/distributed_serving.md b/docs/serving/distributed_serving.md similarity index 73% rename from docs/source/serving/distributed_serving.md rename to docs/serving/distributed_serving.md index c285ef3e8e1c..259af5cabcb8 100644 --- a/docs/source/serving/distributed_serving.md +++ b/docs/serving/distributed_serving.md @@ -1,6 +1,7 @@ -(distributed-serving)= - -# Distributed Inference and Serving +--- +title: Distributed Inference and Serving +--- +[](){ #distributed-serving } ## How to decide the distributed inference strategy? @@ -14,9 +15,8 @@ In short, you should increase the number of GPUs and the number of nodes until y After adding enough GPUs and nodes to hold the model, you can run vLLM first, which will print some logs like `# GPU blocks: 790`. Multiply the number by `16` (the block size), and you can get roughly the maximum number of tokens that can be served on the current configuration. If this number is not satisfying, e.g. you want higher throughput, you can further increase the number of GPUs or nodes, until the number of blocks is enough. -:::{note} -There is one edge case: if the model fits in a single node with multiple GPUs, but the number of GPUs cannot divide the model size evenly, you can use pipeline parallelism, which splits the model along layers and supports uneven splits. In this case, the tensor parallel size should be 1 and the pipeline parallel size should be the number of GPUs. -::: +!!! note + There is one edge case: if the model fits in a single node with multiple GPUs, but the number of GPUs cannot divide the model size evenly, you can use pipeline parallelism, which splits the model along layers and supports uneven splits. In this case, the tensor parallel size should be 1 and the pipeline parallel size should be the number of GPUs. ## Running vLLM on a single node @@ -77,13 +77,11 @@ bash run_cluster.sh \ Then you get a ray cluster of **containers**. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster. In addition, please note that the argument `ip_of_head_node` should be the IP address of the head node, which is accessible by all the worker nodes. The IP addresses of each worker node should be specified in the `VLLM_HOST_IP` environment variable, and should be different for each worker node. Please check the network configuration of your cluster to make sure the nodes can communicate with each other through the specified IP addresses. -:::{warning} -It is considered best practice to set `VLLM_HOST_IP` to an address on a private network segment for the vLLM cluster. The traffic sent here is not encrypted. The endpoints are also exchanging data in a format that could be exploited to execute arbitrary code should a malicious party gain access to the network. Please ensure that this network is not reachable by any untrusted parties. -::: +!!! warning + It is considered best practice to set `VLLM_HOST_IP` to an address on a private network segment for the vLLM cluster. The traffic sent here is not encrypted. The endpoints are also exchanging data in a format that could be exploited to execute arbitrary code should a malicious party gain access to the network. Please ensure that this network is not reachable by any untrusted parties. -:::{warning} -Since this is a ray cluster of **containers**, all the following commands should be executed in the **containers**, otherwise you are executing the commands on the host machine, which is not connected to the ray cluster. To enter the container, you can use `docker exec -it node /bin/bash`. -::: +!!! warning + Since this is a ray cluster of **containers**, all the following commands should be executed in the **containers**, otherwise you are executing the commands on the host machine, which is not connected to the ray cluster. To enter the container, you can use `docker exec -it node /bin/bash`. Then, on any node, use `docker exec -it node /bin/bash` to enter the container, execute `ray status` and `ray list nodes` to check the status of the Ray cluster. You should see the right number of nodes and GPUs. @@ -104,16 +102,13 @@ vllm serve /path/to/the/model/in/the/container \ To make tensor parallel performant, you should make sure the communication between nodes is efficient, e.g. using high-speed network cards like Infiniband. To correctly set up the cluster to use Infiniband, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the `run_cluster.sh` script. Please contact your system administrator for more information on how to set up the flags. One way to confirm if the Infiniband is working is to run vLLM with `NCCL_DEBUG=TRACE` environment variable set, e.g. `NCCL_DEBUG=TRACE vllm serve ...` and check the logs for the NCCL version and the network used. If you find `[send] via NET/Socket` in the logs, it means NCCL uses raw TCP Socket, which is not efficient for cross-node tensor parallel. If you find `[send] via NET/IB/GDRDMA` in the logs, it means NCCL uses Infiniband with GPU-Direct RDMA, which is efficient. -:::{warning} -After you start the Ray cluster, you'd better also check the GPU-GPU communication between nodes. It can be non-trivial to set up. Please refer to the [sanity check script](#troubleshooting-incorrect-hardware-driver) for more information. If you need to set some environment variables for the communication configuration, you can append them to the `run_cluster.sh` script, e.g. `-e NCCL_SOCKET_IFNAME=eth0`. Note that setting environment variables in the shell (e.g. `NCCL_SOCKET_IFNAME=eth0 vllm serve ...`) only works for the processes in the same node, not for the processes in the other nodes. Setting environment variables when you create the cluster is the recommended way. See <gh-issue:6803> for more information. -::: +!!! warning + After you start the Ray cluster, you'd better also check the GPU-GPU communication between nodes. It can be non-trivial to set up. Please refer to the [sanity check script][troubleshooting-incorrect-hardware-driver] for more information. If you need to set some environment variables for the communication configuration, you can append them to the `run_cluster.sh` script, e.g. `-e NCCL_SOCKET_IFNAME=eth0`. Note that setting environment variables in the shell (e.g. `NCCL_SOCKET_IFNAME=eth0 vllm serve ...`) only works for the processes in the same node, not for the processes in the other nodes. Setting environment variables when you create the cluster is the recommended way. See <gh-issue:6803> for more information. -:::{warning} -Please make sure you downloaded the model to all the nodes (with the same path), or the model is downloaded to some distributed file system that is accessible by all nodes. +!!! warning + Please make sure you downloaded the model to all the nodes (with the same path), or the model is downloaded to some distributed file system that is accessible by all nodes. -When you use huggingface repo id to refer to the model, you should append your huggingface token to the `run_cluster.sh` script, e.g. `-e HF_TOKEN=`. The recommended way is to download the model first, and then use the path to refer to the model. -::: + When you use huggingface repo id to refer to the model, you should append your huggingface token to the `run_cluster.sh` script, e.g. `-e HF_TOKEN=`. The recommended way is to download the model first, and then use the path to refer to the model. -:::{warning} -If you keep receiving the error message `Error: No available node types can fulfill resource request` but you have enough GPUs in the cluster, chances are your nodes have multiple IP addresses and vLLM cannot find the right one, especially when you are using multi-node inference. Please make sure vLLM and ray use the same IP address. You can set the `VLLM_HOST_IP` environment variable to the right IP address in the `run_cluster.sh` script (different for each node!), and check `ray status` and `ray list nodes` to see the IP address used by Ray. See <gh-issue:7815> for more information. -::: +!!! warning + If you keep receiving the error message `Error: No available node types can fulfill resource request` but you have enough GPUs in the cluster, chances are your nodes have multiple IP addresses and vLLM cannot find the right one, especially when you are using multi-node inference. Please make sure vLLM and ray use the same IP address. You can set the `VLLM_HOST_IP` environment variable to the right IP address in the `run_cluster.sh` script (different for each node!), and check `ray status` and `ray list nodes` to see the IP address used by Ray. See <gh-issue:7815> for more information. diff --git a/docs/source/serving/integrations/langchain.md b/docs/serving/integrations/langchain.md similarity index 93% rename from docs/source/serving/integrations/langchain.md rename to docs/serving/integrations/langchain.md index 03142d23b145..14ea6a044341 100644 --- a/docs/source/serving/integrations/langchain.md +++ b/docs/serving/integrations/langchain.md @@ -1,6 +1,7 @@ -(serving-langchain)= - -# LangChain +--- +title: LangChain +--- +[](){ #serving-langchain } vLLM is also available via [LangChain](https://github.com/langchain-ai/langchain) . diff --git a/docs/source/serving/integrations/llamaindex.md b/docs/serving/integrations/llamaindex.md similarity index 91% rename from docs/source/serving/integrations/llamaindex.md rename to docs/serving/integrations/llamaindex.md index 8c72605202cf..251b7155c556 100644 --- a/docs/source/serving/integrations/llamaindex.md +++ b/docs/serving/integrations/llamaindex.md @@ -1,6 +1,7 @@ -(serving-llamaindex)= - -# LlamaIndex +--- +title: LlamaIndex +--- +[](){ #serving-llamaindex } vLLM is also available via [LlamaIndex](https://github.com/run-llama/llama_index) . diff --git a/docs/serving/offline_inference.md b/docs/serving/offline_inference.md new file mode 100644 index 000000000000..b238199e4144 --- /dev/null +++ b/docs/serving/offline_inference.md @@ -0,0 +1,29 @@ +--- +title: Offline Inference +--- +[](){ #offline-inference } + +You can run vLLM in your own code on a list of prompts. + +The offline API is based on the [LLM][vllm.LLM] class. +To initialize the vLLM engine, create a new instance of `LLM` and specify the model to run. + +For example, the following code downloads the [`facebook/opt-125m`](https://huggingface.co/facebook/opt-125m) model from HuggingFace +and runs it in vLLM using the default configuration. + +```python +from vllm import LLM + +llm = LLM(model="facebook/opt-125m") +``` + +After initializing the `LLM` instance, you can perform model inference using various APIs. +The available APIs depend on the type of model that is being run: + +- [Generative models][generative-models] output logprobs which are sampled from to obtain the final output text. +- [Pooling models][pooling-models] output their hidden states directly. + +Please refer to the above pages for more details about each API. + +!!! info + [API Reference][offline-inference-api] diff --git a/docs/source/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md similarity index 53% rename from docs/source/serving/openai_compatible_server.md rename to docs/serving/openai_compatible_server.md index 34382c87a484..c2e39d029dd5 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -1,13 +1,16 @@ -(openai-compatible-server)= - -# OpenAI-Compatible Server +--- +title: OpenAI-Compatible Server +--- +[](){ #openai-compatible-server } vLLM provides an HTTP server that implements OpenAI's [Completions API](https://platform.openai.com/docs/api-reference/completions), [Chat API](https://platform.openai.com/docs/api-reference/chat), and more! This functionality lets you serve models and interact with them using an HTTP client. -In your terminal, you can [install](../getting_started/installation.md) vLLM, then start the server with the [`vllm serve`](#vllm-serve) command. (You can also use our [Docker](#deployment-docker) image.) +In your terminal, you can [install](../getting_started/installation/README.md) vLLM, then start the server with the [`vllm serve`][serve-args] command. (You can also use our [Docker][deployment-docker] image.) ```bash -vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 +vllm serve NousResearch/Meta-Llama-3-8B-Instruct \ + --dtype auto \ + --api-key token-abc123 ``` To call the server, in your preferred text editor, create a script that uses an HTTP client. Include any messages that you want to send to the model. Then run that script. Below is an example script using the [official OpenAI Python client](https://github.com/openai/openai-python). @@ -20,56 +23,56 @@ client = OpenAI( ) completion = client.chat.completions.create( - model="NousResearch/Meta-Llama-3-8B-Instruct", - messages=[ - {"role": "user", "content": "Hello!"} - ] + model="NousResearch/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "user", "content": "Hello!"} + ] ) print(completion.choices[0].message) ``` -:::{tip} -vLLM supports some parameters that are not supported by OpenAI, `top_k` for example. -You can pass these parameters to vLLM using the OpenAI client in the `extra_body` parameter of your requests, i.e. `extra_body={"top_k": 50}` for `top_k`. -::: +!!! tip + vLLM supports some parameters that are not supported by OpenAI, `top_k` for example. + You can pass these parameters to vLLM using the OpenAI client in the `extra_body` parameter of your requests, i.e. `extra_body={"top_k": 50}` for `top_k`. -:::{important} -By default, the server applies `generation_config.json` from the Hugging Face model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator. +!!! warning + By default, the server applies `generation_config.json` from the Hugging Face model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator. -To disable this behavior, please pass `--generation-config vllm` when launching the server. -::: + To disable this behavior, please pass `--generation-config vllm` when launching the server. ## Supported APIs We currently support the following OpenAI APIs: -- [Completions API](#completions-api) (`/v1/completions`) - - Only applicable to [text generation models](../models/generative_models.md) (`--task generate`). - - *Note: `suffix` parameter is not supported.* -- [Chat Completions API](#chat-api) (`/v1/chat/completions`) - - Only applicable to [text generation models](../models/generative_models.md) (`--task generate`) with a [chat template](#chat-template). - - *Note: `parallel_tool_calls` and `user` parameters are ignored.* -- [Embeddings API](#embeddings-api) (`/v1/embeddings`) - - Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`). -- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`) - - Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`). +- [Completions API][completions-api] (`/v1/completions`) + - Only applicable to [text generation models](../models/generative_models.md) (`--task generate`). + - *Note: `suffix` parameter is not supported.* +- [Chat Completions API][chat-api] (`/v1/chat/completions`) + - Only applicable to [text generation models](../models/generative_models.md) (`--task generate`) with a [chat template][chat-template]. + - *Note: `parallel_tool_calls` and `user` parameters are ignored.* +- [Embeddings API][embeddings-api] (`/v1/embeddings`) + - Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`). +- [Transcriptions API][transcriptions-api] (`/v1/audio/transcriptions`) + - Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`). In addition, we have the following custom APIs: -- [Tokenizer API](#tokenizer-api) (`/tokenize`, `/detokenize`) - - Applicable to any model with a tokenizer. -- [Pooling API](#pooling-api) (`/pooling`) - - Applicable to all [pooling models](../models/pooling_models.md). -- [Score API](#score-api) (`/score`) - - Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`). -- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`) - - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) - - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) - - Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response. - - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). - -(chat-template)= +- [Tokenizer API][tokenizer-api] (`/tokenize`, `/detokenize`) + - Applicable to any model with a tokenizer. +- [Pooling API][pooling-api] (`/pooling`) + - Applicable to all [pooling models](../models/pooling_models.md). +- [Classification API][classification-api] (`/classify`) + - Only applicable to [classification models](../models/pooling_models.md) (`--task classify`). +- [Score API][score-api] (`/score`) + - Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`). +- [Re-rank API][rerank-api] (`/rerank`, `/v1/rerank`, `/v2/rerank`) + - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) + - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) + - Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response. + - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). + +[](){ #chat-template } ## Chat Template @@ -95,10 +98,10 @@ both a `type` and a `text` field. An example is provided below: ```python completion = client.chat.completions.create( - model="NousResearch/Meta-Llama-3-8B-Instruct", - messages=[ - {"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]} - ] + model="NousResearch/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]} + ] ) ``` @@ -109,9 +112,9 @@ request. vLLM provides best-effort support to detect this automatically, which i the detected format, which can be one of: - `"string"`: A string. - - Example: `"Hello world"` + - Example: `"Hello world"` - `"openai"`: A list of dictionaries, similar to OpenAI schema. - - Example: `[{"type": "text", "text": "Hello world!"}]` + - Example: `[{"type": "text", "text": "Hello world!"}]` If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument to override which format to use. @@ -124,13 +127,13 @@ Or directly merge them into the JSON payload if you are using HTTP call directly ```python completion = client.chat.completions.create( - model="NousResearch/Meta-Llama-3-8B-Instruct", - messages=[ - {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} - ], - extra_body={ - "guided_choice": ["positive", "negative"] - } + model="NousResearch/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} + ], + extra_body={ + "guided_choice": ["positive", "negative"] + } ) ``` @@ -146,77 +149,29 @@ with `--enable-request-id-headers`. ```python completion = client.chat.completions.create( - model="NousResearch/Meta-Llama-3-8B-Instruct", - messages=[ - {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} - ], - extra_headers={ - "x-request-id": "sentiment-classification-00001", - } + model="NousResearch/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} + ], + extra_headers={ + "x-request-id": "sentiment-classification-00001", + } ) print(completion._request_id) completion = client.completions.create( - model="NousResearch/Meta-Llama-3-8B-Instruct", - prompt="A robot may not injure a human being", - extra_headers={ - "x-request-id": "completion-test", - } + model="NousResearch/Meta-Llama-3-8B-Instruct", + prompt="A robot may not injure a human being", + extra_headers={ + "x-request-id": "completion-test", + } ) print(completion._request_id) ``` -## CLI Reference - -(vllm-serve)= - -### `vllm serve` - -The `vllm serve` command is used to launch the OpenAI-compatible server. - -:::{tip} -The vast majority of command-line arguments are based on those for offline inference. - -See [here](configuration-options) for some common options. -::: - -:::{argparse} -:module: vllm.entrypoints.openai.cli_args -:func: create_parser_for_docs -:prog: vllm serve -::: - -#### Configuration file - -You can load CLI arguments via a [YAML](https://yaml.org/) config file. -The argument names must be the long form of those outlined [above](#vllm-serve). - -For example: - -```yaml -# config.yaml - -model: meta-llama/Llama-3.1-8B-Instruct -host: "127.0.0.1" -port: 6379 -uvicorn-log-level: "info" -``` - -To use the above config file: - -```bash -vllm serve --config config.yaml -``` - -:::{note} -In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence. -The order of priorities is `command line > config file values > defaults`. -e.g. `vllm serve SOME_MODEL --config config.yaml`, SOME_MODEL takes precedence over `model` in config file. -::: - ## API Reference -(completions-api)= +[](){ #completions-api } ### Completions API @@ -227,23 +182,19 @@ Code example: <gh-file:examples/online_serving/openai_completion_client.py> #### Extra parameters -The following [sampling parameters](#sampling-params) are supported. +The following [sampling parameters][sampling-params] are supported. -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-completion-sampling-params -:end-before: end-completion-sampling-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:completion-sampling-params" +``` The following extra parameters are supported: -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-completion-extra-params -:end-before: end-completion-extra-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:completion-extra-params" +``` -(chat-api)= +[](){ #chat-api } ### Chat API @@ -252,37 +203,33 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai We support both [Vision](https://platform.openai.com/docs/guides/vision)- and [Audio](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in)-related parameters; -see our [Multimodal Inputs](#multimodal-inputs) guide for more information. +see our [Multimodal Inputs][multimodal-inputs] guide for more information. - *Note: `image_url.detail` parameter is not supported.* Code example: <gh-file:examples/online_serving/openai_chat_completion_client.py> #### Extra parameters -The following [sampling parameters](#sampling-params) are supported. +The following [sampling parameters][sampling-params] are supported. -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-chat-completion-sampling-params -:end-before: end-chat-completion-sampling-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-sampling-params" +``` The following extra parameters are supported: -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-chat-completion-extra-params -:end-before: end-chat-completion-extra-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-extra-params" +``` -(embeddings-api)= +[](){ #embeddings-api } ### Embeddings API Our Embeddings API is compatible with [OpenAI's Embeddings API](https://platform.openai.com/docs/api-reference/embeddings); you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. -If the model has a [chat template](#chat-template), you can replace `inputs` with a list of `messages` (same schema as [Chat API](#chat-api)) +If the model has a [chat template][chat-template], you can replace `inputs` with a list of `messages` (same schema as [Chat API][chat-api]) which will be treated as a single prompt to the model. Code example: <gh-file:examples/online_serving/openai_embedding_client.py> @@ -292,138 +239,121 @@ Code example: <gh-file:examples/online_serving/openai_embedding_client.py> You can pass multi-modal inputs to embedding models by defining a custom chat template for the server and passing a list of `messages` in the request. Refer to the examples below for illustration. -:::::{tab-set} -::::{tab-item} VLM2Vec - -To serve the model: +=== "VLM2Vec" -```bash -vllm serve TIGER-Lab/VLM2Vec-Full --task embed \ - --trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja -``` + To serve the model: -:::{important} -Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass `--task embed` -to run this model in embedding mode instead of text generation mode. + ```bash + vllm serve TIGER-Lab/VLM2Vec-Full --task embed \ + --trust-remote-code \ + --max-model-len 4096 \ + --chat-template examples/template_vlm2vec.jinja + ``` -The custom chat template is completely different from the original one for this model, -and can be found here: <gh-file:examples/template_vlm2vec.jinja> -::: + !!! warning + Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass `--task embed` + to run this model in embedding mode instead of text generation mode. -Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library: + The custom chat template is completely different from the original one for this model, + and can be found here: <gh-file:examples/template_vlm2vec.jinja> -```python -import requests - -image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" - -response = requests.post( - "http://localhost:8000/v1/embeddings", - json={ - "model": "TIGER-Lab/VLM2Vec-Full", - "messages": [{ - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": "Represent the given image."}, - ], - }], - "encoding_format": "float", - }, -) -response.raise_for_status() -response_json = response.json() -print("Embedding output:", response_json["data"][0]["embedding"]) -``` + Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library: -:::: + ```python + import requests -::::{tab-item} DSE-Qwen2-MRL + image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" -To serve the model: + response = requests.post( + "http://localhost:8000/v1/embeddings", + json={ + "model": "TIGER-Lab/VLM2Vec-Full", + "messages": [{ + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Represent the given image."}, + ], + }], + "encoding_format": "float", + }, + ) + response.raise_for_status() + response_json = response.json() + print("Embedding output:", response_json["data"][0]["embedding"]) + ``` -```bash -vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embed \ - --trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja -``` +=== "DSE-Qwen2-MRL" -:::{important} -Like with VLM2Vec, we have to explicitly pass `--task embed`. + To serve the model: -Additionally, `MrLight/dse-qwen2-2b-mrl-v1` requires an EOS token for embeddings, which is handled -by a custom chat template: <gh-file:examples/template_dse_qwen2_vl.jinja> -::: + ```bash + vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embed \ + --trust-remote-code \ + --max-model-len 8192 \ + --chat-template examples/template_dse_qwen2_vl.jinja + ``` -:::{important} -`MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code -example below for details. -::: + !!! warning + Like with VLM2Vec, we have to explicitly pass `--task embed`. -:::: + Additionally, `MrLight/dse-qwen2-2b-mrl-v1` requires an EOS token for embeddings, which is handled + by a custom chat template: <gh-file:examples/template_dse_qwen2_vl.jinja> -::::: + !!! warning + `MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code + example below for details. Full example: <gh-file:examples/online_serving/openai_chat_embedding_client_for_multimodal.py> #### Extra parameters -The following [pooling parameters](#pooling-params) are supported. +The following [pooling parameters][pooling-params] are supported. -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-embedding-pooling-params -:end-before: end-embedding-pooling-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:embedding-pooling-params" +``` The following extra parameters are supported by default: -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-embedding-extra-params -:end-before: end-embedding-extra-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:embedding-extra-params" +``` For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead: -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-chat-embedding-extra-params -:end-before: end-chat-embedding-extra-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:chat-embedding-extra-params" +``` -(transcriptions-api)= +[](){ #transcriptions-api } ### Transcriptions API Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. -:::{note} -To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. -::: +!!! note + To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. Code example: <gh-file:examples/online_serving/openai_transcription_client.py> <!-- TODO: api enforced limits + uploading audios --> #### Extra Parameters -The following [sampling parameters](#sampling-params) are supported. +The following [sampling parameters][sampling-params] are supported. -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-transcription-sampling-params -:end-before: end-transcription-sampling-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:transcription-sampling-params" +``` The following extra parameters are supported: -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-transcription-extra-params -:end-before: end-transcription-extra-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params" +``` -(tokenizer-api)= +[](){ #tokenizer-api } ### Tokenizer API @@ -433,17 +363,137 @@ It consists of two endpoints: - `/tokenize` corresponds to calling `tokenizer.encode()`. - `/detokenize` corresponds to calling `tokenizer.decode()`. -(pooling-api)= +[](){ #pooling-api } ### Pooling API Our Pooling API encodes input prompts using a [pooling model](../models/pooling_models.md) and returns the corresponding hidden states. -The input format is the same as [Embeddings API](#embeddings-api), but the output data can contain an arbitrary nested list, not just a 1-D list of floats. +The input format is the same as [Embeddings API][embeddings-api], but the output data can contain an arbitrary nested list, not just a 1-D list of floats. Code example: <gh-file:examples/online_serving/openai_pooling_client.py> -(score-api)= +[](){ #classification-api } + +### Classification API + +Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach). + +We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. + +Code example: <gh-file:examples/online_serving/openai_classification_client.py> + +#### Example Requests + +You can classify multiple texts by passing an array of strings: + +Request: + +```bash +curl -v "http://127.0.0.1:8000/classify" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": [ + "Loved the new cafรฉโ€”coffee was great.", + "This update broke everything. Frustrating." + ] + }' +``` + +Response: + +```bash +{ + "id": "classify-7c87cac407b749a6935d8c7ce2a8fba2", + "object": "list", + "created": 1745383065, + "model": "jason9693/Qwen2.5-1.5B-apeach", + "data": [ + { + "index": 0, + "label": "Default", + "probs": [ + 0.565970778465271, + 0.4340292513370514 + ], + "num_classes": 2 + }, + { + "index": 1, + "label": "Spoiled", + "probs": [ + 0.26448777318000793, + 0.7355121970176697 + ], + "num_classes": 2 + } + ], + "usage": { + "prompt_tokens": 20, + "total_tokens": 20, + "completion_tokens": 0, + "prompt_tokens_details": null + } +} +``` + +You can also pass a string directly to the `input` field: + +Request: + +```bash +curl -v "http://127.0.0.1:8000/classify" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": "Loved the new cafรฉโ€”coffee was great." + }' +``` + +Response: + +```bash +{ + "id": "classify-9bf17f2847b046c7b2d5495f4b4f9682", + "object": "list", + "created": 1745383213, + "model": "jason9693/Qwen2.5-1.5B-apeach", + "data": [ + { + "index": 0, + "label": "Default", + "probs": [ + 0.565970778465271, + 0.4340292513370514 + ], + "num_classes": 2 + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 10, + "completion_tokens": 0, + "prompt_tokens_details": null + } +} +``` + +#### Extra parameters + +The following [pooling parameters][pooling-params] are supported. + +```python +--8<-- "vllm/entrypoints/openai/protocol.py:classification-pooling-params" +``` + +The following extra parameters are supported: + +```python +--8<-- "vllm/entrypoints/openai/protocol.py:classification-extra-params" +``` + +[](){ #score-api } ### Score API @@ -590,23 +640,19 @@ Response: #### Extra parameters -The following [pooling parameters](#pooling-params) are supported. +The following [pooling parameters][pooling-params] are supported. -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-score-pooling-params -:end-before: end-score-pooling-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:score-pooling-params" +``` The following extra parameters are supported: -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-score-extra-params -:end-before: end-score-extra-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:score-extra-params" +``` -(rerank-api)= +[](){ #rerank-api } ### Re-rank API @@ -677,18 +723,14 @@ Response: #### Extra parameters -The following [pooling parameters](#pooling-params) are supported. +The following [pooling parameters][pooling-params] are supported. -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-rerank-pooling-params -:end-before: end-rerank-pooling-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:rerank-pooling-params" +``` The following extra parameters are supported: -:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py -:language: python -:start-after: begin-rerank-extra-params -:end-before: end-rerank-extra-params -::: +```python +--8<-- "vllm/entrypoints/openai/protocol.py:rerank-extra-params" +``` diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css deleted file mode 100644 index 79bd2082b49e..000000000000 --- a/docs/source/_static/custom.css +++ /dev/null @@ -1,8 +0,0 @@ -.vertical-table-header th.head:not(.stub) { - writing-mode: sideways-lr; - white-space: nowrap; - max-width: 0; - p { - margin: 0; - } -} diff --git a/docs/source/_templates/sections/header.html b/docs/source/_templates/sections/header.html deleted file mode 100644 index 7174431b1027..000000000000 --- a/docs/source/_templates/sections/header.html +++ /dev/null @@ -1,39 +0,0 @@ -<style> - .notification-bar { - width: 100vw; - display: flex; - justify-content: center; - align-items: center; - font-size: 16px; - padding: 0 6px 0 6px; - } - .notification-bar p { - margin: 0; - } - .notification-bar a { - font-weight: bold; - text-decoration: none; - } - - /* Light mode styles (default) */ - .notification-bar { - background-color: #fff3cd; - color: #856404; - } - .notification-bar a { - color: #d97706; - } - - /* Dark mode styles */ - html[data-theme=dark] .notification-bar { - background-color: #333; - color: #ddd; - } - html[data-theme=dark] .notification-bar a { - color: #ffa500; /* Brighter color for visibility */ - } -</style> - -<div class="notification-bar"> - <p>You are viewing the latest developer preview docs. <a href="https://docs.vllm.ai/en/stable/">Click here</a> to view docs for the latest stable release.</p> -</div> diff --git a/docs/source/api/summary.md b/docs/source/api/summary.md deleted file mode 100644 index 46de545f9ded..000000000000 --- a/docs/source/api/summary.md +++ /dev/null @@ -1,133 +0,0 @@ -# Summary - -(configuration)= - -## Configuration - -API documentation for vLLM's configuration classes. - -```{autodoc2-summary} - vllm.config.ModelConfig - vllm.config.CacheConfig - vllm.config.TokenizerPoolConfig - vllm.config.LoadConfig - vllm.config.ParallelConfig - vllm.config.SchedulerConfig - vllm.config.DeviceConfig - vllm.config.SpeculativeConfig - vllm.config.LoRAConfig - vllm.config.PromptAdapterConfig - vllm.config.MultiModalConfig - vllm.config.PoolerConfig - vllm.config.DecodingConfig - vllm.config.ObservabilityConfig - vllm.config.KVTransferConfig - vllm.config.CompilationConfig - vllm.config.VllmConfig -``` - -(offline-inference-api)= - -## Offline Inference - -LLM Class. - -```{autodoc2-summary} - vllm.LLM -``` - -LLM Inputs. - -```{autodoc2-summary} - vllm.inputs.PromptType - vllm.inputs.TextPrompt - vllm.inputs.TokensPrompt -``` - -## vLLM Engines - -Engine classes for offline and online inference. - -```{autodoc2-summary} - vllm.LLMEngine - vllm.AsyncLLMEngine -``` - -## Inference Parameters - -Inference parameters for vLLM APIs. - -(sampling-params)= -(pooling-params)= - -```{autodoc2-summary} - vllm.SamplingParams - vllm.PoolingParams -``` - -(multi-modality)= - -## Multi-Modality - -vLLM provides experimental support for multi-modal models through the {mod}`vllm.multimodal` package. - -Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models) -via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`. - -Looking to add your own multi-modal model? Please follow the instructions listed [here](#supports-multimodal). - -```{autodoc2-summary} - vllm.multimodal.MULTIMODAL_REGISTRY -``` - -### Inputs - -User-facing inputs. - -```{autodoc2-summary} - vllm.multimodal.inputs.MultiModalDataDict -``` - -Internal data structures. - -```{autodoc2-summary} - vllm.multimodal.inputs.PlaceholderRange - vllm.multimodal.inputs.NestedTensors - vllm.multimodal.inputs.MultiModalFieldElem - vllm.multimodal.inputs.MultiModalFieldConfig - vllm.multimodal.inputs.MultiModalKwargsItem - vllm.multimodal.inputs.MultiModalKwargs - vllm.multimodal.inputs.MultiModalInputs -``` - -### Data Parsing - -```{autodoc2-summary} - vllm.multimodal.parse -``` - -### Data Processing - -```{autodoc2-summary} - vllm.multimodal.processing -``` - -### Memory Profiling - -```{autodoc2-summary} - vllm.multimodal.profiling -``` - -### Registry - -```{autodoc2-summary} - vllm.multimodal.registry -``` - -## Model Development - -```{autodoc2-summary} - vllm.model_executor.models.interfaces_base - vllm.model_executor.models.interfaces - vllm.model_executor.models.adapters -``` diff --git a/docs/source/autodoc2_docstring_parser.py b/docs/source/autodoc2_docstring_parser.py deleted file mode 100644 index 41c49ed1c545..000000000000 --- a/docs/source/autodoc2_docstring_parser.py +++ /dev/null @@ -1,21 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from docutils import nodes -from myst_parser.parsers.sphinx_ import MystParser -from sphinx.ext.napoleon import docstring - - -class NapoleonParser(MystParser): - - def parse(self, input_string: str, document: nodes.document) -> None: - # Get the Sphinx configuration - config = document.settings.env.config - - parsed_content = str( - docstring.GoogleDocstring( - str(docstring.NumpyDocstring(input_string, config)), - config, - )) - return super().parse(parsed_content, document) - - -Parser = NapoleonParser diff --git a/docs/source/community/blog.md b/docs/source/community/blog.md deleted file mode 100644 index e8030edfa02e..000000000000 --- a/docs/source/community/blog.md +++ /dev/null @@ -1,3 +0,0 @@ -# vLLM Blog - -vLLM blog posts are published [here](https://blog.vllm.ai/). diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index 5620d6de2c59..000000000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,263 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. - -import datetime -import logging -import os -import re -import sys -from pathlib import Path - -import requests - -logger = logging.getLogger(__name__) -REPO_ROOT = Path(__file__).resolve().parent.parent.parent -sys.path.append(os.path.abspath(REPO_ROOT)) - -# -- Project information ----------------------------------------------------- - -project = 'vLLM' -copyright = f'{datetime.datetime.now().year}, vLLM Team' -author = 'the vLLM Team' - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.napoleon", - "sphinx.ext.linkcode", - "sphinx.ext.intersphinx", - "sphinx_copybutton", - "autodoc2", - "myst_parser", - "sphinxarg.ext", - "sphinx_design", - "sphinx_togglebutton", -] -myst_enable_extensions = [ - "colon_fence", - "fieldlist", -] -autodoc2_packages = [ - { - "path": "../../vllm", - "exclude_dirs": ["__pycache__", "third_party"], - }, -] -autodoc2_output_dir = "api" -autodoc2_render_plugin = "myst" -autodoc2_hidden_objects = ["dunder", "private", "inherited"] -autodoc2_sort_names = True -autodoc2_index_template = None - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns: list[str] = ["**/*.template.md", "**/*.inc.md"] - -# Exclude the prompt "$" when copying code -copybutton_prompt_text = r"\$ " -copybutton_prompt_is_regexp = True - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_title = project -html_theme = 'sphinx_book_theme' -html_logo = 'assets/logos/vllm-logo-text-light.png' -html_favicon = 'assets/logos/vllm-logo-only-light.ico' -html_theme_options = { - 'path_to_docs': 'docs/source', - 'repository_url': 'https://github.com/vllm-project/vllm', - 'use_repository_button': True, - 'use_edit_page_button': True, - # Prevents the full API being added to the left sidebar of every page. - # Reduces build time by 2.5x and reduces build size from ~225MB to ~95MB. - 'collapse_navbar': True, - # Makes API visible in the right sidebar on API reference pages. - 'show_toc_level': 3, -} -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] -html_js_files = ["custom.js"] -html_css_files = ["custom.css"] - -myst_heading_anchors = 2 -myst_url_schemes = { - 'http': None, - 'https': None, - 'mailto': None, - 'ftp': None, - "gh-issue": { - "url": - "https://github.com/vllm-project/vllm/issues/{{path}}#{{fragment}}", - "title": "Issue #{{path}}", - "classes": ["github"], - }, - "gh-pr": { - "url": - "https://github.com/vllm-project/vllm/pull/{{path}}#{{fragment}}", - "title": "Pull Request #{{path}}", - "classes": ["github"], - }, - "gh-project": { - "url": "https://github.com/orgs/vllm-project/projects/{{path}}", - "title": "Project #{{path}}", - "classes": ["github"], - }, - "gh-dir": { - "url": "https://github.com/vllm-project/vllm/tree/main/{{path}}", - "title": "{{path}}", - "classes": ["github"], - }, - "gh-file": { - "url": "https://github.com/vllm-project/vllm/blob/main/{{path}}", - "title": "{{path}}", - "classes": ["github"], - }, -} - -# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa -READTHEDOCS_VERSION_TYPE = os.environ.get('READTHEDOCS_VERSION_TYPE') -if READTHEDOCS_VERSION_TYPE == "tag": - # remove the warning banner if the version is a tagged release - header_file = os.path.join(os.path.dirname(__file__), - "_templates/sections/header.html") - # The file might be removed already if the build is triggered multiple times - # (readthedocs build both HTML and PDF versions separately) - if os.path.exists(header_file): - os.remove(header_file) - - -# Generate additional rst documentation here. -def setup(app): - from docs.source.generate_examples import generate_examples - generate_examples() - - -_cached_base: str = "" -_cached_branch: str = "" - - -def get_repo_base_and_branch(pr_number): - global _cached_base, _cached_branch - if _cached_base and _cached_branch: - return _cached_base, _cached_branch - - url = f"https://api.github.com/repos/vllm-project/vllm/pulls/{pr_number}" - response = requests.get(url) - if response.status_code == 200: - data = response.json() - _cached_base = data['head']['repo']['full_name'] - _cached_branch = data['head']['ref'] - return _cached_base, _cached_branch - else: - logger.error("Failed to fetch PR details: %s", response) - return None, None - - -def linkcode_resolve(domain, info): - if domain != 'py': - return None - if not info['module']: - return None - - # Get path from module name - file = Path(f"{info['module'].replace('.', '/')}.py") - path = REPO_ROOT / file - if not path.exists(): - path = REPO_ROOT / file.with_suffix("") / "__init__.py" - if not path.exists(): - return None - - # Get the line number of the object - with open(path) as f: - lines = f.readlines() - name = info['fullname'].split(".")[-1] - pattern = fr"^( {{4}})*((def|class) )?{name}\b.*" - for lineno, line in enumerate(lines, 1): - if not line or line.startswith("#"): - continue - if re.match(pattern, line): - break - - # If the line number is not found, return None - if lineno == len(lines): - return None - - # If the line number is found, create the URL - filename = path.relative_to(REPO_ROOT) - if "checkouts" in path.parts: - # a PR build on readthedocs - pr_number = REPO_ROOT.name - base, branch = get_repo_base_and_branch(pr_number) - if base and branch: - return f"https://github.com/{base}/blob/{branch}/{filename}#L{lineno}" - # Otherwise, link to the source file on the main branch - return f"https://github.com/vllm-project/vllm/blob/main/{filename}#L{lineno}" - - -# Mock out external dependencies here, otherwise sphinx-argparse won't work. -autodoc_mock_imports = [ - "huggingface_hub", - "pydantic", - "zmq", - "cloudpickle", - "aiohttp", - "starlette", - "blake3", - "cpuinfo", - "transformers", - "psutil", - "vllm._C", - "PIL", - "numpy", - "tqdm", - # The mocks below are required by - # docs/source/serving/openai_compatible_server.md's - # vllm.entrypoints.openai.cli_args - "openai", - "fastapi", - "partial_json_parser", -] - -for mock_target in autodoc_mock_imports: - if mock_target in sys.modules: - logger.info( - "Potentially problematic mock target (%s) found; " - "autodoc_mock_imports cannot mock modules that have already " - "been loaded into sys.modules when the sphinx build starts.", - mock_target) - -intersphinx_mapping = { - "python": ("https://docs.python.org/3", None), - "typing_extensions": - ("https://typing-extensions.readthedocs.io/en/latest", None), - "aiohttp": ("https://docs.aiohttp.org/en/stable", None), - "pillow": ("https://pillow.readthedocs.io/en/stable", None), - "numpy": ("https://numpy.org/doc/stable", None), - "torch": ("https://pytorch.org/docs/stable", None), - "psutil": ("https://psutil.readthedocs.io/en/stable", None), -} - -navigation_with_keys = False diff --git a/docs/source/contributing/model/index.md b/docs/source/contributing/model/index.md deleted file mode 100644 index 721ee3cd2047..000000000000 --- a/docs/source/contributing/model/index.md +++ /dev/null @@ -1,27 +0,0 @@ -(new-model)= - -# Adding a New Model - -This section provides more information on how to integrate a [PyTorch](https://pytorch.org/) model into vLLM. - -:::{toctree} -:caption: Contents -:maxdepth: 1 - -basic -registration -tests -multimodal -::: - -:::{note} -The complexity of adding a new model depends heavily on the model's architecture. -The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. -However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex. -::: - -:::{tip} -If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues) -or ask on our [developer slack](https://slack.vllm.ai). -We will be happy to help you out! -::: diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md deleted file mode 100644 index b42536f054d7..000000000000 --- a/docs/source/contributing/model/multimodal.md +++ /dev/null @@ -1,834 +0,0 @@ -(supports-multimodal)= - -# Multi-Modal Support - -This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](#multimodal-inputs). - -## 1. Update the base vLLM model - -It is assumed that you have already implemented the model in vLLM according to [these steps](#new-model-basic). -Further update the model as follows: - -- Reserve a keyword parameter in {meth}`~torch.nn.Module.forward` for each input tensor that corresponds to a multi-modal input, as shown in the following example: - - ```diff - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - + pixel_values: torch.Tensor, - ) -> SamplerOutput: - ``` - - More conveniently, you can simply pass `**kwargs` to the {meth}`~torch.nn.Module.forward` method and retrieve the keyword parameters for multimodal inputs from it. - -- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings` that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs. - - ```python - class YourModelForImage2Seq(nn.Module): - ... - - def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor: - - assert self.vision_encoder is not None - image_features = self.vision_encoder(image_input) - return self.multi_modal_projector(image_features) - - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - - # Validate the multimodal input keyword arguments - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - return None - - # Run multimodal inputs through encoder and projector - vision_embeddings = self._process_image_input(image_input) - return vision_embeddings - ``` - - :::{important} - The returned `multimodal_embeddings` must be either a **3D {class}`torch.Tensor`** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D {class}`torch.Tensor`'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request. - ::: - -- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings` to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings. - - ```python - from .utils import merge_multimodal_embeddings - - class YourModelForImage2Seq(nn.Module): - ... - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index) - - return inputs_embeds - ``` - -- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model. - - ```python - class YourModelForImage2Seq(nn.Module): - ... - - def get_language_model(self) -> torch.nn.Module: - # Change `language_model` according to your implementation. - return self.language_model - ``` - -- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. - - ```diff - + from vllm.model_executor.models.interfaces import SupportsMultiModal - - - class YourModelForImage2Seq(nn.Module): - + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): - ``` - - :::{note} - The model class does not have to be named {code}`*ForCausalLM`. - Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples. - ::: - -## 2. Specify processing information - -Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo` -to provide basic information related to HF processing. - -### Maximum number of input items - -You need to override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_supported_mm_limits` -to return the maximum number of input items for each modality supported by the model. - -For example, if the model supports any number of images but only one video per prompt: - -```python -def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": 1} -``` - -## 3. Specify dummy inputs - -Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for -HF processing as well as memory profiling. - -### For memory profiling - -Override the abstract methods {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text` and {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_mm_data` to construct dummy inputs for memory profiling. These dummy inputs should result in the worst-case memory usage of the model so that vLLM can reserve the correct amount of memory for it. - -Assuming that the memory usage increases with the number of tokens, the dummy inputs can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens. - -::::{tab-set} -:::{tab-item} Basic example: LLaVA -:sync: llava - -Looking at the code of HF's `LlavaForConditionalGeneration`: - -```python -# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544 -n_image_tokens = (input_ids == self.config.image_token_index).sum().item() -n_image_features = image_features.shape[0] * image_features.shape[1] - -if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) -special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) -) -image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) -inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) -``` - -The number of placeholder feature tokens per image is `image_features.shape[1]`. -`image_features` is calculated inside the `get_image_features` method: - -```python -# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300 -image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) - -selected_image_feature = image_outputs.hidden_states[vision_feature_layer] -if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] -elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature -else: - raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") -image_features = self.multi_modal_projector(selected_image_feature) -return image_features -``` - -We can infer that `image_features.shape[1]` is based on `image_outputs.hidden_states.shape[1]` from the vision tower -(`CLIPVisionModel` for the [`llava-hf/llava-1.5-7b-hf`](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model). -Moreover, we only need the sequence length (the second dimension of the tensor) to get `image_features.shape[1]`. -The sequence length is determined by the initial hidden states in `CLIPVisionTransformer` since the attention -mechanism doesn't change the sequence length of the output hidden states. - -```python -# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L1094-L1102 -hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) -hidden_states = self.pre_layrnorm(hidden_states) - -encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, -) -``` - -To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`: - -```python -# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257 -target_dtype = self.patch_embedding.weight.dtype -patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] -patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - -class_embeds = self.class_embedding.expand(batch_size, 1, -1) -embeddings = torch.cat([class_embeds, patch_embeds], dim=1) -if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) -else: - embeddings = embeddings + self.position_embedding(self.position_ids) -return embeddings -``` - -We can infer that `embeddings.shape[1] == self.num_positions`, where - -```python -# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L195-L196 -self.num_patches = (self.image_size // self.patch_size) ** 2 -self.num_positions = self.num_patches + 1 -``` - -Overall, the number of placeholder feature tokens for an image can be calculated as: - -```python -def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, -) -> int: - hf_config = self.get_hf_config() - hf_processor = self.get_hf_processor() - - image_size = hf_config.vision_config.image_size - patch_size = hf_config.vision_config.patch_size - - num_image_tokens = (image_size // patch_size) ** 2 + 1 - if hf_processor.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - - return num_image_tokens -``` - -Notice that the number of image tokens doesn't depend on the image width and height. -We can simply use a dummy `image_size` to calculate the multimodal profiling data: - -```python -# NOTE: In actuality, this is usually implemented as part of the -# model's subclass of `BaseProcessingInfo`, but we show it as is -# here for simplicity. -def get_image_size_with_most_features(self) -> ImageSize: - hf_config = self.get_hf_config() - width = height = hf_config.image_size - return ImageSize(width=width, height=height) - -def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], -) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width, target_height = \ - self.info.get_image_size_with_most_features() - - return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } -``` - -For the text, we simply expand the multimodal image token from the model config to match the desired number of images. - -```python -def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_images = mm_counts.get("image", 0) - - processor = self.info.get_hf_processor() - image_token = processor.image_token - - return image_token * num_images -``` - -::: - -:::{tab-item} No input placeholders: Fuyu -:sync: fuyu - -Looking at the code of HF's `FuyuForCausalLM`: - -```python -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322 -if image_patches is not None and past_key_values is None: - patch_embeddings = [ - self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)) - .squeeze(0) - .to(inputs_embeds.device) - for patch in image_patches - ] - inputs_embeds = self.gather_continuous_embeddings( - word_embeddings=inputs_embeds, - continuous_embeddings=patch_embeddings, - image_patch_input_indices=image_patches_indices, - ) -``` - -The number of placeholder feature tokens for the `i`th item in the batch is `patch_embeddings[i].shape[0]`, -which is the same as `image_patches[i].shape[0]`, i.e. `num_total_patches`. - -Unlike LLaVA, Fuyu does not define the number of patches inside the modeling file. Where can we get more information? -Considering that the model input comes from the output of `FuyuProcessor`, let's **look at the preprocessing files**. - -The image outputs are obtained by calling `FuyuImageProcessor.preprocess` and then -`FuyuImageProcessor.preprocess_with_tokenizer_info` inside `FuyuProcessor`. - -In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`, -returning the dimensions after resizing (but before padding) as metadata. - -```python -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544 -image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"]) -batch_images = image_encoding["images"] -image_unpadded_heights = image_encoding["image_unpadded_heights"] -image_unpadded_widths = image_encoding["image_unpadded_widths"] - -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L480-L -if do_resize: - batch_images = [ - [self.resize(image, size=size, input_data_format=input_data_format) for image in images] - for images in batch_images - ] - -image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images] -image_unpadded_heights = [[image_size[0]] for image_size in image_sizes] -image_unpadded_widths = [[image_size[1]] for image_size in image_sizes] - -if do_pad: - batch_images = [ - [ - self.pad_image( - image, - size=size, - mode=padding_mode, - constant_values=padding_value, - input_data_format=input_data_format, - ) - for image in images - ] - for images in batch_images - ] -``` - -In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata: - -```python -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425 -model_image_input = self.image_processor.preprocess_with_tokenizer_info( - image_input=tensor_batch_images, - image_present=image_present, - image_unpadded_h=image_unpadded_heights, - image_unpadded_w=image_unpadded_widths, - image_placeholder_id=image_placeholder_id, - image_newline_id=image_newline_id, - variable_sized=True, -) - -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L638-L658 -image_height, image_width = image.shape[1], image.shape[2] -if variable_sized: # variable_sized=True - new_h = min( - image_height, - math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height, - ) - new_w = min( - image_width, - math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width, - ) - image = image[:, :new_h, :new_w] - image_height, image_width = new_h, new_w - -num_patches = self.get_num_patches(image_height=image_height, image_width=image_width) -tensor_of_image_ids = torch.full( - [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device -) -patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0) -assert num_patches == patches.shape[0] -``` - -The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`: - -```python -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562 -patch_size = patch_size if patch_size is not None else self.patch_size -patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] - -if image_height % patch_height != 0: - raise ValueError(f"{image_height=} must be divisible by {patch_height}") -if image_width % patch_width != 0: - raise ValueError(f"{image_width=} must be divisible by {patch_width}") - -num_patches_per_dim_h = image_height // patch_height -num_patches_per_dim_w = image_width // patch_width -num_patches = num_patches_per_dim_h * num_patches_per_dim_w -``` - -These image patches correspond to placeholder tokens (`|SPEAKER|`). So, we just need to maximize the number of image patches. Since input images are first resized -to fit within `image_processor.size`, we can maximize the number of image patches by inputting an image with size equal to `image_processor.size`. - -```python -def get_image_size_with_most_features(self) -> ImageSize: - image_processor = self.get_image_processor() - return ImageSize(width=image_processor.size["width"], - height=image_processor.size["height"]) -``` - -Fuyu does not expect image placeholders in the inputs to HF processor, so -the dummy prompt text is empty regardless of the number of images. - -```python -def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - return "" -``` - -For the multimodal image profiling data, the logic is very similar to LLaVA: - -```python -def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], -) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() - num_images = mm_counts.get("image", 0) - - return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } -``` - -::: - -:::: - -## 4. Specify processing details - -Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` -to fill in the missing details about HF processing. - -:::{seealso} -[Multi-Modal Data Processing](#mm-processing) -::: - -### Multi-modal fields - -Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to -return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items. - -:::::{tab-set} -::::{tab-item} Basic example: LLaVA -:sync: llava - -The output of `CLIPImageProcessor` is a simple tensor with shape -`(num_images, num_channels, image_height, image_width)`: - -```python -# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/image_processing_clip.py#L339-L345 -images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - for image in all_images -] - -data = {"pixel_values": images} -return BatchFeature(data=data, tensor_type=return_tensors) -``` - -So, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` as follows: - -```python -def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], -) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - ) -``` - -:::{note} -Our [actual code](gh-file:vllm/model_executor/models/llava.py) additionally supports -pre-computed image embeddings, which can be passed to be model via the `image_embeds` argument. -::: - -:::: - -::::{tab-item} With postprocessing: Fuyu -:sync: fuyu - -The `image_patches` output of `FuyuImageProcessor.preprocess_with_tokenizer_info` concatenates -the patches from each image belonging to an item in the batch: - -```python -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L673-L679 - image_input_ids.append(tensor_of_image_ids) - image_patches.append(patches) - else: - image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device)) - -batch_image_input_ids.append(image_input_ids) -batch_image_patches.append(image_patches) -``` - -The shape of `image_patches` outputted by `FuyuImageProcessor` is therefore -`(1, num_images, num_patches, patch_width * patch_height * num_channels)`. - -In order to support the use of {func}`MultiModalFieldConfig.batched` like in LLaVA, -we remove the extra batch dimension by overriding {meth}`BaseMultiModalProcessor._call_hf_processor`: - -```python -def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], -) -> BatchFeature: - processed_outputs = super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - ) - - image_patches = processed_outputs.get("image_patches") - if image_patches is not None: - images = mm_data["images"] - assert isinstance(images, list) - - # Original output: (1, num_images, Pn, Px * Py * C) - # New output: (num_images, Pn, Px * Py * C) - assert (isinstance(image_patches, list) - and len(image_patches) == 1) - assert (isinstance(image_patches[0], torch.Tensor) - and len(image_patches[0]) == len(images)) - - processed_outputs["image_patches"] = image_patches[0] - - return processed_outputs -``` - -:::{note} -Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling -for text-only inputs to prevent unnecessary warnings from HF processor. -::: - -This lets us override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` as follows: - -```python -def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], -) -> Mapping[str, MultiModalFieldConfig]: - return dict(image_patches=MultiModalFieldConfig.batched("image")) -``` - -:::: - -::::: - -### Prompt updates - -Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to -return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances. - -Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation -(e.g.: insertion, replacement) performed by the HF processor. - -::::{tab-set} -:::{tab-item} Basic example: LLaVA -:sync: llava - -Looking at HF's `LlavaProcessor`: - -```python -# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/processing_llava.py#L167-L170 -prompt_strings = [] -for sample in text: - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) - prompt_strings.append(sample) -``` - -It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`). -Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows: - -```python -def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, -) -> Sequence[PromptUpdate]: - hf_config = self.info.get_hf_config() - image_token_id = hf_config.image_token_index - - def get_replacement(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) - - image_size = images.get_image_size(item_idx) - num_image_tokens = self.info.get_num_image_tokens( - image_width=image_size.width, - image_height=image_size.height, - ) - - return [image_token_id] * num_image_tokens - - return [ - PromptReplacement( - modality="image", - target=[image_token_id], - replacement=get_replacement, - ), - ] -``` - -::: - -:::{tab-item} Handling additional tokens: Fuyu -:sync: fuyu - -Recall the layout of feature tokens from Step 2: - -``` -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -... -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -``` - -We define a helper function to return `ncols` and `nrows` directly: - -```python -def get_image_feature_grid_size( - self, - *, - image_width: int, - image_height: int, -) -> tuple[int, int]: - image_processor = self.get_image_processor() - target_width = image_processor.size["width"] - target_height = image_processor.size["height"] - patch_width = image_processor.patch_size["width"] - patch_height = image_processor.patch_size["height"] - - if not (image_width <= target_width and image_height <= target_height): - height_scale_factor = target_height / image_height - width_scale_factor = target_width / image_width - optimal_scale_factor = min(height_scale_factor, width_scale_factor) - - image_height = int(image_height * optimal_scale_factor) - image_width = int(image_width * optimal_scale_factor) - - ncols = math.ceil(image_width / patch_width) - nrows = math.ceil(image_height / patch_height) - return ncols, nrows -``` - -Based on this, we can initially define our replacement tokens as: - -```python -def get_replacement(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) - image_size = images.get_image_size(item_idx) - - ncols, nrows = self.info.get_image_feature_grid_size( - image_width=image_size.width, - image_height=image_size.height, - ) - - # `_IMAGE_TOKEN_ID` corresponds to `|SPEAKER|` - # `_NEWLINE_TOKEN_ID` corresponds to `|NEWLINE|` - return ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows -``` - -However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called, -a BOS token (`<s>`) is also added to the promopt: - -```python -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435 -model_image_input = self.image_processor.preprocess_with_tokenizer_info( - image_input=tensor_batch_images, - image_present=image_present, - image_unpadded_h=image_unpadded_heights, - image_unpadded_w=image_unpadded_widths, - image_placeholder_id=image_placeholder_id, - image_newline_id=image_newline_id, - variable_sized=True, -) -prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( - tokenizer=self.tokenizer, - prompts=prompts, - scale_factors=scale_factors, - max_tokens_to_generate=self.max_tokens_to_generate, - max_position_embeddings=self.max_position_embeddings, - add_BOS=True, - add_beginning_of_answer_token=True, -) -``` - -To assign the vision embeddings to only the image tokens, instead of a string -you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`: - -```python -hf_config = self.info.get_hf_config() -bos_token_id = hf_config.bos_token_id # `<s>` -assert isinstance(bos_token_id, int) - -def get_replacement_fuyu(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) - image_size = images.get_image_size(item_idx) - - ncols, nrows = self.info.get_image_feature_grid_size( - image_width=image_size.width, - image_height=image_size.height, - ) - image_tokens = ([_IMAGE_TOKEN_ID] * ncols + - [_NEWLINE_TOKEN_ID]) * nrows - - return PromptUpdateDetails.select_token_id( - image_tokens + [bos_token_id], - embed_token_id=_IMAGE_TOKEN_ID, - ) -``` - -Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt, -we can search for it to conduct the replacement at the start of the string: - -```python -def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, -) -> Sequence[PromptUpdate]: - hf_config = self.info.get_hf_config() - bos_token_id = hf_config.bos_token_id - assert isinstance(bos_token_id, int) - - tokenizer = self.info.get_tokenizer() - eot_token_id = tokenizer.bos_token_id - assert isinstance(eot_token_id, int) - - def get_replacement_fuyu(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) - image_size = images.get_image_size(item_idx) - - ncols, nrows = self.info.get_image_feature_grid_size( - image_width=image_size.width, - image_height=image_size.height, - ) - image_tokens = ([_IMAGE_TOKEN_ID] * ncols + - [_NEWLINE_TOKEN_ID]) * nrows - - return PromptUpdateDetails.select_token_id( - image_tokens + [bos_token_id], - embed_token_id=_IMAGE_TOKEN_ID, - ) - - return [ - PromptReplacement( - modality="image", - target=[eot_token_id], - replacement=get_replacement_fuyu, - ) - ] -``` - -::: - -:::: - -## 5. Register processor-related classes - -After you have defined {class}`~vllm.multimodal.processing.BaseProcessingInfo` (Step 2), -{class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` (Step 3), -and {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` (Step 4), -decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor <vllm.multimodal.registry.MultiModalRegistry.register_processor>` -to register them to the multi-modal registry: - -```diff - from vllm.model_executor.models.interfaces import SupportsMultiModal -+ from vllm.multimodal import MULTIMODAL_REGISTRY - -+ @MULTIMODAL_REGISTRY.register_processor(YourMultiModalProcessor, -+ info=YourProcessingInfo, -+ dummy_inputs=YourDummyInputsBuilder) - class YourModelForImage2Seq(nn.Module, SupportsMultiModal): -``` - -## Notes - -### Inserting feature tokens without replacement - -Some HF processors directly insert feature tokens without replacing anything in the original prompt. In that case, you can use {class}`~vllm.multimodal.processing.PromptInsertion` instead of {class}`~vllm.multimodal.processing.PromptReplacement` inside {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. - -Examples: - -- BLIP-2 (insert at start of prompt): <gh-file:vllm/model_executor/models/blip2.py> -- Florence2 (insert at start of prompt): <gh-file:vllm/model_executor/models/florence2.py> -- Molmo (insert after `<|endoftext|>` token): <gh-file:vllm/model_executor/models/molmo.py> - -### Handling prompt updates unrelated to multi-modal data - -{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only` so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design](#mm-processing). - -Examples: - -- Chameleon (appends `sep_token`): <gh-file:vllm/model_executor/models/chameleon.py> -- Fuyu (appends `boa_token`): <gh-file:vllm/model_executor/models/fuyu.py> -- Molmo (applies chat template which is not defined elsewhere): <gh-file:vllm/model_executor/models/molmo.py> - -### Custom HF processor - -Some models don't define a HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor`. - -Examples: - -- DeepSeek-VL2: <gh-file:vllm/model_executor/models/deepseek_vl2.py> -- InternVL: <gh-file:vllm/model_executor/models/internvl.py> -- Qwen-VL: <gh-file:vllm/model_executor/models/qwen_vl.py> diff --git a/docs/source/contributing/model/registration.md b/docs/source/contributing/model/registration.md deleted file mode 100644 index 64cd25b53807..000000000000 --- a/docs/source/contributing/model/registration.md +++ /dev/null @@ -1,55 +0,0 @@ -(new-model-registration)= - -# Registering a Model to vLLM - -vLLM relies on a model registry to determine how to run each model. -A list of pre-registered architectures can be found [here](#supported-models). - -If your model is not on this list, you must register it to vLLM. -This page provides detailed instructions on how to do so. - -## Built-in models - -To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source](#build-from-source). -This gives you the ability to modify the codebase and test your model. - -After you have implemented your model (see [tutorial](#new-model-basic)), put it into the <gh-dir:vllm/model_executor/models> directory. -Then, add your model class to `_VLLM_MODELS` in <gh-file:vllm/model_executor/models/registry.py> so that it is automatically registered upon importing vLLM. -Finally, update our [list of supported models](#supported-models) to promote your model! - -:::{important} -The list of models in each section should be maintained in alphabetical order. -::: - -## Out-of-tree models - -You can load an external model using a plugin without modifying the vLLM codebase. - -:::{seealso} -[vLLM's Plugin System](#plugin-system) -::: - -To register the model, use the following code: - -```python -from vllm import ModelRegistry -from your_code import YourModelForCausalLM -ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) -``` - -If your model imports modules that initialize CUDA, consider lazy-importing it to avoid errors like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`: - -```python -from vllm import ModelRegistry - -ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM") -``` - -:::{important} -If your model is a multimodal model, ensure the model class implements the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. -Read more about that [here](#supports-multimodal). -::: - -:::{note} -Although you can directly put these code snippets in your script using `vllm.LLM`, the recommended way is to place these snippets in a vLLM plugin. This ensures compatibility with various vLLM features like distributed inference and the API server. -::: diff --git a/docs/source/deployment/docker.md b/docs/source/deployment/docker.md deleted file mode 100644 index ca56710bc2ef..000000000000 --- a/docs/source/deployment/docker.md +++ /dev/null @@ -1,133 +0,0 @@ -(deployment-docker)= - -# Using Docker - -(deployment-docker-pre-built-image)= - -## Use vLLM's Official Docker Image - -vLLM offers an official Docker image for deployment. -The image can be used to run OpenAI compatible server and is available on Docker Hub as [vllm/vllm-openai](https://hub.docker.com/r/vllm/vllm-openai/tags). - -```console -$ docker run --runtime nvidia --gpus all \ - -v ~/.cache/huggingface:/root/.cache/huggingface \ - --env "HUGGING_FACE_HUB_TOKEN=<secret>" \ - -p 8000:8000 \ - --ipc=host \ - vllm/vllm-openai:latest \ - --model mistralai/Mistral-7B-v0.1 -``` - -This image can also be used with other container engines such as [Podman](https://podman.io/). - -```console -$ podman run --gpus all \ - -v ~/.cache/huggingface:/root/.cache/huggingface \ - --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \ - -p 8000:8000 \ - --ipc=host \ - vllm/vllm-openai:latest \ - --model mistralai/Mistral-7B-v0.1 -``` - -You can add any other <project:#engine-args> you need after the image tag (`vllm/vllm-openai:latest`). - -:::{note} -You can either use the `ipc=host` flag or `--shm-size` flag to allow the -container to access the host's shared memory. vLLM uses PyTorch, which uses shared -memory to share data between processes under the hood, particularly for tensor parallel inference. -::: - -:::{note} -Optional dependencies are not included in order to avoid licensing issues (e.g. <gh-issue:8030>). - -If you need to use those dependencies (having accepted the license terms), -create a custom Dockerfile on top of the base image with an extra layer that installs them: - -```Dockerfile -FROM vllm/vllm-openai:v0.8.3 - -# e.g. install the `audio` optional dependencies -# NOTE: Make sure the version of vLLM matches the base image! -RUN uv pip install --system vllm[audio]==0.8.3 -``` - -::: - -:::{tip} -Some new models may only be available on the main branch of [HF Transformers](https://github.com/huggingface/transformers). - -To use the development version of `transformers`, create a custom Dockerfile on top of the base image -with an extra layer that installs their code from source: - -```Dockerfile -FROM vllm/vllm-openai:latest - -RUN uv pip install --system git+https://github.com/huggingface/transformers.git -``` - -::: - -(deployment-docker-build-image-from-source)= - -## Building vLLM's Docker Image from Source - -You can build and run vLLM from source via the provided <gh-file:docker/Dockerfile>. To build vLLM: - -```console -# optionally specifies: --build-arg max_jobs=8 --build-arg nvcc_threads=2 -DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai --file docker/Dockerfile -``` - -:::{note} -By default vLLM will build for all GPU types for widest distribution. If you are just building for the -current GPU type the machine is running on, you can add the argument `--build-arg torch_cuda_arch_list=""` -for vLLM to find the current GPU type and build for that. - -If you are using Podman instead of Docker, you might need to disable SELinux labeling by -adding `--security-opt label=disable` when running `podman build` command to avoid certain [existing issues](https://github.com/containers/buildah/discussions/4184). -::: - -## Building for Arm64/aarch64 - -A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper. At time of this writing, this requires the use -of PyTorch Nightly and should be considered **experimental**. Using the flag `--platform "linux/arm64"` will attempt to build for arm64. - -:::{note} -Multiple modules must be compiled, so this process can take a while. Recommend using `--build-arg max_jobs=` & `--build-arg nvcc_threads=` -flags to speed up build process. However, ensure your `max_jobs` is substantially larger than `nvcc_threads` to get the most benefits. -Keep an eye on memory usage with parallel jobs as it can be substantial (see example below). -::: - -```console -# Example of building on Nvidia GH200 server. (Memory usage: ~15GB, Build time: ~1475s / ~25 min, Image size: 6.93GB) -$ python3 use_existing_torch.py -$ DOCKER_BUILDKIT=1 docker build . \ - --file docker/Dockerfile \ - --target vllm-openai \ - --platform "linux/arm64" \ - -t vllm/vllm-gh200-openai:latest \ - --build-arg max_jobs=66 \ - --build-arg nvcc_threads=2 \ - --build-arg torch_cuda_arch_list="9.0+PTX" \ - --build-arg vllm_fa_cmake_gpu_arches="90-real" -``` - -## Use the custom-built vLLM Docker image - -To run vLLM with the custom-built Docker image: - -```console -$ docker run --runtime nvidia --gpus all \ - -v ~/.cache/huggingface:/root/.cache/huggingface \ - -p 8000:8000 \ - --env "HUGGING_FACE_HUB_TOKEN=<secret>" \ - vllm/vllm-openai <args...> -``` - -The argument `vllm/vllm-openai` specifies the image to run, and should be replaced with the name of the custom-built image (the `-t` tag from the build command). - -:::{note} -**For version 0.4.1 and 0.4.2 only** - the vLLM docker images under these versions are supposed to be run under the root user since a library under the root user's home directory, i.e. `/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` is required to be loaded during runtime. If you are running the container under a different user, you may need to first change the permissions of the library (and all the parent directories) to allow the user to access it, then run vLLM with environment variable `VLLM_NCCL_SO_PATH=/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` . -::: diff --git a/docs/source/deployment/frameworks/helm.md b/docs/source/deployment/frameworks/helm.md deleted file mode 100644 index 7320d727fbaa..000000000000 --- a/docs/source/deployment/frameworks/helm.md +++ /dev/null @@ -1,250 +0,0 @@ -(deployment-helm)= - -# Helm - -A Helm chart to deploy vLLM for Kubernetes - -Helm is a package manager for Kubernetes. It will help you to deploy vLLM on k8s and automate the deployment of vLLM Kubernetes applications. With Helm, you can deploy the same framework architecture with different configurations to multiple namespaces by overriding variable values. - -This guide will walk you through the process of deploying vLLM with Helm, including the necessary prerequisites, steps for helm installation and documentation on architecture and values file. - -## Prerequisites - -Before you begin, ensure that you have the following: - -- A running Kubernetes cluster -- NVIDIA Kubernetes Device Plugin (`k8s-device-plugin`): This can be found at [https://github.com/NVIDIA/k8s-device-plugin](https://github.com/NVIDIA/k8s-device-plugin) -- Available GPU resources in your cluster -- S3 with the model which will be deployed - -## Installing the chart - -To install the chart with the release name `test-vllm`: - -```console -helm upgrade --install --create-namespace --namespace=ns-vllm test-vllm . -f values.yaml --set secrets.s3endpoint=$ACCESS_POINT --set secrets.s3bucketname=$BUCKET --set secrets.s3accesskeyid=$ACCESS_KEY --set secrets.s3accesskey=$SECRET_KEY -``` - -## Uninstalling the Chart - -To uninstall the `test-vllm` deployment: - -```console -helm uninstall test-vllm --namespace=ns-vllm -``` - -The command removes all the Kubernetes components associated with the -chart **including persistent volumes** and deletes the release. - -## Architecture - -:::{image} /assets/deployment/architecture_helm_deployment.png -::: - -## Values - -:::{list-table} -:widths: 25 25 25 25 -:header-rows: 1 - -- * Key - * Type - * Default - * Description -- * autoscaling - * object - * {"enabled":false,"maxReplicas":100,"minReplicas":1,"targetCPUUtilizationPercentage":80} - * Autoscaling configuration -- * autoscaling.enabled - * bool - * false - * Enable autoscaling -- * autoscaling.maxReplicas - * int - * 100 - * Maximum replicas -- * autoscaling.minReplicas - * int - * 1 - * Minimum replicas -- * autoscaling.targetCPUUtilizationPercentage - * int - * 80 - * Target CPU utilization for autoscaling -- * configs - * object - * {} - * Configmap -- * containerPort - * int - * 8000 - * Container port -- * customObjects - * list - * [] - * Custom Objects configuration -- * deploymentStrategy - * object - * {} - * Deployment strategy configuration -- * externalConfigs - * list - * [] - * External configuration -- * extraContainers - * list - * [] - * Additional containers configuration -- * extraInit - * object - * {"pvcStorage":"1Gi","s3modelpath":"relative_s3_model_path/opt-125m", "awsEc2MetadataDisabled": true} - * Additional configuration for the init container -- * extraInit.pvcStorage - * string - * "50Gi" - * Storage size of the s3 -- * extraInit.s3modelpath - * string - * "relative_s3_model_path/opt-125m" - * Path of the model on the s3 which hosts model weights and config files -- * extraInit.awsEc2MetadataDisabled - * boolean - * true - * Disables the use of the Amazon EC2 instance metadata service -- * extraPorts - * list - * [] - * Additional ports configuration -- * gpuModels - * list - * ["TYPE_GPU_USED"] - * Type of gpu used -- * image - * object - * {"command":["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"],"repository":"vllm/vllm-openai","tag":"latest"} - * Image configuration -- * image.command - * list - * ["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"] - * Container launch command -- * image.repository - * string - * "vllm/vllm-openai" - * Image repository -- * image.tag - * string - * "latest" - * Image tag -- * livenessProbe - * object - * {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":15,"periodSeconds":10} - * Liveness probe configuration -- * livenessProbe.failureThreshold - * int - * 3 - * Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not alive -- * livenessProbe.httpGet - * object - * {"path":"/health","port":8000} - * Configuration of the Kubelet http request on the server -- * livenessProbe.httpGet.path - * string - * "/health" - * Path to access on the HTTP server -- * livenessProbe.httpGet.port - * int - * 8000 - * Name or number of the port to access on the container, on which the server is listening -- * livenessProbe.initialDelaySeconds - * int - * 15 - * Number of seconds after the container has started before liveness probe is initiated -- * livenessProbe.periodSeconds - * int - * 10 - * How often (in seconds) to perform the liveness probe -- * maxUnavailablePodDisruptionBudget - * string - * "" - * Disruption Budget Configuration -- * readinessProbe - * object - * {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":5,"periodSeconds":5} - * Readiness probe configuration -- * readinessProbe.failureThreshold - * int - * 3 - * Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not ready -- * readinessProbe.httpGet - * object - * {"path":"/health","port":8000} - * Configuration of the Kubelet http request on the server -- * readinessProbe.httpGet.path - * string - * "/health" - * Path to access on the HTTP server -- * readinessProbe.httpGet.port - * int - * 8000 - * Name or number of the port to access on the container, on which the server is listening -- * readinessProbe.initialDelaySeconds - * int - * 5 - * Number of seconds after the container has started before readiness probe is initiated -- * readinessProbe.periodSeconds - * int - * 5 - * How often (in seconds) to perform the readiness probe -- * replicaCount - * int - * 1 - * Number of replicas -- * resources - * object - * {"limits":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1},"requests":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1}} - * Resource configuration -- * resources.limits."nvidia.com/gpu" - * int - * 1 - * Number of gpus used -- * resources.limits.cpu - * int - * 4 - * Number of CPUs -- * resources.limits.memory - * string - * "16Gi" - * CPU memory configuration -- * resources.requests."nvidia.com/gpu" - * int - * 1 - * Number of gpus used -- * resources.requests.cpu - * int - * 4 - * Number of CPUs -- * resources.requests.memory - * string - * "16Gi" - * CPU memory configuration -- * secrets - * object - * {} - * Secrets configuration -- * serviceName - * string - * - * Service name -- * servicePort - * int - * 80 - * Service port -- * labels.environment - * string - * test - * Environment name -- * labels.release - * string - * test - * Release name -::: diff --git a/docs/source/deployment/frameworks/index.md b/docs/source/deployment/frameworks/index.md deleted file mode 100644 index aa3394c377d5..000000000000 --- a/docs/source/deployment/frameworks/index.md +++ /dev/null @@ -1,19 +0,0 @@ -# Using other frameworks - -:::{toctree} -:maxdepth: 1 - -anything-llm -bentoml -cerebrium -chatbox -dstack -helm -lws -modal -open-webui -retrieval_augmented_generation -skypilot -streamlit -triton -::: diff --git a/docs/source/deployment/integrations/index.md b/docs/source/deployment/integrations/index.md deleted file mode 100644 index 410742b88c73..000000000000 --- a/docs/source/deployment/integrations/index.md +++ /dev/null @@ -1,11 +0,0 @@ -# External Integrations - -:::{toctree} -:maxdepth: 1 - -kserve -kubeai -llamastack -llmaz -production-stack -::: diff --git a/docs/source/design/kernel/paged_attention.md b/docs/source/design/kernel/paged_attention.md deleted file mode 100644 index e1770c822643..000000000000 --- a/docs/source/design/kernel/paged_attention.md +++ /dev/null @@ -1,529 +0,0 @@ -(design-paged-attention)= - -# vLLM Paged Attention - -- Currently, vLLM utilizes its own implementation of a multi-head query - attention kernel (`csrc/attention/attention_kernels.cu`). - This kernel is designed to be compatible with - vLLM's paged KV caches, where the key and value cache are stored in - separate blocks (note that this block concept differs from the GPU - thread block. So in a later document, I will refer to vLLM paged - attention block as "block", while refer to GPU thread block as - "thread block"). -- To achieve high performance, this kernel relies on a specially - designed memory layout and access method, specifically when threads - read data from global memory to shared memory. The purpose of this - document is to provide a high-level explanation of the kernel - implementation step by step, aiding those who wish to learn about the - vLLM multi-head query attention kernel. After going through this - document, users will likely have a better understanding and feel easier - to follow the actual implementation. -- Please note that this document may not cover all details, such as how - to calculate the correct index for the corresponding data or the dot - multiplication implementation. However, after reading this document - and becoming familiar with the high-level logic flow, it should be - easier for you to read the actual code and understand the details. - -## Inputs - -- The kernel function takes a list of arguments for the current thread - to perform its assigned work. The three most important arguments are - the input pointers `q`, `k_cache`, and `v_cache`, which point - to query, key, and value data on global memory that need to be read - and processed. The output pointer `out` points to global memory - where the result should be written. These four pointers actually - refer to multi-dimensional arrays, but each thread only accesses the - portion of data assigned to it. I have omitted all other runtime - parameters here for simplicity. - - ```cpp - template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE = 0> - __device__ void paged_attention_kernel( - ... // Other side args. - const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - ... // Other side args. - ) - ``` - -- There are also a list of template arguments above the function - signature that are determined during compilation time. `scalar_t` - represents the data type of the query, key, and value data elements, - such as FP16. `HEAD_SIZE` indicates the number of elements in each - head. `BLOCK_SIZE` refers to the number of tokens in each block. - `NUM_THREADS` denotes the number of threads in each thread block. - `PARTITION_SIZE` represents the number of tensor parallel GPUs (For - simplicity, we assume this is 0 and tensor parallel is disabled). - -- With these arguments, we need to perform a sequence of preparations. - This includes calculating the current head index, block index, and - other necessary variables. However, for now, we can ignore these - preparations and proceed directly to the actual calculations. It will - be easier to understand them once we grasp the entire flow. - -## Concepts - -- Just before we dive into the calculation flow, I want to describe a - few concepts that are needed for later sections. However, you may - skip this section and return later if you encounter any confusing - terminologies. -- **Sequence**: A sequence represents a client request. For example, - the data pointed to by `q` has a shape of - `[num_seqs, num_heads, head_size]`. That represents there are total - `num_seqs` of query sequence data are pointed by `q`. Since this - kernel is a single query attention kernel, each sequence only has one - query token. Hence, the `num_seqs` equals the total number of tokens - that are processed in the batch. -- **Context**: The context consists of the generated tokens from the - sequence. For instance, `["What", "is", "your"]` are the context - tokens, and the input query token is `"name"`. The model might - generate the token `"?"`. -- **Vec**: The vec is a list of elements that are fetched and - calculated together. For query and key data, the vec size - (`VEC_SIZE`) is determined so that each thread group can fetch and - calculate 16 bytes of data at a time. For value data, the vec size - (`V_VEC_SIZE`) is determined so that each thread can fetch and - calculate 16 bytes of data at a time. For example, if the - `scalar_t` is FP16 (2 bytes) and `THREAD_GROUP_SIZE` is 2, the - `VEC_SIZE` will be 4, while the `V_VEC_SIZE` will be 8. -- **Thread group**: The thread group is a small group of - threads(`THREAD_GROUP_SIZE`) that fetches and calculates one - query token and one key token at a time. Each thread handles only a - portion of the token data. The total number of elements processed by - one thread group is referred as `x`. For example, if the thread - group contains 2 threads and the head size is 8, then thread 0 - handles the query and key elements at index 0, 2, 4, 6, while thread - 1 handles the elements at index 1, 3, 5, 7. -- **Block**: The key and value cache data in vLLM are split into - blocks. Each block stores data for a fixed number(`BLOCK_SIZE`) - of tokens at one head. Each block may contain only a portion of the - whole context tokens. For example, if the block size is 16 and the - head size is 128, then for one head, one block can store 16 * 128 = - 2048 elements. -- **Warp**: A warp is a group of 32 threads(`WARP_SIZE`) that - execute simultaneously on a stream multiprocessor (SM). In this - kernel, each warp processes the calculation between one query token - and key tokens of one entire block at a time (it may process multiple - blocks in multiple iterations). For example, if there are 4 warps and - 6 blocks for one context, the assignment would be like warp 0 handles - the 0th, 4th blocks, warp 1 handles the 1st, 5th blocks, warp 2 - handles the 2nd block and warp 3 handles the 3rd block. -- **Thread block**: A thread block is a group of - threads(`NUM_THREADS`) that can access the same shared memory. - Each thread block contains multiple warps(`NUM_WARPS`), and in - this kernel, each thread block processes the calculation between one - query token and key tokens of a whole context. -- **Grid**: A grid is a collection of thread blocks and defines the - shape of the collection. In this kernel, the shape is - `(num_heads, num_seqs, max_num_partitions)`. Therefore, each thread - block only handles the calculation for one head, one sequence, and - one partition. - -## Query - -- This section will introduce how query data is stored in memory and - fetched by each thread. As mentioned above, each thread group fetches - one query token data, while each thread itself only handles a part of - one query token data. Within each warp, every thread group will fetch - the same query token data, but will multiply it with different key - token data. - - ```cpp - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - ``` - - :::{figure} ../../assets/kernel/query.png - :align: center - :alt: query - :width: 70% - - Query data of one token at one head - ::: - -- Each thread defines its own `q_ptr` which points to the assigned - query token data on global memory. For example, if `VEC_SIZE` is 4 - and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains - total of 128 elements divided into 128 / 4 = 32 vecs. - - :::{figure} ../../assets/kernel/q_vecs.png - :align: center - :alt: q_vecs - :width: 70% - - `q_vecs` for one thread group - ::: - - ```cpp - __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; - ``` - -- Next, we need to read the global memory data pointed to by `q_ptr` - into shared memory as `q_vecs`. It is important to note that each - vecs is assigned to a different row. For example, if the - `THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs, - while thread 1 handles the 1st row vecs. By reading the query data in - this way, neighboring threads like thread 0 and thread 1 can read - neighbor memory, achieving the memory coalescing to improve - performance. - -## Key - -- Similar to the "Query" section, this section introduces memory layout - and assignment for keys. While each thread group only handle one - query token one kernel run, it may handle multiple key tokens across - multiple iterations. Meanwhile, each warp will process multiple blocks - of key tokens in multiple iterations, ensuring that all context - tokens are processed by the entire thread group after the kernel run. - In this context, "handle" refers to performing the dot multiplication - between query data and key data. - - ```cpp - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; - ``` - -- Unlike to `q_ptr`, `k_ptr` in each thread will point to different - key token at different iterations. As shown above, that `k_ptr` - points to key token data based on `k_cache` at assigned block, - assigned head and assigned token. - - :::{figure} ../../assets/kernel/key.png - :align: center - :alt: key - :width: 70% - - Key data of all context tokens at one head - ::: - -- The diagram above illustrates the memory layout for key data. It - assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is - 8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each - rectangle represents all the elements for one key token at one head, - which will be processed by one thread group. The left half shows the - total 16 blocks of key token data for warp 0, while the right half - represents the remaining key token data for other warps or - iterations. Inside each rectangle, there are a total 32 vecs (128 - elements for one token) that will be processed by 2 threads (one - thread group) separately. - - :::{figure} ../../assets/kernel/k_vecs.png - :align: center - :alt: k_vecs - :width: 70% - - `k_vecs` for one thread - ::: - - ```cpp - K_vec k_vecs[NUM_VECS_PER_THREAD] - ``` - -- Next, we need to read the key token data from `k_ptr` and store - them on register memory as `k_vecs`. We use register memory for - `k_vecs` because it will only be accessed by one thread once, - whereas `q_vecs` will be accessed by multiple threads multiple - times. Each `k_vecs` will contain multiple vectors for later - calculation. Each vec will be set at each inner iteration. The - assignment of vecs allows neighboring threads in a warp to read - neighboring memory together, which again promotes the memory - coalescing. For instance, thread 0 will read vec 0, while thread 1 - will read vec 1. In the next inner loop, thread 0 will read vec 2, - while thread 1 will read vec 3, and so on. - -- You may still be a little confused about the overall flow. Don't - worry, please keep reading the next "QK" section. It will illustrate - the query and key calculation flow in a clearer and higher-level - manner. - -## QK - -- As shown the pseudo code below, before the entire for loop block, we - fetch the query data for one token and store it in `q_vecs`. Then, - in the outer for loop, we iterate through different `k_ptrs` that - point to different tokens and prepare the `k_vecs` in the inner for - loop. Finally, we perform the dot multiplication between the - `q_vecs` and each `k_vecs`. - - ```cpp - q_vecs = ... - for ... { - k_ptr = ... - for ... { - k_vecs[i] = ... - } - ... - float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); - } - ``` - -- As mentioned before, for each thread, it only fetches part of the - query and key token data at a time. However, there will be a cross - thread group reduction happen in the `Qk_dot<>::dot` . So `qk` - returned here is not just between part of the query and key token dot - multiplication, but actually a full result between entire query and - key token data. - -- For example, if the value of `HEAD_SIZE` is 128 and - `THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain - total 64 elements. However, the returned `qk` is actually the - result of dot multiplication between 128 query elements and 128 key - elements. If you want to learn more about the details of the dot - multiplication and reduction, you may refer to the implementation of - `Qk_dot<>::dot`. However, for the sake of simplicity, I will not - cover it in this document. - -## Softmax - -- Next, we need to calculate the normalized softmax for all `qk`s, - as shown above, where each $x$ represents a `qk`. To do this, - we must obtain the reduced value of `qk_max`($m(x)$) and - the `exp_sum`($\ell(x)$) of all `qk`s. The reduction - should be performed across the entire thread block, encompassing - results between the query token and all context key tokens. - - :::{math} - :nowrap: true - - \begin{gather*} - m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ - \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} - \end{gather*} - ::: - -### `qk_max` and `logits` - -- Just right after we get the `qk` result, we can set the temporary - `logits` result with `qk` (In the end, the `logits` should - store the normalized softmax result). Also we can compare and collect - the `qk_max` for all `qk`s that are calculated by current - thread group. - - ```cpp - if (thread_group_offset == 0) { - const bool mask = token_idx >= context_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - ``` - -- Please note that the `logits` here is on shared memory, so each - thread group will set the fields for its own assigned context tokens. - Overall, the size of logits should be number of context tokens. - - ```cpp - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - ``` - -- Then we need to get the reduced `qk_max` across each warp. The main - idea is to make threads in warp to communicate with each other and - get the final max `qk` . - - ```cpp - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - qk_max = VLLM_SHFL_SYNC(qk_max, 0); - ``` - -- Finally, we can get the reduced `qk_max` from whole thread block by - compare the `qk_max` from all warps in this thread block. Then we - need to broadcast the final result to each thread. - -### `exp_sum` - -- Similar to `qk_max`, we need to get the reduced sum value from the - entire thread block too. - - ```cpp - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - ... - exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum); - ``` - -- Firstly, sum all exp values from each thread group, and meanwhile, - convert each entry of `logits` from `qk` to `exp(qk - qk_max)`. - Please note, the `qk_max` here is already the max `qk` across the - whole thread block. And then we can do reduction for `exp_sum` - across whole thread block just like the `qk_max`. - - ```cpp - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - ``` - -- Finally, with the reduced `qk_max` and `exp_sum`, we can obtain - the final normalized softmax result as `logits`. This `logits` - variable will be used for dot multiplication with the value data in - later steps. Now, it should store the normalized softmax result of - `qk` for all assigned context tokens. - -## Value - -:::{figure} ../../assets/kernel/value.png -:align: center -:alt: value -:width: 70% - -Value data of all context tokens at one head -::: - -:::{figure} ../../assets/kernel/logits_vec.png -:align: center -:alt: logits_vec -:width: 50% - -`logits_vec` for one thread -::: - -:::{figure} ../../assets/kernel/v_vec.png -:align: center -:alt: v_vec -:width: 70% - -List of `v_vec` for one thread -::: - -- Now we need to retrieve the value data and perform dot multiplication - with `logits`. Unlike query and key, there is no thread group - concept for value data. As shown in diagram, different from key token - memory layout, elements from the same column correspond to the same - value token. For one block of value data, there are `HEAD_SIZE` of - rows and `BLOCK_SIZE` of columns that are split into multiple - `v_vecs`. - -- Each thread always fetches `V_VEC_SIZE` elements from the same - `V_VEC_SIZE` of tokens at a time. As a result, a single thread - retrieves multiple `v_vec`s from different rows and the same - columns through multiple inner iterations. For each `v_vec`, it - needs to be dot multiplied with the corresponding `logits_vec`, - which is also `V_VEC_SIZE` elements from `logits`. Overall, with - multiple inner iterations, each warp will process one block of value - tokens. And with multiple outer iterations, the whole context value - tokens are processed - - ```cpp - float accs[NUM_ROWS_PER_THREAD]; - for ... { // Iteration over different blocks. - logits_vec = ... - for ... { // Iteration over different rows. - v_vec = ... - ... - accs[i] += dot(logits_vec, v_vec); - } - } - ``` - -- As shown in the above pseudo code, in the outer loop, similar to - `k_ptr`, `logits_vec` iterates over different blocks and reads - `V_VEC_SIZE` elements from `logits`. In the inner loop, each - thread reads `V_VEC_SIZE` elements from the same tokens as a - `v_vec` and performs dot multiplication. It is important to note - that in each inner iteration, the thread fetches different head - position elements for the same tokens. The dot result is then - accumulated in `accs`. Therefore, each entry of `accs` is mapped - to a head position assigned to the current thread. - -- For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each - thread fetches 8 value elements for 8 tokens at a time. Each element - is from different tokens at the same head position. If `HEAD_SIZE` - is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to - fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are - a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle - a whole block of value tokens. And each `accs` in each thread - contains 8 elements that accumulated at 8 different head positions. - For the thread 0, the `accs` variable will have 8 elements, which - are 0th, 32th โ€ฆ 224th elements of a value head that are accumulated - from all assigned 8 tokens. - -## LV - -- Now, we need to perform reduction for `accs` within each warp. This - process allows each thread to accumulate the `accs` for the - assigned head positions of all tokens in one block. - - ```cpp - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += VLLM_SHFL_XOR_SYNC(acc, mask); - } - accs[i] = acc; - } - ``` - -- Next, we perform reduction for `accs` across all warps, allowing - each thread to have the accumulation of `accs` for the assigned - head positions of all context tokens. Please note that each `accs` - in every thread only stores the accumulation for a portion of - elements of the entire head for all context tokens. However, overall, - all results for output have been calculated but are just stored in - different thread register memory. - - ```cpp - float* out_smem = reinterpret_cast<float*>(shared_mem); - for (int i = NUM_WARPS; i > 1; i /= 2) { - // Upper warps write to shared memory. - ... - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - ... - dst[row_idx] = accs[i]; - } - - // Lower warps update the output. - const float* src = &out_smem[warp_idx * HEAD_SIZE]; - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - ... - accs[i] += src[row_idx]; - } - - // Write out the accs. - } - ``` - -## Output - -- Now we can write all of calculated result from local register memory - to final output global memory. - - ```cpp - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; - ``` - -- First, we need to define the `out_ptr` variable, which points to - the start address of the assigned sequence and assigned head. - - ```cpp - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - from_float(*(out_ptr + row_idx), accs[i]); - } - } - ``` - -- Finally, we need to iterate over different assigned head positions - and write out the corresponding accumulated result based on the - `out_ptr`. diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md deleted file mode 100644 index 8865d26deaed..000000000000 --- a/docs/source/features/compatibility_matrix.md +++ /dev/null @@ -1,476 +0,0 @@ -(compatibility-matrix)= - -# Compatibility Matrix - -The tables below show mutually exclusive features and the support on some hardware. - -The symbols used have the following meanings: - -- โœ… = Full compatibility -- ๐ŸŸ  = Partial compatibility -- โŒ = No compatibility - -:::{note} -Check the โŒ or ๐ŸŸ  with links to see tracking issue for unsupported feature/hardware combination. -::: - -## Feature x Feature - -:::{raw} html -<style> - /* Make smaller to try to improve readability */ - td { - font-size: 0.8rem; - text-align: center; - } - - th { - text-align: center; - font-size: 0.8rem; - } -</style> -::: - -:::{list-table} -:header-rows: 1 -:stub-columns: 1 -:widths: auto -:class: vertical-table-header - -- * Feature - * [CP](#chunked-prefill) - * [APC](#automatic-prefix-caching) - * [LoRA](#lora-adapter) - * <abbr title="Prompt Adapter">prmpt adptr</abbr> - * [SD](#spec-decode) - * CUDA graph - * <abbr title="Pooling Models">pooling</abbr> - * <abbr title="Encoder-Decoder Models">enc-dec</abbr> - * <abbr title="Logprobs">logP</abbr> - * <abbr title="Prompt Logprobs">prmpt logP</abbr> - * <abbr title="Async Output Processing">async output</abbr> - * multi-step - * <abbr title="Multimodal Inputs">mm</abbr> - * best-of - * beam-search - * <abbr title="Guided Decoding">guided dec</abbr> -- * [CP](#chunked-prefill) - * โœ… - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * -- * [APC](#automatic-prefix-caching) - * โœ… - * โœ… - * - * - * - * - * - * - * - * - * - * - * - * - * - * -- * [LoRA](#lora-adapter) - * โœ… - * โœ… - * โœ… - * - * - * - * - * - * - * - * - * - * - * - * - * -- * <abbr title="Prompt Adapter">prmpt adptr</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * - * - * - * - * - * - * - * - * - * - * - * -- * [SD](#spec-decode) - * โœ… - * โœ… - * โŒ - * โœ… - * โœ… - * - * - * - * - * - * - * - * - * - * - * -- * CUDA graph - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * - * - * - * - * - * - * - * - * - * -- * <abbr title="Pooling Models">pooling</abbr> - * โŒ - * โŒ - * โŒ - * โŒ - * โŒ - * โŒ - * โœ… - * - * - * - * - * - * - * - * - * -- * <abbr title="Encoder-Decoder Models">enc-dec</abbr> - * โŒ - * [โŒ](gh-issue:7366) - * โŒ - * โŒ - * [โŒ](gh-issue:7366) - * โœ… - * โœ… - * โœ… - * - * - * - * - * - * - * - * -- * <abbr title="Logprobs">logP</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โŒ - * โœ… - * โœ… - * - * - * - * - * - * - * -- * <abbr title="Prompt Logprobs">prmpt logP</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โŒ - * โœ… - * โœ… - * โœ… - * - * - * - * - * - * -- * <abbr title="Async Output Processing">async output</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โŒ - * โœ… - * โŒ - * โŒ - * โœ… - * โœ… - * โœ… - * - * - * - * - * -- * multi-step - * โŒ - * โœ… - * โŒ - * โœ… - * โŒ - * โœ… - * โŒ - * โŒ - * โœ… - * โœ… - * โœ… - * โœ… - * - * - * - * -- * <abbr title="Multimodal Inputs">mm</abbr> - * โœ… - * [๐ŸŸ ](gh-pr:8348) - * [๐ŸŸ ](gh-pr:4194) - * โ” - * โ” - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โ” - * โœ… - * - * - * -- * best-of - * โœ… - * โœ… - * โœ… - * โœ… - * [โŒ](gh-issue:6137) - * โœ… - * โŒ - * โœ… - * โœ… - * โœ… - * โ” - * [โŒ](gh-issue:7968) - * โœ… - * โœ… - * - * -- * beam-search - * โœ… - * โœ… - * โœ… - * โœ… - * [โŒ](gh-issue:6137) - * โœ… - * โŒ - * โœ… - * โœ… - * โœ… - * โ” - * [โŒ](gh-issue:7968) - * โ” - * โœ… - * โœ… - * -- * <abbr title="Guided Decoding">guided dec</abbr> - * โœ… - * โœ… - * โ” - * โ” - * [โŒ](gh-issue:11484) - * โœ… - * โŒ - * โ” - * โœ… - * โœ… - * โœ… - * [โŒ](gh-issue:9893) - * โ” - * โœ… - * โœ… - * โœ… -::: - -(feature-x-hardware)= - -## Feature x Hardware - -:::{list-table} -:header-rows: 1 -:stub-columns: 1 -:widths: auto - -- * Feature - * Volta - * Turing - * Ampere - * Ada - * Hopper - * CPU - * AMD -- * [CP](#chunked-prefill) - * [โŒ](gh-issue:2729) - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * [APC](#automatic-prefix-caching) - * [โŒ](gh-issue:3687) - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * [LoRA](#lora-adapter) - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * <abbr title="Prompt Adapter">prmpt adptr</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * [โŒ](gh-issue:8475) - * โœ… -- * [SD](#spec-decode) - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * CUDA graph - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โŒ - * โœ… -- * <abbr title="Pooling Models">pooling</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โ” -- * <abbr title="Encoder-Decoder Models">enc-dec</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โŒ -- * <abbr title="Multimodal Inputs">mm</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * <abbr title="Logprobs">logP</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * <abbr title="Prompt Logprobs">prmpt logP</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * <abbr title="Async Output Processing">async output</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โŒ - * โŒ -- * multi-step - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * [โŒ](gh-issue:8477) - * โœ… -- * best-of - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * beam-search - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -- * <abbr title="Guided Decoding">guided dec</abbr> - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… - * โœ… -::: diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md deleted file mode 100644 index 7ad46b7094ee..000000000000 --- a/docs/source/features/quantization/index.md +++ /dev/null @@ -1,24 +0,0 @@ -(quantization-index)= - -# Quantization - -Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices. - -:::{toctree} -:caption: Contents -:maxdepth: 1 - -supported_hardware -auto_awq -bnb -bitblas -gguf -gptqmodel -int4 -int8 -fp8 -modelopt -quark -quantized_kvcache -torchao -::: diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md deleted file mode 100644 index f8af1ba60b12..000000000000 --- a/docs/source/features/quantization/supported_hardware.md +++ /dev/null @@ -1,153 +0,0 @@ -(quantization-supported-hardware)= - -# Supported Hardware - -The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: - -:::{list-table} -:header-rows: 1 -:widths: 20 8 8 8 8 8 8 8 8 8 8 - -- * Implementation - * Volta - * Turing - * Ampere - * Ada - * Hopper - * AMD GPU - * Intel GPU - * x86 CPU - * AWS Inferentia - * Google TPU -- * AWQ - * โŒ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ -- * GPTQ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ -- * Marlin (GPTQ/AWQ/FP8) - * โŒ - * โŒ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ - * โŒ - * โŒ - * โŒ -- * INT8 (W8A8) - * โŒ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ - * โœ…๏ธŽ - * โŒ - * โœ…๏ธŽ -- * FP8 (W8A8) - * โŒ - * โŒ - * โŒ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ - * โŒ - * โŒ -- * BitBLAS (GPTQ) - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ - * โŒ - * โŒ - * โŒ -- * AQLM - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ - * โŒ - * โŒ - * โŒ -- * bitsandbytes - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ - * โŒ - * โŒ - * โŒ -- * DeepSpeedFP - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ - * โŒ - * โŒ - * โŒ -- * GGUF - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โŒ - * โŒ - * โŒ - * โŒ -- * modelopt - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ๏ธŽ - * โŒ - * โŒ - * โŒ - * โŒ - * โŒ -::: - -- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. -- โœ…๏ธŽ indicates that the quantization method is supported on the specified hardware. -- โŒ indicates that the quantization method is not supported on the specified hardware. - -:::{note} -This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. - -For the most up-to-date information on hardware support and quantization methods, please refer to <gh-dir:vllm/model_executor/layers/quantization> or consult with the vLLM development team. -::: diff --git a/docs/source/generate_examples.py b/docs/source/generate_examples.py deleted file mode 100644 index f77dbefb0a01..000000000000 --- a/docs/source/generate_examples.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import itertools -import re -from dataclasses import dataclass, field -from pathlib import Path - -ROOT_DIR = Path(__file__).parent.parent.parent.resolve() -ROOT_DIR_RELATIVE = '../../../..' -EXAMPLE_DIR = ROOT_DIR / "examples" -EXAMPLE_DOC_DIR = ROOT_DIR / "docs/source/getting_started/examples" - - -def fix_case(text: str) -> str: - subs = { - "api": "API", - "cli": "CLI", - "cpu": "CPU", - "llm": "LLM", - "mae": "MAE", - "tpu": "TPU", - "aqlm": "AQLM", - "gguf": "GGUF", - "lora": "LoRA", - "rlhf": "RLHF", - "vllm": "vLLM", - "openai": "OpenAI", - "lmcache": "LMCache", - "multilora": "MultiLoRA", - "mlpspeculator": "MLPSpeculator", - r"fp\d+": lambda x: x.group(0).upper(), # e.g. fp16, fp32 - r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16 - } - for pattern, repl in subs.items(): - text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE) - return text - - -@dataclass -class Index: - """ - Index class to generate a structured document index. - - Attributes: - path (Path): The path save the index file to. - title (str): The title of the index. - description (str): A brief description of the index. - caption (str): An optional caption for the table of contents. - maxdepth (int): The maximum depth of the table of contents. Defaults to 1. - documents (list[str]): A list of document paths to include in the index. Defaults to an empty list. - - Methods: - generate() -> str: - Generates the index content as a string in the specified format. - """ # noqa: E501 - path: Path - title: str - description: str - caption: str - maxdepth: int = 1 - documents: list[str] = field(default_factory=list) - - def generate(self) -> str: - content = f"# {self.title}\n\n{self.description}\n\n" - content += ":::{toctree}\n" - content += f":caption: {self.caption}\n:maxdepth: {self.maxdepth}\n" - content += "\n".join(self.documents) + "\n:::\n" - return content - - -@dataclass -class Example: - """ - Example class for generating documentation content from a given path. - - Attributes: - path (Path): The path to the main directory or file. - category (str): The category of the document. - main_file (Path): The main file in the directory. - other_files (list[Path]): list of other files in the directory. - title (str): The title of the document. - - Methods: - __post_init__(): Initializes the main_file, other_files, and title attributes. - determine_main_file() -> Path: Determines the main file in the given path. - determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file. - determine_title() -> str: Determines the title of the document. - generate() -> str: Generates the documentation content. - """ # noqa: E501 - path: Path - category: str = None - main_file: Path = field(init=False) - other_files: list[Path] = field(init=False) - title: str = field(init=False) - - def __post_init__(self): - self.main_file = self.determine_main_file() - self.other_files = self.determine_other_files() - self.title = self.determine_title() - - def determine_main_file(self) -> Path: - """ - Determines the main file in the given path. - If the path is a file, it returns the path itself. Otherwise, it searches - for Markdown files (*.md) in the directory and returns the first one found. - Returns: - Path: The main file path, either the original path if it's a file or the first - Markdown file found in the directory. - Raises: - IndexError: If no Markdown files are found in the directory. - """ # noqa: E501 - return self.path if self.path.is_file() else list( - self.path.glob("*.md")).pop() - - def determine_other_files(self) -> list[Path]: - """ - Determine other files in the directory excluding the main file. - - This method checks if the given path is a file. If it is, it returns an empty list. - Otherwise, it recursively searches through the directory and returns a list of all - files that are not the main file. - - Returns: - list[Path]: A list of Path objects representing the other files in the directory. - """ # noqa: E501 - if self.path.is_file(): - return [] - is_other_file = lambda file: file.is_file() and file != self.main_file - return [file for file in self.path.rglob("*") if is_other_file(file)] - - def determine_title(self) -> str: - return fix_case(self.path.stem.replace("_", " ").title()) - - def generate(self) -> str: - # Convert the path to a relative path from __file__ - make_relative = lambda path: ROOT_DIR_RELATIVE / path.relative_to( - ROOT_DIR) - - content = f"Source <gh-file:{self.path.relative_to(ROOT_DIR)}>.\n\n" - include = "include" if self.main_file.suffix == ".md" else \ - "literalinclude" - if include == "literalinclude": - content += f"# {self.title}\n\n" - content += f":::{{{include}}} {make_relative(self.main_file)}\n" - if include == "literalinclude": - content += f":language: {self.main_file.suffix[1:]}\n" - content += ":::\n\n" - - if not self.other_files: - return content - - content += "## Example materials\n\n" - for file in sorted(self.other_files): - include = "include" if file.suffix == ".md" else "literalinclude" - content += f":::{{admonition}} {file.relative_to(self.path)}\n" - content += ":class: dropdown\n\n" - content += f":::{{{include}}} {make_relative(file)}\n:::\n" - content += ":::\n\n" - - return content - - -def generate_examples(): - # Create the EXAMPLE_DOC_DIR if it doesn't exist - if not EXAMPLE_DOC_DIR.exists(): - EXAMPLE_DOC_DIR.mkdir(parents=True) - - # Create empty indices - examples_index = Index( - path=EXAMPLE_DOC_DIR / "examples_index.md", - title="Examples", - description= - "A collection of examples demonstrating usage of vLLM.\nAll documented examples are autogenerated using <gh-file:docs/source/generate_examples.py> from examples found in <gh-file:examples>.", # noqa: E501 - caption="Examples", - maxdepth=2) - # Category indices stored in reverse order because they are inserted into - # examples_index.documents at index 0 in order - category_indices = { - "other": - Index( - path=EXAMPLE_DOC_DIR / "examples_other_index.md", - title="Other", - description= - "Other examples that don't strongly fit into the online or offline serving categories.", # noqa: E501 - caption="Examples", - ), - "online_serving": - Index( - path=EXAMPLE_DOC_DIR / "examples_online_serving_index.md", - title="Online Serving", - description= - "Online serving examples demonstrate how to use vLLM in an online setting, where the model is queried for predictions in real-time.", # noqa: E501 - caption="Examples", - ), - "offline_inference": - Index( - path=EXAMPLE_DOC_DIR / "examples_offline_inference_index.md", - title="Offline Inference", - description= - "Offline inference examples demonstrate how to use vLLM in an offline setting, where the model is queried for predictions in batches. We recommend starting with <project:basic.md>.", # noqa: E501 - caption="Examples", - ), - } - - examples = [] - glob_patterns = ["*.py", "*.md", "*.sh"] - # Find categorised examples - for category in category_indices: - category_dir = EXAMPLE_DIR / category - globs = [category_dir.glob(pattern) for pattern in glob_patterns] - for path in itertools.chain(*globs): - examples.append(Example(path, category)) - # Find examples in subdirectories - for path in category_dir.glob("*/*.md"): - examples.append(Example(path.parent, category)) - # Find uncategorised examples - globs = [EXAMPLE_DIR.glob(pattern) for pattern in glob_patterns] - for path in itertools.chain(*globs): - examples.append(Example(path)) - # Find examples in subdirectories - for path in EXAMPLE_DIR.glob("*/*.md"): - # Skip categorised examples - if path.parent.name in category_indices: - continue - examples.append(Example(path.parent)) - - # Generate the example documentation - for example in sorted(examples, key=lambda e: e.path.stem): - doc_path = EXAMPLE_DOC_DIR / f"{example.path.stem}.md" - with open(doc_path, "w+") as f: - f.write(example.generate()) - # Add the example to the appropriate index - index = category_indices.get(example.category, examples_index) - index.documents.append(example.path.stem) - - # Generate the index files - for category_index in category_indices.values(): - if category_index.documents: - examples_index.documents.insert(0, category_index.path.name) - with open(category_index.path, "w+") as f: - f.write(category_index.generate()) - - with open(examples_index.path, "w+") as f: - f.write(examples_index.generate()) diff --git a/docs/source/getting_started/installation.md b/docs/source/getting_started/installation.md deleted file mode 100644 index 44134bf01b76..000000000000 --- a/docs/source/getting_started/installation.md +++ /dev/null @@ -1,28 +0,0 @@ -(installation-index)= - -# Installation - -vLLM supports the following hardware platforms: - -:::{toctree} -:maxdepth: 1 -:hidden: - -installation/gpu -installation/cpu -installation/ai_accelerator -::: - -- <project:installation/gpu.md> - - NVIDIA CUDA - - AMD ROCm - - Intel XPU -- <project:installation/cpu.md> - - Intel/AMD x86 - - ARM AArch64 - - Apple silicon - - IBM Z (S390X) -- <project:installation/ai_accelerator.md> - - Google TPU - - Intel Gaudi - - AWS Neuron diff --git a/docs/source/getting_started/installation/ai_accelerator.md b/docs/source/getting_started/installation/ai_accelerator.md deleted file mode 100644 index 0a207af1a4c7..000000000000 --- a/docs/source/getting_started/installation/ai_accelerator.md +++ /dev/null @@ -1,299 +0,0 @@ -# Other AI accelerators - -vLLM is a Python library that supports the following AI accelerators. Select your AI accelerator type to see vendor specific instructions: - -:::::{tab-set} -:sync-group: device - -::::{tab-item} Google TPU -:selected: -:sync: tpu - -:::{include} ai_accelerator/tpu.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: - -:::: - -::::{tab-item} Intel Gaudi -:sync: hpu-gaudi - -:::{include} ai_accelerator/hpu-gaudi.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: - -:::: - -::::{tab-item} AWS Neuron -:sync: neuron - -:::{include} ai_accelerator/neuron.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: - -:::: - -::::: - -## Requirements - -:::::{tab-set} -:sync-group: device - -::::{tab-item} Google TPU -:sync: tpu - -:::{include} ai_accelerator/tpu.inc.md -:start-after: "## Requirements" -:end-before: "## Configure a new environment" -::: - -:::: - -::::{tab-item} Intel Gaudi -:sync: hpu-gaudi - -:::{include} ai_accelerator/hpu-gaudi.inc.md -:start-after: "## Requirements" -:end-before: "## Configure a new environment" -::: - -:::: - -::::{tab-item} AWS Neuron -:sync: neuron - -:::{include} ai_accelerator/neuron.inc.md -:start-after: "## Requirements" -:end-before: "## Configure a new environment" -::: - -:::: - -::::: - -## Configure a new environment - -:::::{tab-set} -:sync-group: device - -::::{tab-item} Google TPU -:sync: tpu - -:::{include} ai_accelerator/tpu.inc.md -:start-after: "## Configure a new environment" -:end-before: "## Set up using Python" -::: - -:::: - -::::{tab-item} Intel Gaudi -:sync: hpu-gaudi - -:::{include} ai_accelerator/hpu-gaudi.inc.md -:start-after: "## Configure a new environment" -:end-before: "## Set up using Python" -::: - -:::: - -::::{tab-item} AWS Neuron -:sync: neuron - -:::{include} ai_accelerator/neuron.inc.md -:start-after: "## Configure a new environment" -:end-before: "## Set up using Python" -::: - -:::: - -::::: - -## Set up using Python - -### Pre-built wheels - -:::::{tab-set} -:sync-group: device - -::::{tab-item} Google TPU -:sync: tpu - -:::{include} ai_accelerator/tpu.inc.md -:start-after: "### Pre-built wheels" -:end-before: "### Build wheel from source" -::: - -:::: - -::::{tab-item} Intel Gaudi -:sync: hpu-gaudi - -:::{include} ai_accelerator/hpu-gaudi.inc.md -:start-after: "### Pre-built wheels" -:end-before: "### Build wheel from source" -::: - -:::: - -::::{tab-item} AWS Neuron -:sync: neuron - -:::{include} ai_accelerator/neuron.inc.md -:start-after: "### Pre-built wheels" -:end-before: "### Build wheel from source" -::: - -:::: - -::::: - -### Build wheel from source - -:::::{tab-set} -:sync-group: device - -::::{tab-item} Google TPU -:sync: tpu - -:::{include} ai_accelerator/tpu.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: - -:::: - -::::{tab-item} Intel Gaudi -:sync: hpu-gaudi - -:::{include} ai_accelerator/hpu-gaudi.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: - -:::: - -::::{tab-item} AWS Neuron -:sync: neuron - -:::{include} ai_accelerator/neuron.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: - -:::: - -::::: - -## Set up using Docker - -### Pre-built images - -:::::{tab-set} -:sync-group: device - -::::{tab-item} Google TPU -:sync: tpu - -:::{include} ai_accelerator/tpu.inc.md -:start-after: "### Pre-built images" -:end-before: "### Build image from source" -::: - -:::: - -::::{tab-item} Intel Gaudi -:sync: hpu-gaudi - -:::{include} ai_accelerator/hpu-gaudi.inc.md -:start-after: "### Pre-built images" -:end-before: "### Build image from source" -::: - -:::: - -::::{tab-item} AWS Neuron -:sync: neuron - -:::{include} ai_accelerator/neuron.inc.md -:start-after: "### Pre-built images" -:end-before: "### Build image from source" -::: - -:::: - -::::: - -### Build image from source - -:::::{tab-set} -:sync-group: device - -::::{tab-item} Google TPU -:sync: tpu - -:::{include} ai_accelerator/tpu.inc.md -:start-after: "### Build image from source" -:end-before: "## Extra information" -::: - -:::: - -::::{tab-item} Intel Gaudi -:sync: hpu-gaudi - -:::{include} ai_accelerator/hpu-gaudi.inc.md -:start-after: "### Build image from source" -:end-before: "## Extra information" -::: - -:::: - -::::{tab-item} AWS Neuron -:sync: neuron - -:::{include} ai_accelerator/neuron.inc.md -:start-after: "### Build image from source" -:end-before: "## Extra information" -::: - -:::: - -::::: - -## Extra information - -:::::{tab-set} -:sync-group: device - -::::{tab-item} Google TPU -:sync: tpu - -:::{include} ai_accelerator/tpu.inc.md -:start-after: "## Extra information" -::: - -:::: - -::::{tab-item} Intel Gaudi -:sync: hpu-gaudi - -:::{include} ai_accelerator/hpu-gaudi.inc.md -:start-after: "## Extra information" -::: - -:::: - -::::{tab-item} AWS Neuron -:sync: neuron - -:::{include} ai_accelerator/neuron.inc.md -:start-after: "## Extra information" -::: - -:::: - -::::: diff --git a/docs/source/getting_started/installation/cpu/arm.inc.md b/docs/source/getting_started/installation/cpu/arm.inc.md deleted file mode 100644 index e7d8d60630dc..000000000000 --- a/docs/source/getting_started/installation/cpu/arm.inc.md +++ /dev/null @@ -1,34 +0,0 @@ -# Installation - -vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform. - -ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. - -:::{attention} -There are no pre-built wheels or images for this device, so you must build vLLM from source. -::: - -## Requirements - -- OS: Linux -- Compiler: `gcc/g++ >= 12.3.0` (optional, recommended) -- Instruction Set Architecture (ISA): NEON support is required - -## Set up using Python - -### Pre-built wheels - -### Build wheel from source - -:::{include} cpu/build.inc.md -::: - -Testing has been conducted on AWS Graviton3 instances for compatibility. - -## Set up using Docker - -### Pre-built images - -### Build image from source - -## Extra information diff --git a/docs/source/getting_started/installation/cpu/x86.inc.md b/docs/source/getting_started/installation/cpu/x86.inc.md deleted file mode 100644 index 9ae2035db543..000000000000 --- a/docs/source/getting_started/installation/cpu/x86.inc.md +++ /dev/null @@ -1,41 +0,0 @@ -# Installation - -vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. - -:::{attention} -There are no pre-built wheels or images for this device, so you must build vLLM from source. -::: - -## Requirements - -- OS: Linux -- Compiler: `gcc/g++ >= 12.3.0` (optional, recommended) -- Instruction Set Architecture (ISA): AVX512 (optional, recommended) - -:::{tip} -[Intel Extension for PyTorch (IPEX)](https://github.com/intel/intel-extension-for-pytorch) extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. -::: - -## Set up using Python - -### Pre-built wheels - -### Build wheel from source - -:::{include} cpu/build.inc.md -::: - -:::{note} -- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, which brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16. -- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable `VLLM_CPU_AVX512BF16=1` before the building. -::: - -## Set up using Docker - -### Pre-built images - -See [https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo) - -### Build image from source - -## Extra information diff --git a/docs/source/getting_started/installation/gpu.md b/docs/source/getting_started/installation/gpu.md deleted file mode 100644 index 22db992354fb..000000000000 --- a/docs/source/getting_started/installation/gpu.md +++ /dev/null @@ -1,301 +0,0 @@ -# GPU - -vLLM is a Python library that supports the following GPU variants. Select your GPU type to see vendor specific instructions: - -:::::{tab-set} -:sync-group: device - -::::{tab-item} NVIDIA CUDA -:selected: -:sync: cuda - -:::{include} gpu/cuda.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: - -:::: - -::::{tab-item} AMD ROCm -:sync: rocm - -:::{include} gpu/rocm.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: - -:::: - -::::{tab-item} Intel XPU -:sync: xpu - -:::{include} gpu/xpu.inc.md -:start-after: "# Installation" -:end-before: "## Requirements" -::: - -:::: - -::::: - -## Requirements - -- OS: Linux -- Python: 3.9 -- 3.12 - -:::::{tab-set} -:sync-group: device - -::::{tab-item} NVIDIA CUDA -:sync: cuda - -:::{include} gpu/cuda.inc.md -:start-after: "## Requirements" -:end-before: "## Set up using Python" -::: - -:::: - -::::{tab-item} AMD ROCm -:sync: rocm - -:::{include} gpu/rocm.inc.md -:start-after: "## Requirements" -:end-before: "## Set up using Python" -::: - -:::: - -::::{tab-item} Intel XPU -:sync: xpu - -:::{include} gpu/xpu.inc.md -:start-after: "## Requirements" -:end-before: "## Set up using Python" -::: - -:::: - -::::: - -## Set up using Python - -### Create a new Python environment - -:::{include} python_env_setup.inc.md -::: - -:::::{tab-set} -:sync-group: device - -::::{tab-item} NVIDIA CUDA -:sync: cuda - -:::{include} gpu/cuda.inc.md -:start-after: "## Create a new Python environment" -:end-before: "### Pre-built wheels" -::: - -:::: - -::::{tab-item} AMD ROCm -:sync: rocm - -There is no extra information on creating a new Python environment for this device. - -:::: - -::::{tab-item} Intel XPU -:sync: xpu - -There is no extra information on creating a new Python environment for this device. - -:::: - -::::: - -### Pre-built wheels - -:::::{tab-set} -:sync-group: device - -::::{tab-item} NVIDIA CUDA -:sync: cuda - -:::{include} gpu/cuda.inc.md -:start-after: "### Pre-built wheels" -:end-before: "### Build wheel from source" -::: - -:::: - -::::{tab-item} AMD ROCm -:sync: rocm - -:::{include} gpu/rocm.inc.md -:start-after: "### Pre-built wheels" -:end-before: "### Build wheel from source" -::: - -:::: - -::::{tab-item} Intel XPU -:sync: xpu - -:::{include} gpu/xpu.inc.md -:start-after: "### Pre-built wheels" -:end-before: "### Build wheel from source" -::: - -:::: - -::::: - -(build-from-source)= - -### Build wheel from source - -:::::{tab-set} -:sync-group: device - -::::{tab-item} NVIDIA CUDA -:sync: cuda - -:::{include} gpu/cuda.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: - -:::: - -::::{tab-item} AMD ROCm -:sync: rocm - -:::{include} gpu/rocm.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: - -:::: - -::::{tab-item} Intel XPU -:sync: xpu - -:::{include} gpu/xpu.inc.md -:start-after: "### Build wheel from source" -:end-before: "## Set up using Docker" -::: - -:::: - -::::: - -## Set up using Docker - -### Pre-built images - -:::::{tab-set} -:sync-group: device - -::::{tab-item} NVIDIA CUDA -:sync: cuda - -:::{include} gpu/cuda.inc.md -:start-after: "### Pre-built images" -:end-before: "### Build image from source" -::: - -:::: - -::::{tab-item} AMD ROCm -:sync: rocm - -:::{include} gpu/rocm.inc.md -:start-after: "### Pre-built images" -:end-before: "### Build image from source" -::: - -:::: - -::::{tab-item} Intel XPU -:sync: xpu - -:::{include} gpu/xpu.inc.md -:start-after: "### Pre-built images" -:end-before: "### Build image from source" -::: - -:::: - -::::: - -### Build image from source - -:::::{tab-set} -:sync-group: device - -::::{tab-item} NVIDIA CUDA -:sync: cuda - -:::{include} gpu/cuda.inc.md -:start-after: "### Build image from source" -:end-before: "## Supported features" -::: - -:::: - -::::{tab-item} AMD ROCm -:sync: rocm - -:::{include} gpu/rocm.inc.md -:start-after: "### Build image from source" -:end-before: "## Supported features" -::: - -:::: - -::::{tab-item} Intel XPU -:sync: xpu - -:::{include} gpu/xpu.inc.md -:start-after: "### Build image from source" -:end-before: "## Supported features" -::: - -:::: - -::::: - -## Supported features - -:::::{tab-set} -:sync-group: device - -::::{tab-item} NVIDIA CUDA -:sync: cuda - -:::{include} gpu/cuda.inc.md -:start-after: "## Supported features" -::: - -:::: - -::::{tab-item} AMD ROCm -:sync: rocm - -:::{include} gpu/rocm.inc.md -:start-after: "## Supported features" -::: - -:::: - -::::{tab-item} Intel XPU -:sync: xpu - -:::{include} gpu/xpu.inc.md -:start-after: "## Supported features" -::: - -:::: - -::::: diff --git a/docs/source/getting_started/installation/python_env_setup.inc.md b/docs/source/getting_started/installation/python_env_setup.inc.md deleted file mode 100644 index 00b61ea5c826..000000000000 --- a/docs/source/getting_started/installation/python_env_setup.inc.md +++ /dev/null @@ -1,19 +0,0 @@ -You can create a new Python environment using [conda](https://docs.conda.io/projects/conda/en/stable/user-guide/getting-started.html): - -```console -# (Recommended) Create a new conda environment. -conda create -n vllm python=3.12 -y -conda activate vllm -``` - -:::{note} -[PyTorch has deprecated the conda release channel](https://github.com/pytorch/pytorch/issues/138506). If you use `conda`, please only use it to create Python environment rather than installing packages. -::: - -Or you can create a new Python environment using [uv](https://docs.astral.sh/uv/), a very fast Python environment manager. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following command: - -```console -# (Recommended) Create a new uv environment. Use `--seed` to install `pip` and `setuptools` in the environment. -uv venv --python 3.12 --seed -source .venv/bin/activate -``` diff --git a/docs/source/index.md b/docs/source/index.md deleted file mode 100644 index bbff7361f752..000000000000 --- a/docs/source/index.md +++ /dev/null @@ -1,215 +0,0 @@ -# Welcome to vLLM - -:::{figure} ./assets/logos/vllm-logo-text-light.png -:align: center -:alt: vLLM -:class: no-scaled-link -:width: 60% -::: - -:::{raw} html -<p style="text-align:center"> -<strong>Easy, fast, and cheap LLM serving for everyone -</strong> -</p> - -<p style="text-align:center"> -<script async defer src="https://buttons.github.io/buttons.js"></script> -<a class="github-button" href="https://github.com/vllm-project/vllm" data-show-count="true" data-size="large" aria-label="Star">Star</a> -<a class="github-button" href="https://github.com/vllm-project/vllm/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a> -<a class="github-button" href="https://github.com/vllm-project/vllm/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a> -</p> -::: - -vLLM is a fast and easy-to-use library for LLM inference and serving. - -Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry. - -vLLM is fast with: - -- State-of-the-art serving throughput -- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) -- Continuous batching of incoming requests -- Fast model execution with CUDA/HIP graph -- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8 -- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. -- Speculative decoding -- Chunked prefill - -vLLM is flexible and easy to use with: - -- Seamless integration with popular HuggingFace models -- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more -- Tensor parallelism and pipeline parallelism support for distributed inference -- Streaming outputs -- OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudiยฎ accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators. -- Prefix caching support -- Multi-lora support - -For more information, check out the following: - -- [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention) -- [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023) -- [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al. -- [vLLM Meetups](#meetups) - -## Documentation - -% How to start using vLLM? - -:::{toctree} -:caption: Getting Started -:maxdepth: 1 - -getting_started/installation -getting_started/quickstart -getting_started/examples/examples_index -getting_started/troubleshooting -getting_started/faq -getting_started/v1_user_guide - -::: - -% What does vLLM support? - -:::{toctree} -:caption: Models -:maxdepth: 1 - -models/supported_models -models/generative_models -models/pooling_models -models/extensions/index -::: - -% Additional capabilities - -:::{toctree} -:caption: Features -:maxdepth: 1 - -features/quantization/index -features/lora -features/tool_calling -features/reasoning_outputs -features/structured_outputs -features/automatic_prefix_caching -features/disagg_prefill -features/spec_decode -features/compatibility_matrix -::: - -% Details about running vLLM - -:::{toctree} -:caption: Training -:maxdepth: 1 - -training/trl.md -training/rlhf.md - -::: - -:::{toctree} -:caption: Inference and Serving -:maxdepth: 1 - -serving/offline_inference -serving/openai_compatible_server -serving/multimodal_inputs -serving/distributed_serving -serving/metrics -serving/engine_args -serving/env_vars -serving/usage_stats -serving/integrations/index -::: - -% Scaling up vLLM for production - -:::{toctree} -:caption: Deployment -:maxdepth: 1 - -deployment/security -deployment/docker -deployment/k8s -deployment/nginx -deployment/frameworks/index -deployment/integrations/index -::: - -% Making the most out of vLLM - -:::{toctree} -:caption: Performance -:maxdepth: 1 - -performance/optimization -performance/benchmarks -::: - -% Explanation of vLLM internals - -:::{toctree} -:caption: Design Documents -:maxdepth: 2 - -design/arch_overview -design/huggingface_integration -design/plugin_system -design/kernel/paged_attention -design/mm_processing -design/automatic_prefix_caching -design/multiprocessing -::: - -:::{toctree} -:caption: V1 Design Documents -:maxdepth: 2 - -design/v1/torch_compile -design/v1/prefix_caching -design/v1/metrics -::: - -% How to contribute to the vLLM project - -:::{toctree} -:caption: Developer Guide -:maxdepth: 2 - -contributing/overview -contributing/deprecation_policy -contributing/profiling/profiling_index -contributing/dockerfile/dockerfile -contributing/model/index -contributing/vulnerability_management -::: - -% Technical API specifications - -:::{toctree} -:caption: API Reference -:maxdepth: 2 - -api/summary -api/vllm/vllm -::: - -% Latest news and acknowledgements - -:::{toctree} -:caption: Community -:maxdepth: 1 - -community/blog -community/meetups -community/sponsors -::: - -## Indices and tables - -- {ref}`genindex` -- {ref}`modindex` diff --git a/docs/source/models/extensions/index.md b/docs/source/models/extensions/index.md deleted file mode 100644 index cdcdaa5b3501..000000000000 --- a/docs/source/models/extensions/index.md +++ /dev/null @@ -1,9 +0,0 @@ -# Built-in Extensions - -:::{toctree} -:maxdepth: 1 - -runai_model_streamer -tensorizer -fastsafetensor -::: diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md deleted file mode 100644 index 287947feb3d0..000000000000 --- a/docs/source/models/supported_models.md +++ /dev/null @@ -1,1308 +0,0 @@ -(supported-models)= - -# Supported Models - -vLLM supports [generative](generative-models) and [pooling](pooling-models) models across various tasks. -If a model supports more than one task, you can set the task via the `--task` argument. - -For each task, we list the model architectures that have been implemented in vLLM. -Alongside each architecture, we include some popular models that use it. - -## Model Implementation - -### vLLM - -If vLLM natively supports a model, its implementation can be found in <gh-file:vllm/model_executor/models>. - -These models are what we list in <project:#supported-text-models> and <project:#supported-mm-models>. - -(transformers-backend)= - -### Transformers - -vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned! - -To check if the modeling backend is Transformers, you can simply do this: - -```python -from vllm import LLM -llm = LLM(model=..., task="generate") # Name or path of your model -llm.apply_model(lambda model: print(type(model))) -``` - -If it is `TransformersForCausalLM` then it means it's based on Transformers! - -:::{tip} -You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>. -::: - -:::{note} -vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. -::: - -#### Custom models - -If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! - -For a model to be compatible with the Transformers backend for vLLM it must: - -- be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)): - * The model directory must have the correct structure (e.g. `config.json` is present). - * `config.json` must contain `auto_map.AutoModel`. -- be a Transformers backend for vLLM compatible model (see <project:#writing-custom-models>): - * Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`). - -If the compatible model is: - -- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for <project:#offline-inference> or `--trust-remode-code` for the <project:#openai-compatible-server>. -- in a local directory, simply pass directory path to `model=<MODEL_DIR>` for <project:#offline-inference> or `vllm serve <MODEL_DIR>` for the <project:#openai-compatible-server>. - -This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! - -(writing-custom-models)= - -#### Writing custom models - -This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)). - -To make your model compatible with the Transformers backend, it needs: - -1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. -2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention. -3. `MyModel` must contain `_supports_attention_backend = True`. - -```{code-block} python -:caption: modeling_my_model.py - -from transformers import PreTrainedModel -from torch import nn - -class MyAttention(nn.Module): - - def forward(self, hidden_states, **kwargs): - ... - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - **kwargs, - ) - ... - -class MyModel(PreTrainedModel): - _supports_attention_backend = True -``` - -Here is what happens in the background when this model is loaded: - -1. The config is loaded. -2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. -3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. - -That's it! - -For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class: - -```{code-block} python -:caption: configuration_my_model.py - -from transformers import PretrainedConfig - -class MyConfig(PretrainedConfig): - base_model_tp_plan = { - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } -``` - -- `base_model_tp_plan` is a `dict` that maps fully qualified layer name patterns to tensor parallel styles (currently only `"colwise"` and `"rowwise"` are supported). -- `base_model_pp_plan` is a `dict` that maps direct child layer names to `tuple`s of `list`s of `str`s: - * You only need to do this for layers which are not present on all pipeline stages - * vLLM assumes that there will be only one `nn.ModuleList`, which is distributed across the pipeline stages - * The `list` in the first element of the `tuple` contains the names of the input arguments - * The `list` in the last element of the `tuple` contains the names of the variables the layer outputs to in your modeling code - -## Loading a Model - -### Hugging Face Hub - -By default, vLLM loads models from [Hugging Face (HF) Hub](https://huggingface.co/models). To change the download path for models, you can set the `HF_HOME` environment variable; for more details, refer to [their official documentation](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome). - -To determine whether a given model is natively supported, you can check the `config.json` file inside the HF repository. -If the `"architectures"` field contains a model architecture listed below, then it should be natively supported. - -Models do not _need_ to be natively supported to be used in vLLM. -The [Transformers backend](#transformers-backend) enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!). - -:::{tip} -The easiest way to check if your model is really supported at runtime is to run the program below: - -```python -from vllm import LLM - -# For generative models (task=generate) only -llm = LLM(model=..., task="generate") # Name or path of your model -output = llm.generate("Hello, my name is") -print(output) - -# For pooling models (task={embed,classify,reward,score}) only -llm = LLM(model=..., task="embed") # Name or path of your model -output = llm.encode("Hello, my name is") -print(output) -``` - -If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported. -::: - -Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM. -Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support. - -#### Using a proxy - -Here are some tips for loading/downloading models from Hugging Face using a proxy: - -- Set the proxy globally for your session (or set it in the profile file): - -```shell -export http_proxy=http://your.proxy.server:port -export https_proxy=http://your.proxy.server:port -``` - -- Set the proxy for just the current command: - -```shell -https_proxy=http://your.proxy.server:port huggingface-cli download <model_name> - -# or use vllm cmd directly -https_proxy=http://your.proxy.server:port vllm serve <model_name> --disable-log-requests -``` - -- Set the proxy in Python interpreter: - -```python -import os - -os.environ['http_proxy'] = 'http://your.proxy.server:port' -os.environ['https_proxy'] = 'http://your.proxy.server:port' -``` - -### ModelScope - -To use models from [ModelScope](https://www.modelscope.cn) instead of Hugging Face Hub, set an environment variable: - -```shell -export VLLM_USE_MODELSCOPE=True -``` - -And use with `trust_remote_code=True`. - -```python -from vllm import LLM - -llm = LLM(model=..., revision=..., task=..., trust_remote_code=True) - -# For generative models (task=generate) only -output = llm.generate("Hello, my name is") -print(output) - -# For pooling models (task={embed,classify,reward,score}) only -output = llm.encode("Hello, my name is") -print(output) -``` - -(feature-status-legend)= - -## Feature Status Legend - -- โœ…๏ธŽ indicates that the feature is supported for the model. - -- ๐Ÿšง indicates that the feature is planned but not yet supported for the model. - -- โš ๏ธ indicates that the feature is available but may have known issues or limitations. - -(supported-text-models)= - -## List of Text-only Language Models - -### Generative Models - -See [this page](#generative-models) for more information on how to use generative models. - -#### Text Generation - -Specified using `--task generate`. - -:::{list-table} -:widths: 25 25 50 5 5 -:header-rows: 1 - -- * Architecture - * Models - * Example HF Models - * [LoRA](#lora-adapter) - * [PP](#distributed-serving) -- * `AquilaForCausalLM` - * Aquila, Aquila2 - * `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `ArcticForCausalLM` - * Arctic - * `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. - * - * โœ…๏ธŽ -- * `BaiChuanForCausalLM` - * Baichuan2, Baichuan - * `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `BambaForCausalLM` - * Bamba - * `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` - * - * -- * `BloomForCausalLM` - * BLOOM, BLOOMZ, BLOOMChat - * `bigscience/bloom`, `bigscience/bloomz`, etc. - * - * โœ…๏ธŽ -- * `BartForConditionalGeneration` - * BART - * `facebook/bart-base`, `facebook/bart-large-cnn`, etc. - * - * -- * `ChatGLMModel`, `ChatGLMForConditionalGeneration` - * ChatGLM - * `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `CohereForCausalLM`, `Cohere2ForCausalLM` - * Command-R - * `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `DbrxForCausalLM` - * DBRX - * `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. - * - * โœ…๏ธŽ -- * `DeciLMForCausalLM` - * DeciLM - * `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. - * - * โœ…๏ธŽ -- * `DeepseekForCausalLM` - * DeepSeek - * `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. - * - * โœ…๏ธŽ -- * `DeepseekV2ForCausalLM` - * DeepSeek-V2 - * `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. - * - * โœ…๏ธŽ -- * `DeepseekV3ForCausalLM` - * DeepSeek-V3 - * `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. - * - * โœ…๏ธŽ -- * `ExaoneForCausalLM` - * EXAONE-3 - * `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `FalconForCausalLM` - * Falcon - * `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. - * - * โœ…๏ธŽ -- * `FalconMambaForCausalLM` - * FalconMamba - * `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GemmaForCausalLM` - * Gemma - * `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Gemma2ForCausalLM` - * Gemma 2 - * `google/gemma-2-9b`, `google/gemma-2-27b`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Gemma3ForCausalLM` - * Gemma 3 - * `google/gemma-3-1b-it`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GlmForCausalLM` - * GLM-4 - * `THUDM/glm-4-9b-chat-hf`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Glm4ForCausalLM` - * GLM-4-0414 - * `THUDM/GLM-4-32B-0414`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GPT2LMHeadModel` - * GPT-2 - * `gpt2`, `gpt2-xl`, etc. - * - * โœ…๏ธŽ -- * `GPTBigCodeForCausalLM` - * StarCoder, SantaCoder, WizardCoder - * `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GPTJForCausalLM` - * GPT-J - * `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. - * - * โœ…๏ธŽ -- * `GPTNeoXForCausalLM` - * GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - * `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. - * - * โœ…๏ธŽ -- * `GraniteForCausalLM` - * Granite 3.0, Granite 3.1, PowerLM - * `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GraniteMoeForCausalLM` - * Granite 3.0 MoE, PowerMoE - * `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GraniteMoeHybridForCausalLM` - * Granite 4.0 MoE Hybrid - * `ibm-granite/granite-4.0-tiny-preview`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GraniteMoeSharedForCausalLM` - * Granite MoE Shared - * `ibm-research/moe-7b-1b-active-shared-experts` (test model) - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GritLM` - * GritLM - * `parasail-ai/GritLM-7B-vllm`. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Grok1ModelForCausalLM` - * Grok1 - * `hpcai-tech/grok-1`. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `InternLMForCausalLM` - * InternLM - * `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `InternLM2ForCausalLM` - * InternLM2 - * `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `InternLM3ForCausalLM` - * InternLM3 - * `internlm/internlm3-8b-instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `JAISLMHeadModel` - * Jais - * `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. - * - * โœ…๏ธŽ -- * `JambaForCausalLM` - * Jamba - * `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `LlamaForCausalLM` - * Llama 3.1, Llama 3, Llama 2, LLaMA, Yi - * `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MambaForCausalLM` - * Mamba - * `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. - * - * โœ…๏ธŽ -- * `MiniCPMForCausalLM` - * MiniCPM - * `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MiniCPM3ForCausalLM` - * MiniCPM3 - * `openbmb/MiniCPM3-4B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MistralForCausalLM` - * Mistral, Mistral-Instruct - * `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MixtralForCausalLM` - * Mixtral-8x7B, Mixtral-8x7B-Instruct - * `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MPTForCausalLM` - * MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - * `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. - * - * โœ…๏ธŽ -- * `NemotronForCausalLM` - * Nemotron-3, Nemotron-4, Minitron - * `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `OLMoForCausalLM` - * OLMo - * `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. - * - * โœ…๏ธŽ -- * `OLMo2ForCausalLM` - * OLMo2 - * `allenai/OLMo2-7B-1124`, etc. - * - * โœ…๏ธŽ -- * `OLMoEForCausalLM` - * OLMoE - * `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `OPTForCausalLM` - * OPT, OPT-IML - * `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. - * - * โœ…๏ธŽ -- * `OrionForCausalLM` - * Orion - * `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. - * - * โœ…๏ธŽ -- * `PhiForCausalLM` - * Phi - * `microsoft/phi-1_5`, `microsoft/phi-2`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Phi3ForCausalLM` - * Phi-4, Phi-3 - * `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Phi3SmallForCausalLM` - * Phi-3-Small - * `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. - * - * โœ…๏ธŽ -- * `PhiMoEForCausalLM` - * Phi-3.5-MoE - * `microsoft/Phi-3.5-MoE-instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `PersimmonForCausalLM` - * Persimmon - * `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. - * - * โœ…๏ธŽ -- * `Plamo2ForCausalLM` - * PLaMo2 - * `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. - * - * -- * `QWenLMHeadModel` - * Qwen - * `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2ForCausalLM` - * QwQ, Qwen2 - * `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2MoeForCausalLM` - * Qwen2MoE - * `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. - * - * โœ…๏ธŽ -- * `Qwen3ForCausalLM` - * Qwen3 - * `Qwen/Qwen3-8B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen3MoeForCausalLM` - * Qwen3MoE - * `Qwen/Qwen3-30B-A3B`, etc. - * - * โœ…๏ธŽ -- * `StableLmForCausalLM` - * StableLM - * `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. - * - * โœ…๏ธŽ -- * `Starcoder2ForCausalLM` - * Starcoder2 - * `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. - * - * โœ…๏ธŽ -- * `SolarForCausalLM` - * Solar Pro - * `upstage/solar-pro-preview-instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `TeleChat2ForCausalLM` - * TeleChat2 - * `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `TeleFLMForCausalLM` - * TeleFLM - * `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `XverseForCausalLM` - * XVERSE - * `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MiniMaxText01ForCausalLM` - * MiniMax-Text - * `MiniMaxAI/MiniMax-Text-01`, etc. - * - * โœ…๏ธŽ -- * `Zamba2ForCausalLM` - * Zamba2 - * `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. - * - * -::: - -:::{note} -Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. -::: - -### Pooling Models - -See [this page](pooling-models) for more information on how to use pooling models. - -:::{important} -Since some model architectures support both generative and pooling tasks, -you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. -::: - -#### Text Embedding - -Specified using `--task embed`. - -:::{list-table} -:widths: 25 25 50 5 5 -:header-rows: 1 - -- * Architecture - * Models - * Example HF Models - * [LoRA](#lora-adapter) - * [PP](#distributed-serving) -- * `BertModel` - * BERT-based - * `BAAI/bge-base-en-v1.5`, etc. - * - * -- * `Gemma2Model` - * Gemma 2-based - * `BAAI/bge-multilingual-gemma2`, etc. - * - * โœ…๏ธŽ -- * `GritLM` - * GritLM - * `parasail-ai/GritLM-7B-vllm`. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. - * Llama-based - * `intfloat/e5-mistral-7b-instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2Model`, `Qwen2ForCausalLM` - * Qwen2-based - * `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `RobertaModel`, `RobertaForMaskedLM` - * RoBERTa-based - * `sentence-transformers/all-roberta-large-v1`, `sentence-transformers/all-roberta-large-v1`, etc. - * - * -- * `XLMRobertaModel` - * XLM-RoBERTa-based - * `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, etc. - * - * -::: - -:::{note} -`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. -You should manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. -::: - -:::{note} -The HF implementation of `Alibaba-NLP/gte-Qwen2-1.5B-instruct` is hardcoded to use causal attention despite what is shown in `config.json`. To compare vLLM vs HF results, -you should set `--hf-overrides '{"is_causal": true}'` in vLLM so that the two implementations are consistent with each other. - -For both the 1.5B and 7B variants, you also need to enable `--trust-remote-code` for the correct tokenizer to be loaded. -See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882). -::: - -If your model is not in the above list, we will try to automatically convert the model using -{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings -of the whole prompt are extracted from the normalized hidden state corresponding to the last token. - -#### Reward Modeling - -Specified using `--task reward`. - -:::{list-table} -:widths: 25 25 50 5 5 -:header-rows: 1 - -- * Architecture - * Models - * Example HF Models - * [LoRA](#lora-adapter) - * [PP](#distributed-serving) -- * `InternLM2ForRewardModel` - * InternLM2-based - * `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `LlamaForCausalLM` - * Llama-based - * `peiyi9979/math-shepherd-mistral-7b-prm`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2ForRewardModel` - * Qwen2-based - * `Qwen/Qwen2.5-Math-RM-72B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2ForProcessRewardModel` - * Qwen2-based - * `Qwen/Qwen2.5-Math-PRM-7B`, `Qwen/Qwen2.5-Math-PRM-72B`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -::: - -If your model is not in the above list, we will try to automatically convert the model using -{func}`~vllm.model_executor.models.adapters.as_reward_model`. By default, we return the hidden states of each token directly. - -:::{important} -For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, -e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. -::: - -#### Classification - -Specified using `--task classify`. - -:::{list-table} -:widths: 25 25 50 5 5 -:header-rows: 1 - -- * Architecture - * Models - * Example HF Models - * [LoRA](#lora-adapter) - * [PP](#distributed-serving) -- * `JambaForSequenceClassification` - * Jamba - * `ai21labs/Jamba-tiny-reward-dev`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2ForSequenceClassification` - * Qwen2-based - * `jason9693/Qwen2.5-1.5B-apeach`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ -::: - -If your model is not in the above list, we will try to automatically convert the model using -{func}`~vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. - -#### Sentence Pair Scoring - -Specified using `--task score`. - -:::{list-table} -:widths: 25 25 50 5 5 -:header-rows: 1 - -- * Architecture - * Models - * Example HF Models - * [LoRA](#lora-adapter) - * [PP](#distributed-serving) -- * `BertForSequenceClassification` - * BERT-based - * `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. - * - * -- * `RobertaForSequenceClassification` - * RoBERTa-based - * `cross-encoder/quora-roberta-base`, etc. - * - * -- * `XLMRobertaForSequenceClassification` - * XLM-RoBERTa-based - * `BAAI/bge-reranker-v2-m3`, etc. - * - * -- * `ModernBertForSequenceClassification` - * ModernBert-based - * `Alibaba-NLP/gte-reranker-modernbert-base`, etc. - * - * -::: - -(supported-mm-models)= - -## List of Multimodal Language Models - -The following modalities are supported depending on the model: - -- **T**ext -- **I**mage -- **V**ideo -- **A**udio - -Any combination of modalities joined by `+` are supported. - -- e.g.: `T + I` means that the model supports text-only, image-only, and text-with-image inputs. - -On the other hand, modalities separated by `/` are mutually exclusive. - -- e.g.: `T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. - -See [this page](#multimodal-inputs) on how to pass multi-modal inputs to the model. - -:::{important} -**To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference) -or `--limit-mm-per-prompt` (online serving). For example, to enable passing up to 4 images per text prompt: - -Offline inference: - -```python -from vllm import LLM - -llm = LLM( - model="Qwen/Qwen2-VL-7B-Instruct", - limit_mm_per_prompt={"image": 4}, -) -``` - -Online serving: - -```bash -vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}' -``` - -**This is no longer required if you are using vLLM V1.** - -::: - -:::{note} -vLLM currently only supports adding LoRA to the language backbone of multimodal models. -::: - -### Generative Models - -See [this page](#generative-models) for more information on how to use generative models. - -#### Text Generation - -Specified using `--task generate`. - -:::{list-table} -:widths: 25 25 15 20 5 5 5 -:header-rows: 1 - -- * Architecture - * Models - * Inputs - * Example HF Models - * [LoRA](#lora-adapter) - * [PP](#distributed-serving) - * [V1](gh-issue:8779) -- * `AriaForConditionalGeneration` - * Aria - * T + I<sup>+</sup> - * `rhymes-ai/Aria` - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `AyaVisionForConditionalGeneration` - * Aya Vision - * T + I<sup>+</sup> - * `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Blip2ForConditionalGeneration` - * BLIP-2 - * T + I<sup>E</sup> - * `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `ChameleonForConditionalGeneration` - * Chameleon - * T + I - * `facebook/chameleon-7b` etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `DeepseekVLV2ForCausalLM`<sup>^</sup> - * DeepSeek-VL2 - * T + I<sup>+</sup> - * `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Florence2ForConditionalGeneration` - * Florence-2 - * T + I - * `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. - * - * - * -- * `FuyuForCausalLM` - * Fuyu - * T + I - * `adept/fuyu-8b` etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Gemma3ForConditionalGeneration` - * Gemma 3 - * T + I<sup>+</sup> - * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โš ๏ธ -- * `GLM4VForCausalLM`<sup>^</sup> - * GLM-4V - * T + I - * `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `GraniteSpeechForConditionalGeneration` - * Granite Speech - * T + A - * `ibm-granite/granite-speech-3.3-8b` - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `H2OVLChatModel` - * H2OVL - * T + I<sup>E+</sup> - * `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ\* -- * `Idefics3ForConditionalGeneration` - * Idefics3 - * T + I - * `HuggingFaceM4/Idefics3-8B-Llama3` etc. - * โœ…๏ธŽ - * - * โœ…๏ธŽ -- * `InternVLChatModel` - * InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 - * T + I<sup>E+</sup> - * `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `KimiVLForConditionalGeneration` - * Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking - * T + I<sup>+</sup> - * `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` - * - * - * โœ…๏ธŽ -- * `Llama4ForConditionalGeneration` - * Llama 4 - * T + I<sup>+</sup> - * `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `LlavaForConditionalGeneration` - * LLaVA-1.5 - * T + I<sup>E+</sup> - * `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `LlavaNextForConditionalGeneration` - * LLaVA-NeXT - * T + I<sup>E+</sup> - * `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `LlavaNextVideoForConditionalGeneration` - * LLaVA-NeXT-Video - * T + V - * `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `LlavaOnevisionForConditionalGeneration` - * LLaVA-Onevision - * T + I<sup>+</sup> + V<sup>+</sup> - * `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MiniCPMO` - * MiniCPM-O - * T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> - * `openbmb/MiniCPM-o-2_6`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MiniCPMV` - * MiniCPM-V - * T + I<sup>E+</sup> + V<sup>E+</sup> - * `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MiniMaxVL01ForConditionalGeneration` - * MiniMax-VL - * T + I<sup>E+</sup> - * `MiniMaxAI/MiniMax-VL-01`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Mistral3ForConditionalGeneration` - * Mistral3 - * T + I<sup>+</sup> - * `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `MllamaForConditionalGeneration` - * Llama 3.2 - * T + I<sup>+</sup> - * `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. - * - * - * -- * `MolmoForCausalLM` - * Molmo - * T + I<sup>+</sup> - * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `NVLM_D_Model` - * NVLM-D 1.0 - * T + I<sup>+</sup> - * `nvidia/NVLM-D-72B`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Ovis2ForConditionalGeneration`<sup>^</sup> - * Ovis2 - * T + I<sup>+</sup> - * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc. - * - * - * โœ…๏ธŽ -- * `PaliGemmaForConditionalGeneration` - * PaliGemma, PaliGemma 2 - * T + I<sup>E</sup> - * `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. - * - * โœ…๏ธŽ - * โš ๏ธ -- * `Phi3VForCausalLM` - * Phi-3-Vision, Phi-3.5-Vision - * T + I<sup>E+</sup> - * `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Phi4MMForCausalLM` - * Phi-4-multimodal - * T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> - * `microsoft/Phi-4-multimodal-instruct`, etc. - * โœ…๏ธŽ - * - * โœ…๏ธŽ -- * `PixtralForConditionalGeneration` - * Pixtral - * T + I<sup>+</sup> - * `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `QwenVLForConditionalGeneration`<sup>^</sup> - * Qwen-VL - * T + I<sup>E+</sup> - * `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2AudioForConditionalGeneration` - * Qwen2-Audio - * T + A<sup>+</sup> - * `Qwen/Qwen2-Audio-7B-Instruct` - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2VLForConditionalGeneration` - * QVQ, Qwen2-VL - * T + I<sup>E+</sup> + V<sup>E+</sup> - * `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2_5_VLForConditionalGeneration` - * Qwen2.5-VL - * T + I<sup>E+</sup> + V<sup>E+</sup> - * `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `Qwen2_5OmniThinkerForConditionalGeneration` - * Qwen2.5-Omni - * T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> - * `Qwen/Qwen2.5-Omni-7B` - * - * โœ…๏ธŽ - * โœ…๏ธŽ\* -- * `SkyworkR1VChatModel` - * Skywork-R1V-38B - * T + I - * `Skywork/Skywork-R1V-38B` - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `SmolVLMForConditionalGeneration` - * SmolVLM2 - * T + I - * `SmolVLM2-2.2B-Instruct` - * - * โœ…๏ธŽ - * โœ…๏ธŽ -- * `UltravoxModel` - * Ultravox - * T + A<sup>E+</sup> - * `fixie-ai/ultravox-v0_5-llama-3_2-1b` - * โœ…๏ธŽ - * โœ…๏ธŽ - * โœ…๏ธŽ -::: - -<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM. -    โ€ข For example, to use DeepSeek-VL2 series models: -      `--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` -<sup>E</sup> Pre-computed embeddings can be inputted for this modality. -<sup>+</sup> Multiple items can be inputted per text prompt for this modality. - -:::{warning} -Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. -However, there are differences in how they handle text + image inputs: - -V0 correctly implements the model's attention pattern: -- Uses bidirectional attention between the image tokens corresponding to the same image -- Uses causal attention for other tokens -- Implemented via (naive) PyTorch SDPA with masking tensors -- Note: May use significant memory for long prompts with image - -V1 currently uses a simplified attention pattern: -- Uses causal attention for all tokens, including image tokens -- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` -- Will be updated in the future to support the correct behavior - -This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. -::: - -:::{note} -`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80. -::: - -:::{note} -To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. -::: - -:::{warning} -The output quality of `AllenAI/Molmo-7B-D-0924` (especially in object localization tasks) has deteriorated in recent updates. - -For the best results, we recommend using the following dependency versions (tested on A10 and L40): - -```text -# Core vLLM-compatible dependencies with Molmo accuracy setup (tested on L40) -torch==2.5.1 -torchvision==0.20.1 -transformers==4.48.1 -tokenizers==0.21.0 -tiktoken==0.7.0 -vllm==0.7.0 - -# Optional but recommended for improved performance and stability -triton==3.1.0 -xformers==0.0.28.post3 -uvloop==0.21.0 -protobuf==5.29.3 -openai==1.60.2 -opencv-python-headless==4.11.0.86 -pillow==10.4.0 - -# Installed FlashAttention (for float16 only) -flash-attn>=2.5.6 # Not used in float32, but should be documented -``` - -**Note:** Make sure you understand the security implications of using outdated packages. -::: - -:::{note} -The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. -For more details, please see: <gh-pr:4087#issuecomment-2250397630> -::: - -:::{warning} -Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. -::: - -:::{note} -To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via -`pip install git+https://github.com/huggingface/transformers.git`. - -Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. -`--mm-processor-kwargs '{"use_audio_in_video": true}'`. -::: - -### Pooling Models - -See [this page](pooling-models) for more information on how to use pooling models. - -:::{important} -Since some model architectures support both generative and pooling tasks, -you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. -::: - -#### Text Embedding - -Specified using `--task embed`. - -Any text generation model can be converted into an embedding model by passing `--task embed`. - -:::{note} -To get the best results, you should use pooling models that are specifically trained as such. -::: - -The following table lists those that are tested in vLLM. - -:::{list-table} -:widths: 25 25 15 25 5 5 -:header-rows: 1 - -- * Architecture - * Models - * Inputs - * Example HF Models - * [LoRA](#lora-adapter) - * [PP](#distributed-serving) -- * `LlavaNextForConditionalGeneration` - * LLaVA-NeXT-based - * T / I - * `royokong/e5-v` - * - * โœ…๏ธŽ -- * `Phi3VForCausalLM` - * Phi-3-Vision-based - * T + I - * `TIGER-Lab/VLM2Vec-Full` - * ๐Ÿšง - * โœ…๏ธŽ -- * `Qwen2VLForConditionalGeneration` - * Qwen2-VL-based - * T + I - * `MrLight/dse-qwen2-2b-mrl-v1` - * - * โœ…๏ธŽ -::: - -#### Transcription - -Specified using `--task transcription`. - -Speech2Text models trained specifically for Automatic Speech Recognition. - -:::{list-table} -:widths: 25 25 25 5 5 -:header-rows: 1 - -- * Architecture - * Models - * Example HF Models - * [LoRA](#lora-adapter) - * [PP](#distributed-serving) -- * `Whisper` - * Whisper-based - * `openai/whisper-large-v3-turbo` - * ๐Ÿšง - * ๐Ÿšง -::: - -_________________ - -## Model Support Policy - -At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Hereโ€™s how we manage third-party model support: - -1. **Community-Driven Support**: We encourage community contributions for adding new models. When a user requests support for a new model, we welcome pull requests (PRs) from the community. These contributions are evaluated primarily on the sensibility of the output they generate, rather than strict consistency with existing implementations such as those in transformers. **Call for contribution:** PRs coming directly from model vendors are greatly appreciated! - -2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results. - - :::{tip} - When comparing the output of `model.generate` from Hugging Face Transformers with the output of `llm.generate` from vLLM, note that the former reads the model's generation config file (i.e., [generation_config.json](https://github.com/huggingface/transformers/blob/19dabe96362803fb0a9ae7073d03533966598b17/src/transformers/generation/utils.py#L1945)) and applies the default parameters for generation, while the latter only uses the parameters passed to the function. Ensure all sampling parameters are identical when comparing outputs. - ::: - -3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback. - -4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use. - -5. **Selective Focus**: Our resources are primarily directed towards models with significant user interest and impact. Models that are less frequently used may receive less attention, and we rely on the community to play a more active role in their upkeep and improvement. - -Through this approach, vLLM fosters a collaborative environment where both the core development team and the broader community contribute to the robustness and diversity of the third-party models supported in our ecosystem. - -Note that, as an inference engine, vLLM does not introduce new models. Therefore, all models supported by vLLM are third-party models in this regard. - -We have the following levels of testing for models: - -1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. -2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. -3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. -4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/docs/source/serving/engine_args.md b/docs/source/serving/engine_args.md deleted file mode 100644 index 97ea01cd3b2e..000000000000 --- a/docs/source/serving/engine_args.md +++ /dev/null @@ -1,34 +0,0 @@ -(engine-args)= - -# Engine Arguments - -Engine arguments control the behavior of the vLLM engine. - -- For [offline inference](#offline-inference), they are part of the arguments to `LLM` class. -- For [online serving](#openai-compatible-server), they are part of the arguments to `vllm serve`. - -Below, you can find an explanation of every engine argument: - -<!--- pyml disable-num-lines 7 no-space-in-emphasis --> -```{eval-rst} -.. argparse:: - :module: vllm.engine.arg_utils - :func: _engine_args_parser - :prog: vllm serve - :nodefaultconst: - :markdownhelp: -``` - -## Async Engine Arguments - -Additional arguments are available to the asynchronous engine which is used for online serving: - -<!--- pyml disable-num-lines 7 no-space-in-emphasis --> -```{eval-rst} -.. argparse:: - :module: vllm.engine.arg_utils - :func: _async_engine_args_parser - :prog: vllm serve - :nodefaultconst: - :markdownhelp: -``` diff --git a/docs/source/serving/env_vars.md b/docs/source/serving/env_vars.md deleted file mode 100644 index 9845241930a4..000000000000 --- a/docs/source/serving/env_vars.md +++ /dev/null @@ -1,15 +0,0 @@ -# Environment Variables - -vLLM uses the following environment variables to configure the system: - -:::{warning} -Please note that `VLLM_PORT` and `VLLM_HOST_IP` set the port and ip for vLLM's **internal usage**. It is not the port and ip for the API server. If you use `--host $VLLM_HOST_IP` and `--port $VLLM_PORT` to start the API server, it will not work. - -All environment variables used by vLLM are prefixed with `VLLM_`. **Special care should be taken for Kubernetes users**: please do not name the service as `vllm`, otherwise environment variables set by Kubernetes might conflict with vLLM's environment variables, because [Kubernetes sets environment variables for each service with the capitalized service name as the prefix](https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables). -::: - -:::{literalinclude} ../../../vllm/envs.py -:end-before: end-env-vars-definition -:language: python -:start-after: begin-env-vars-definition -::: diff --git a/docs/source/serving/integrations/index.md b/docs/source/serving/integrations/index.md deleted file mode 100644 index e2b4c0814605..000000000000 --- a/docs/source/serving/integrations/index.md +++ /dev/null @@ -1,8 +0,0 @@ -# External Integrations - -:::{toctree} -:maxdepth: 1 - -langchain -llamaindex -::: diff --git a/docs/source/serving/offline_inference.md b/docs/source/serving/offline_inference.md deleted file mode 100644 index e46361955c73..000000000000 --- a/docs/source/serving/offline_inference.md +++ /dev/null @@ -1,215 +0,0 @@ -(offline-inference)= - -# Offline Inference - -You can run vLLM in your own code on a list of prompts. - -The offline API is based on the {class}`~vllm.LLM` class. -To initialize the vLLM engine, create a new instance of `LLM` and specify the model to run. - -For example, the following code downloads the [`facebook/opt-125m`](https://huggingface.co/facebook/opt-125m) model from HuggingFace -and runs it in vLLM using the default configuration. - -```python -from vllm import LLM - -llm = LLM(model="facebook/opt-125m") -``` - -After initializing the `LLM` instance, you can perform model inference using various APIs. -The available APIs depend on the type of model that is being run: - -- [Generative models](#generative-models) output logprobs which are sampled from to obtain the final output text. -- [Pooling models](#pooling-models) output their hidden states directly. - -Please refer to the above pages for more details about each API. - -:::{seealso} -[API Reference](#offline-inference-api) -::: - -(configuration-options)= - -## Configuration Options - -This section lists the most common options for running the vLLM engine. -For a full list, refer to the <project:#configuration> page. - -(model-resolution)= - -### Model resolution - -vLLM loads HuggingFace-compatible models by inspecting the `architectures` field in `config.json` of the model repository -and finding the corresponding implementation that is registered to vLLM. -Nevertheless, our model resolution may fail for the following reasons: - -- The `config.json` of the model repository lacks the `architectures` field. -- Unofficial repositories refer to a model using alternative names which are not recorded in vLLM. -- The same architecture name is used for multiple models, creating ambiguity as to which model should be loaded. - -To fix this, explicitly specify the model architecture by passing `config.json` overrides to the `hf_overrides` option. -For example: - -```python -from vllm import LLM - -model = LLM( - model="cerebras/Cerebras-GPT-1.3B", - hf_overrides={"architectures": ["GPT2LMHeadModel"]}, # GPT-2 -) -``` - -Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM. - -(reducing-memory-usage)= - -### Reducing memory usage - -Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem. - -#### Tensor Parallelism (TP) - -Tensor parallelism (`tensor_parallel_size` option) can be used to split the model across multiple GPUs. - -The following code splits the model across 2 GPUs. - -```python -llm = LLM(model="ibm-granite/granite-3.1-8b-instruct", - tensor_parallel_size=2) -``` - -:::{important} -To ensure that vLLM initializes CUDA correctly, you should avoid calling related functions (e.g. {func}`torch.cuda.set_device`) -before initializing vLLM. Otherwise, you may run into an error like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`. - -To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable. -::: - -:::{note} -With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). - -You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. -::: - -#### Quantization - -Quantized models take less memory at the cost of lower precision. - -Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Red Hat AI](https://huggingface.co/RedHatAI)) -and used directly without extra configuration. - -Dynamic quantization is also supported via the `quantization` option -- see [here](#quantization-index) for more details. - -#### Context length and batch size - -You can further reduce memory usage by limiting the context length of the model (`max_model_len` option) -and the maximum batch size (`max_num_seqs` option). - -```python -from vllm import LLM - -llm = LLM(model="adept/fuyu-8b", - max_model_len=2048, - max_num_seqs=2) -``` - -#### Reduce CUDA Graphs - -By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU. - -:::{important} -CUDA graph capture takes up more memory in V1 than in V0. -::: - -You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage: - -```python -from vllm import LLM -from vllm.config import CompilationConfig, CompilationLevel - -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - # By default, it goes up to max_num_seqs - cudagraph_capture_sizes=[1, 2, 4, 8, 16], - ), -) -``` - -You can disable graph capturing completely via the `enforce_eager` flag: - -```python -from vllm import LLM - -llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", - enforce_eager=True) -``` - -#### Adjust cache size - -If you run out of CPU RAM, try the following options: - -- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). -- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). - -#### Multi-modal input limits - -You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model: - -```python -from vllm import LLM - -# Accept up to 3 images and 1 video per prompt -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - limit_mm_per_prompt={"image": 3, "video": 1}) -``` - -You can go a step further and disable unused modalities completely by setting its limit to zero. -For example, if your application only accepts image input, there is no need to allocate any memory for videos. - -```python -from vllm import LLM - -# Accept any number of images but no videos -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - limit_mm_per_prompt={"video": 0}) -``` - -You can even run a multi-modal model for text-only inference: - -```python -from vllm import LLM - -# Don't accept images. Just text. -llm = LLM(model="google/gemma-3-27b-it", - limit_mm_per_prompt={"image": 0}) -``` - -#### Multi-modal processor arguments - -For certain models, you can adjust the multi-modal processor arguments to -reduce the size of the processed multi-modal inputs, which in turn saves memory. - -Here are some examples: - -```python -from vllm import LLM - -# Available for Qwen2-VL series models -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_processor_kwargs={ - "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 - }) - -# Available for InternVL series models -llm = LLM(model="OpenGVLab/InternVL2-2B", - mm_processor_kwargs={ - "max_dynamic_patch": 4, # Default is 12 - }) -``` - -### Performance optimization and tuning - -You can potentially improve the performance of vLLM by finetuning various options. -Please refer to [this guide](#optimization-and-tuning) for more details. diff --git a/docs/source/training/rlhf.md b/docs/training/rlhf.md similarity index 69% rename from docs/source/training/rlhf.md rename to docs/training/rlhf.md index 72e89c0c7478..4f75e4e01495 100644 --- a/docs/source/training/rlhf.md +++ b/docs/training/rlhf.md @@ -6,6 +6,6 @@ vLLM can be used to generate the completions for RLHF. The best way to do this i See the following basic examples to get started if you don't want to use an existing library: -- [Training and inference processes are located on separate GPUs (inspired by OpenRLHF)](https://docs.vllm.ai/en/latest/getting_started/examples/rlhf.html) -- [Training and inference processes are colocated on the same GPUs using Ray](https://docs.vllm.ai/en/latest/getting_started/examples/rlhf_colocate.html) -- [Utilities for performing RLHF with vLLM](https://docs.vllm.ai/en/latest/getting_started/examples/rlhf_utils.html) +- [Training and inference processes are located on separate GPUs (inspired by OpenRLHF)](../examples/offline_inference/rlhf.md) +- [Training and inference processes are colocated on the same GPUs using Ray](../examples/offline_inference/rlhf_colocate.md) +- [Utilities for performing RLHF with vLLM](../examples/offline_inference/rlhf_utils.md) diff --git a/docs/source/training/trl.md b/docs/training/trl.md similarity index 66% rename from docs/source/training/trl.md rename to docs/training/trl.md index ebdf593dbde5..c7c1a5a3bbd1 100644 --- a/docs/source/training/trl.md +++ b/docs/training/trl.md @@ -6,8 +6,7 @@ Online methods such as GRPO or Online DPO require the model to generate completi See the guide [vLLM for fast generation in online methods](https://huggingface.co/docs/trl/main/en/speeding_up_training#vllm-for-fast-generation-in-online-methods) in the TRL documentation for more information. -:::{seealso} -For more information on the `use_vllm` flag you can provide to the configs of these online methods, see: -- [`trl.GRPOConfig.use_vllm`](https://huggingface.co/docs/trl/main/en/grpo_trainer#trl.GRPOConfig.use_vllm) -- [`trl.OnlineDPOConfig.use_vllm`](https://huggingface.co/docs/trl/main/en/online_dpo_trainer#trl.OnlineDPOConfig.use_vllm) -::: +!!! info + For more information on the `use_vllm` flag you can provide to the configs of these online methods, see: + - [`trl.GRPOConfig.use_vllm`](https://huggingface.co/docs/trl/main/en/grpo_trainer#trl.GRPOConfig.use_vllm) + - [`trl.OnlineDPOConfig.use_vllm`](https://huggingface.co/docs/trl/main/en/online_dpo_trainer#trl.OnlineDPOConfig.use_vllm) diff --git a/docs/usage/README.md b/docs/usage/README.md new file mode 100644 index 000000000000..681db57d8e0f --- /dev/null +++ b/docs/usage/README.md @@ -0,0 +1,7 @@ +# Using vLLM + +vLLM supports the following usage patterns: + +- [Inference and Serving](../serving/offline_inference.md): Run a single instance of a model. +- [Deployment](../deployment/docker.md): Scale up model instances for production. +- [Training](../training/rlhf.md): Train or fine-tune a model. diff --git a/docs/source/getting_started/faq.md b/docs/usage/faq.md similarity index 91% rename from docs/source/getting_started/faq.md rename to docs/usage/faq.md index c1bb28937c14..51977d4434f5 100644 --- a/docs/source/getting_started/faq.md +++ b/docs/usage/faq.md @@ -1,23 +1,24 @@ -(faq)= - -# Frequently Asked Questions +--- +title: Frequently Asked Questions +--- +[](){ #faq } > Q: How can I serve multiple models on a single port using the OpenAI API? A: Assuming that you're referring to using OpenAI compatible server to serve multiple models at once, that is not currently supported, you can run multiple instances of the server (each serving a different model) at the same time, and have another layer to route the incoming request to the correct server accordingly. -______________________________________________________________________ +--- > Q: Which model to use for offline inference embedding? A: You can try [e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct) and [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5); -more are listed [here](#supported-models). +more are listed [here][supported-models]. By extracting hidden states, vLLM can automatically convert text generation models like [Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B), [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) into embedding models, but they are expected to be inferior to models that are specifically trained on embedding tasks. -______________________________________________________________________ +--- > Q: Can the output of a prompt vary across runs in vLLM? diff --git a/docs/source/serving/metrics.md b/docs/usage/metrics.md similarity index 90% rename from docs/source/serving/metrics.md rename to docs/usage/metrics.md index 647ece3f85f0..9ad7253184d9 100644 --- a/docs/source/serving/metrics.md +++ b/docs/usage/metrics.md @@ -4,7 +4,7 @@ vLLM exposes a number of metrics that can be used to monitor the health of the system. These metrics are exposed via the `/metrics` endpoint on the vLLM OpenAI compatible API server. -You can start the server using Python, or using [Docker](#deployment-docker): +You can start the server using Python, or using [Docker][deployment-docker]: ```console vllm serve unsloth/Llama-3.2-1B-Instruct @@ -31,11 +31,9 @@ vllm:iteration_tokens_total_bucket{le="512.0",model_name="unsloth/Llama-3.2-1B-I The following metrics are exposed: -:::{literalinclude} ../../../vllm/engine/metrics.py -:end-before: end-metrics-definitions -:language: python -:start-after: begin-metrics-definitions -::: +```python +--8<-- "vllm/engine/metrics.py:metrics-definitions" +``` The following metrics are deprecated and due to be removed in a future version: diff --git a/docs/usage/reproducibility.md b/docs/usage/reproducibility.md new file mode 100644 index 000000000000..a494dcf19191 --- /dev/null +++ b/docs/usage/reproducibility.md @@ -0,0 +1,52 @@ +# Reproducibility + +vLLM does not guarantee the reproducibility of the results by default, for the sake of performance. You need to do the following to achieve +reproducible results: + +- For V1: Turn off multiprocessing to make the scheduling deterministic by setting `VLLM_ENABLE_V1_MULTIPROCESSING=0`. +- For V0: Set the global seed (see below). + +Example: <gh-file:examples/offline_inference/reproducibility.py> + +!!! warning + + Applying the above settings [changes the random state in user code](#locality-of-random-state). + +!!! note + + Even with the above settings, vLLM only provides reproducibility + when it runs on the same hardware and the same vLLM version. + Also, the online serving API (`vllm serve`) does not support reproducibility + because it is almost impossible to make the scheduling deterministic in the + online setting. + +## Setting the global seed + +The `seed` parameter in vLLM is used to control the random states for various random number generators. + +If a specific seed value is provided, the random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly. + +However, in some cases, setting the seed will also [change the random state in user code](#locality-of-random-state). + +### Default Behavior + +In V0, the `seed` parameter defaults to `None`. When the `seed` parameter is `None`, the random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that each run of vLLM will produce different results if `temperature > 0`, as expected. + +In V1, the `seed` parameter defaults to `0` which sets the random state for each worker, so the results will remain consistent for each vLLM run even if `temperature > 0`. + +!!! note + + It is impossible to un-specify a seed for V1 because different workers need to sample the same outputs + for workflows such as speculative decoding. + + For more information, see: <gh-pr:17929> + +### Locality of random state + +The random state in user code (i.e. the code that constructs [LLM][vllm.LLM] class) is updated by vLLM under the following conditions: + +- For V0: The seed is specified. +- For V1: The workers are run in the same process as user code, i.e.: `VLLM_ENABLE_V1_MULTIPROCESSING=0`. + +By default, these conditions are not active so you can use vLLM without having to worry about +accidentally making deterministic subsequent operations that rely on random state. diff --git a/docs/source/deployment/security.md b/docs/usage/security.md similarity index 60% rename from docs/source/deployment/security.md rename to docs/usage/security.md index e2ef8196c167..f1661828d68a 100644 --- a/docs/source/deployment/security.md +++ b/docs/usage/security.md @@ -1,4 +1,4 @@ -# Security Guide +# Security ## Inter-Node Communication @@ -53,6 +53,45 @@ Key points from the PyTorch security guide: - Implement proper authentication and authorization for management interfaces - Follow the principle of least privilege for all system components +## Security and Firewalls: Protecting Exposed vLLM Systems + +While vLLM is designed to allow unsafe network services to be isolated to +private networks, there are componentsโ€”such as dependencies and underlying +frameworksโ€”that may open insecure services listening on all network interfaces, +sometimes outside of vLLM's direct control. + +A major concern is the use of `torch.distributed`, which vLLM leverages for +distributed communication, including when using vLLM on a single host. When vLLM +uses TCP initialization (see [PyTorch TCP Initialization +documentation](https://docs.pytorch.org/docs/stable/distributed.html#tcp-initialization)), +PyTorch creates a `TCPStore` that, by default, listens on all network +interfaces. This means that unless additional protections are put in place, +these services may be accessible to any host that can reach your machine via any +network interface. + +**From a PyTorch perspective, any use of `torch.distributed` should be +considered insecure by default.** This is a known and intentional behavior from +the PyTorch team. + +### Firewall Configuration Guidance + +The best way to protect your vLLM system is to carefully configure a firewall to +expose only the minimum network surface area necessary. In most cases, this +means: + +- **Block all incoming connections except to the TCP port the API server is +listening on.** + +- Ensure that ports used for internal communication (such as those for +`torch.distributed` and KV cache transfer) are only accessible from trusted +hosts or networks. + +- Never expose these internal ports to the public internet or untrusted +networks. + +Consult your operating system or application platform documentation for specific +firewall configuration instructions. + ## Reporting Security Vulnerabilities If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md). diff --git a/docs/source/getting_started/troubleshooting.md b/docs/usage/troubleshooting.md similarity index 85% rename from docs/source/getting_started/troubleshooting.md rename to docs/usage/troubleshooting.md index a4744827f226..889cfccdacac 100644 --- a/docs/source/getting_started/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -1,12 +1,12 @@ -(troubleshooting)= - -# Troubleshooting +--- +title: Troubleshooting +--- +[](){ #troubleshooting } This document outlines some troubleshooting strategies you can consider. If you think you've discovered a bug, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible. -:::{note} -Once you've debugged a problem, remember to turn off any debugging environment variables defined, or simply start a new shell to avoid being affected by lingering debugging settings. Otherwise, the system might be slow with debugging functionalities left activated. -::: +!!! note + Once you've debugged a problem, remember to turn off any debugging environment variables defined, or simply start a new shell to avoid being affected by lingering debugging settings. Otherwise, the system might be slow with debugging functionalities left activated. ## Hangs downloading a model @@ -18,13 +18,12 @@ It's recommended to download the model first using the [huggingface-cli](https:/ If the model is large, it can take a long time to load it from disk. Pay attention to where you store the model. Some clusters have shared filesystems across nodes, e.g. a distributed filesystem or a network filesystem, which can be slow. It'd be better to store the model in a local disk. Additionally, have a look at the CPU memory usage, when the model is too large it might take a lot of CPU memory, slowing down the operating system because it needs to frequently swap between disk and memory. -:::{note} -To isolate the model downloading and loading issue, you can use the `--load-format dummy` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck. -::: +!!! note + To isolate the model downloading and loading issue, you can use the `--load-format dummy` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck. ## Out of memory -If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider adopting [these options](#reducing-memory-usage) to reduce the memory consumption. +If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider adopting [these options](../configuration/conserving_memory.md) to reduce the memory consumption. ## Generation quality changed @@ -53,9 +52,9 @@ You might also need to set `export NCCL_SOCKET_IFNAME=<your_network_interface>` ## Error near `self.graph.replay()` If vLLM crashes and the error trace captures it somewhere around `self.graph.replay()` in `vllm/worker/model_runner.py`, it is a CUDA error inside CUDAGraph. -To identify the particular CUDA operation that causes the error, you can add `--enforce-eager` to the command line, or `enforce_eager=True` to the {class}`~vllm.LLM` class to disable the CUDAGraph optimization and isolate the exact CUDA operation that causes the error. +To identify the particular CUDA operation that causes the error, you can add `--enforce-eager` to the command line, or `enforce_eager=True` to the [LLM][vllm.LLM] class to disable the CUDAGraph optimization and isolate the exact CUDA operation that causes the error. -(troubleshooting-incorrect-hardware-driver)= +[](){ #troubleshooting-incorrect-hardware-driver } ## Incorrect hardware/driver @@ -140,16 +139,15 @@ If the script runs successfully, you should see the message `sanity check is suc If the test script hangs or crashes, usually it means the hardware/drivers are broken in some sense. You should try to contact your system administrator or hardware vendor for further assistance. As a common workaround, you can try to tune some NCCL environment variables, such as `export NCCL_P2P_DISABLE=1` to see if it helps. Please check [their documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html) for more information. Please only use these environment variables as a temporary workaround, as they might affect the performance of the system. The best solution is still to fix the hardware/drivers so that the test script can run successfully. -:::{note} -A multi-node environment is more complicated than a single-node one. If you see errors such as `torch.distributed.DistNetworkError`, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments: +!!! note + A multi-node environment is more complicated than a single-node one. If you see errors such as `torch.distributed.DistNetworkError`, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments: -- In the first node, run `NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py`. -- In the second node, run `NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py`. + - In the first node, run `NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py`. + - In the second node, run `NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py`. -Adjust `--nproc-per-node`, `--nnodes`, and `--node-rank` according to your setup, being sure to execute different commands (with different `--node-rank`) on different nodes. -::: + Adjust `--nproc-per-node`, `--nnodes`, and `--node-rank` according to your setup, being sure to execute different commands (with different `--node-rank`) on different nodes. -(troubleshooting-python-multiprocessing)= +[](){ #troubleshooting-python-multiprocessing } ## Python multiprocessing @@ -161,7 +159,7 @@ If you have seen a warning in your logs like this: WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously initialized. We must use the `spawn` multiprocessing start method. Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See - https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing + https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. ``` @@ -260,7 +258,7 @@ or: ValueError: Model architectures ['<arch>'] are not supported for now. Supported architectures: [...] ``` -But you are sure that the model is in the [list of supported models](#supported-models), there may be some issue with vLLM's model resolution. In that case, please follow [these steps](#model-resolution) to explicitly specify the vLLM implementation for the model. +But you are sure that the model is in the [list of supported models][supported-models], there may be some issue with vLLM's model resolution. In that case, please follow [these steps](../configuration/model_resolution.md) to explicitly specify the vLLM implementation for the model. ## Failed to infer device type diff --git a/docs/source/serving/usage_stats.md b/docs/usage/usage_stats.md similarity index 100% rename from docs/source/serving/usage_stats.md rename to docs/usage/usage_stats.md diff --git a/docs/source/getting_started/v1_user_guide.md b/docs/usage/v1_guide.md similarity index 99% rename from docs/source/getting_started/v1_user_guide.md rename to docs/usage/v1_guide.md index de90b8a7851e..3d5d7ce45cce 100644 --- a/docs/source/getting_started/v1_user_guide.md +++ b/docs/usage/v1_guide.md @@ -1,4 +1,4 @@ -# vLLM V1 User Guide +# vLLM V1 V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack). diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index bab41c915c32..56cdd6861baa 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 """ -This example shows how to use vLLM for running offline inference +This example shows how to use vLLM for running offline inference with the correct prompt format on audio language models. For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ + import os from dataclasses import asdict from typing import NamedTuple, Optional @@ -22,7 +23,7 @@ question_per_audio_count = { 0: "What is 1+1?", 1: "What is recited in the audio?", - 2: "What sport and what nursery rhyme are referenced?" + 2: "What sport and what nursery rhyme are referenced?", } @@ -72,8 +73,7 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: # MiniCPM-O def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model_name = "openbmb/MiniCPM-o-2_6" - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) engine_args = EngineArgs( model=model_name, trust_remote_code=True, @@ -82,19 +82,18 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: limit_mm_per_prompt={"audio": audio_count}, ) - stop_tokens = ['<|im_end|>', '<|endoftext|>'] + stop_tokens = ["<|im_end|>", "<|endoftext|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] audio_placeholder = "(<audio>./</audio>)" * audio_count audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501 - messages = [{ - 'role': 'user', - 'content': f'{audio_placeholder}\n{question}' - }] - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True, - chat_template=audio_chat_template) + messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}] + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=audio_chat_template, + ) return ModelRequestData( engine_args=engine_args, @@ -113,7 +112,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. speech_lora_path = os.path.join(model_path, "speech-lora") - placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)]) + placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)]) prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>" @@ -145,15 +144,19 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData: limit_mm_per_prompt={"audio": audio_count}, ) - audio_in_prompt = "".join([ - f"Audio {idx+1}: " - f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) - ]) + audio_in_prompt = "".join( + [ + f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" + for idx in range(audio_count) + ] + ) - prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - "<|im_start|>user\n" - f"{audio_in_prompt}{question}<|im_end|>\n" - "<|im_start|>assistant\n") + prompt = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) return ModelRequestData( engine_args=engine_args, @@ -172,19 +175,22 @@ def run_qwen2_5_omni(question: str, audio_count: int): limit_mm_per_prompt={"audio": audio_count}, ) - audio_in_prompt = "".join([ - "<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) - ]) + audio_in_prompt = "".join( + ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)] + ) default_system = ( "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "Group, capable of perceiving auditory and visual inputs, as well as " - "generating text and speech.") + "generating text and speech." + ) - prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" - "<|im_start|>user\n" - f"{audio_in_prompt}{question}<|im_end|>\n" - "<|im_start|>assistant\n") + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) return ModelRequestData( engine_args=engine_args, prompt=prompt, @@ -196,13 +202,10 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData: model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" tokenizer = AutoTokenizer.from_pretrained(model_name) - messages = [{ - 'role': 'user', - 'content': "<|audio|>\n" * audio_count + question - }] - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) engine_args = EngineArgs( model=model_name, @@ -220,8 +223,7 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData: # Whisper def run_whisper(question: str, audio_count: int) -> ModelRequestData: - assert audio_count == 1, ( - "Whisper only support single audio input per prompt") + assert audio_count == 1, "Whisper only support single audio input per prompt" model_name = "openai/whisper-large-v3-turbo" prompt = "<|startoftranscript|>" @@ -252,27 +254,33 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: def parse_args(): parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'audio language models') - parser.add_argument('--model-type', - '-m', - type=str, - default="ultravox", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument('--num-prompts', - type=int, - default=1, - help='Number of prompts to run.') - parser.add_argument("--num-audios", - type=int, - default=1, - choices=[0, 1, 2], - help="Number of audio items per prompt.") - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") + description="Demo on using vLLM for offline inference with " + "audio language models" + ) + parser.add_argument( + "--model-type", + "-m", + type=str, + default="ultravox", + choices=model_example_map.keys(), + help='Huggingface "model_type".', + ) + parser.add_argument( + "--num-prompts", type=int, default=1, help="Number of prompts to run." + ) + parser.add_argument( + "--num-audios", + type=int, + default=1, + choices=[0, 1, 2], + help="Number of audio items per prompt.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.", + ) return parser.parse_args() @@ -283,29 +291,30 @@ def main(args): raise ValueError(f"Model type {model} is not supported.") audio_count = args.num_audios - req_data = model_example_map[model](question_per_audio_count[audio_count], - audio_count) + req_data = model_example_map[model]( + question_per_audio_count[audio_count], audio_count + ) # Disable other modalities to save memory default_limits = {"image": 0, "video": 0, "audio": 0} req_data.engine_args.limit_mm_per_prompt = default_limits | dict( - req_data.engine_args.limit_mm_per_prompt or {}) + req_data.engine_args.limit_mm_per_prompt or {} + ) engine_args = asdict(req_data.engine_args) | {"seed": args.seed} llm = LLM(**engine_args) # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. - sampling_params = SamplingParams(temperature=0.2, - max_tokens=64, - stop_token_ids=req_data.stop_token_ids) + sampling_params = SamplingParams( + temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids + ) mm_data = {} if audio_count > 0: mm_data = { "audio": [ - asset.audio_and_sample_rate - for asset in audio_assets[:audio_count] + asset.audio_and_sample_rate for asset in audio_assets[:audio_count] ] } @@ -315,8 +324,9 @@ def main(args): # Batch inference inputs = [inputs] * args.num_prompts # Add LoRA request if applicable - lora_request = (req_data.lora_requests * - args.num_prompts if req_data.lora_requests else None) + lora_request = ( + req_data.lora_requests * args.num_prompts if req_data.lora_requests else None + ) outputs = llm.generate( inputs, diff --git a/docs/source/features/automatic_prefix_caching.md b/examples/offline_inference/automatic_prefix_caching.py similarity index 63% rename from docs/source/features/automatic_prefix_caching.md rename to examples/offline_inference/automatic_prefix_caching.py index 59016d7fcf6b..0d8c73304237 100644 --- a/docs/source/features/automatic_prefix_caching.md +++ b/examples/offline_inference/automatic_prefix_caching.py @@ -1,26 +1,31 @@ -(automatic-prefix-caching)= - -# Automatic Prefix Caching - -## Introduction +# SPDX-License-Identifier: Apache-2.0 +""" +Demonstration script for Automatic Prefix Caching (APC) in vLLM. -Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part. +Automatic Prefix Caching (APC) allows the vLLM engine to reuse cached +KV (key-value) pairs from previous prompts if a new query shares the same +prefix. This reduces redundant computation and improves inference speed. -:::{note} -Technical details on how vLLM implements APC can be found [here](#design-automatic-prefix-caching). -::: +To enable APC, set `enable_prefix_caching=True` when initializing the +vLLM engine. -## Enabling APC in vLLM +This script uses a long Markdown table as the shared prompt prefix and +compares the generation time for two queries that share the same prefix +but ask different questions. -Set `enable_prefix_caching=True` in vLLM engine to enable APC. Here is an example: +Run: +python examples/offline_inference/automatic_prefix_caching.py +""" -```python import time -from vllm import LLM, SamplingParams +from vllm import LLM, SamplingParams +# ruff: noqa: E501 # A prompt containing a large markdown table. The table is randomly generated by GPT-4. -LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """ +LONG_PROMPT = ( + "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + + """ | ID | Name | Age | Occupation | Country | Email | Phone Number | Address | |-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------| | 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL | @@ -54,6 +59,7 @@ | 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ | | 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE | """ +) def get_generation_time(llm, sampling_params, prompts): @@ -62,41 +68,35 @@ def get_generation_time(llm, sampling_params, prompts): output = llm.generate(prompts, sampling_params=sampling_params) end_time = time.time() # print the output and generation time + print("-" * 30) print(f"Output: {output[0].outputs[0].text}") print(f"Generation time: {end_time - start_time} seconds.") + print("-" * 30) -# set enable_prefix_caching=True to enable APC -llm = LLM( - model='lmsys/longchat-13b-16k', - enable_prefix_caching=True -) - -sampling_params = SamplingParams(temperature=0, max_tokens=100) - -# Querying the age of John Doe -get_generation_time( - llm, - sampling_params, - LONG_PROMPT + "Question: what is the age of John Doe? Your answer: The age of John Doe is ", -) - -# Querying the age of Zack Blue -# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again. -get_generation_time( - llm, - sampling_params, - LONG_PROMPT + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ", -) -``` +def main(): + # set enable_prefix_caching=True to enable APC + llm = LLM(model="lmsys/longchat-13b-16k", enable_prefix_caching=True) -## Example workloads + sampling_params = SamplingParams(temperature=0, max_tokens=100) -We describe two example workloads, where APC can provide huge performance benefit: + # Querying the age of John Doe + get_generation_time( + llm, + sampling_params, + LONG_PROMPT + + "Question: what is the age of John Doe? Your answer: The age of John Doe is ", + ) -- Long document query, where the user repeatedly queries the same long document (e.g. software manual or annual report) with different queries. In this case, instead of processing the long document again and again, APC allows vLLM to process this long document *only once*, and all future requests can avoid recomputing this long document by reusing its KV cache. This allows vLLM to serve future requests with much higher throughput and much lower latency. -- Multi-round conversation, where the user may chat with the application multiple times in the same chatting session. In this case, instead of processing the whole chatting history again and again, APC allows vLLM to reuse the processing results of the chat history across all future rounds of conversation, allowing vLLM to serve future requests with much higher throughput and much lower latency. + # Querying the age of Zack Blue + # This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again. + get_generation_time( + llm, + sampling_params, + LONG_PROMPT + + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ", + ) -## Limits -APC in general does not reduce the performance of vLLM. With that being said, APC only reduces the time of processing the queries (the prefilling phase) and does not reduce the time of generating new tokens (the decoding phase). So APC does not bring performance gain when vLLM spends most of the time generating answers to the queries (e.g. when the length of the answer is long), or new queries do not share the same prefix with any of existing queries (so that the computation cannot be reused). +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py index 8e6f78ed7de2..b0bb5aa71b8a 100644 --- a/examples/offline_inference/basic/chat.py +++ b/examples/offline_inference/basic/chat.py @@ -56,22 +56,12 @@ def print_outputs(outputs): # In this script, we demonstrate how to pass input to the chat method: conversation = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, { "role": "user", - "content": - "Write an essay about the importance of higher education.", + "content": "Write an essay about the importance of higher education.", }, ] outputs = llm.chat(conversation, sampling_params, use_tqdm=False) diff --git a/examples/offline_inference/basic/classify.py b/examples/offline_inference/basic/classify.py index 5b6dcb41eee1..40ccb1294e42 100644 --- a/examples/offline_inference/basic/classify.py +++ b/examples/offline_inference/basic/classify.py @@ -10,9 +10,9 @@ def parse_args(): parser = FlexibleArgumentParser() parser = EngineArgs.add_cli_args(parser) # Set example specific arguments - parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach", - task="classify", - enforce_eager=True) + parser.set_defaults( + model="jason9693/Qwen2.5-1.5B-apeach", task="classify", enforce_eager=True + ) return parser.parse_args() @@ -36,10 +36,11 @@ def main(args: Namespace): print("\nGenerated Outputs:\n" + "-" * 60) for prompt, output in zip(prompts, outputs): probs = output.outputs.probs - probs_trimmed = ((str(probs[:16])[:-1] + - ", ...]") if len(probs) > 16 else probs) - print(f"Prompt: {prompt!r} \n" - f"Class Probabilities: {probs_trimmed} (size={len(probs)})") + probs_trimmed = (str(probs[:16])[:-1] + ", ...]") if len(probs) > 16 else probs + print( + f"Prompt: {prompt!r} \n" + f"Class Probabilities: {probs_trimmed} (size={len(probs)})" + ) print("-" * 60) diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index cb5f923ffb69..38a73ccca251 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -10,9 +10,9 @@ def parse_args(): parser = FlexibleArgumentParser() parser = EngineArgs.add_cli_args(parser) # Set example specific arguments - parser.set_defaults(model="intfloat/e5-mistral-7b-instruct", - task="embed", - enforce_eager=True) + parser.set_defaults( + model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True + ) return parser.parse_args() @@ -36,10 +36,10 @@ def main(args: Namespace): print("\nGenerated Outputs:\n" + "-" * 60) for prompt, output in zip(prompts, outputs): embeds = output.outputs.embedding - embeds_trimmed = ((str(embeds[:16])[:-1] + - ", ...]") if len(embeds) > 16 else embeds) - print(f"Prompt: {prompt!r} \n" - f"Embeddings: {embeds_trimmed} (size={len(embeds)})") + embeds_trimmed = ( + (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds + ) + print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})") print("-" * 60) diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py index d2bda8b3180c..3da73c6c407d 100644 --- a/examples/offline_inference/basic/score.py +++ b/examples/offline_inference/basic/score.py @@ -10,9 +10,9 @@ def parse_args(): parser = FlexibleArgumentParser() parser = EngineArgs.add_cli_args(parser) # Set example specific arguments - parser.set_defaults(model="BAAI/bge-reranker-v2-m3", - task="score", - enforce_eager=True) + parser.set_defaults( + model="BAAI/bge-reranker-v2-m3", task="score", enforce_eager=True + ) return parser.parse_args() diff --git a/examples/offline_inference/batch_llm_inference.py b/examples/offline_inference/batch_llm_inference.py index 6548857b6d11..c1edfb52ff70 100644 --- a/examples/offline_inference/batch_llm_inference.py +++ b/examples/offline_inference/batch_llm_inference.py @@ -17,12 +17,14 @@ Learn more about Ray Data's LLM integration: https://docs.ray.io/en/latest/data/working-with-llms.html """ + import ray from packaging.version import Version from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig -assert Version(ray.__version__) >= Version( - "2.44.1"), "Ray version must be at least 2.44.1" +assert Version(ray.__version__) >= Version("2.44.1"), ( + "Ray version must be at least 2.44.1" +) # Uncomment to reduce clutter in stdout # ray.init(log_to_driver=False) @@ -53,20 +55,18 @@ vllm_processor = build_llm_processor( config, preprocess=lambda row: dict( - messages=[{ - "role": "system", - "content": "You are a bot that responds with haikus." - }, { - "role": "user", - "content": row["text"] - }], + messages=[ + {"role": "system", "content": "You are a bot that responds with haikus."}, + {"role": "user", "content": row["text"]}, + ], sampling_params=dict( temperature=0.3, max_tokens=250, - )), + ), + ), postprocess=lambda row: dict( answer=row["generated_text"], - **row # This will return all the original columns in the dataset. + **row, # This will return all the original columns in the dataset. ), ) diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py index 15519bfed9cb..61230d895584 100644 --- a/examples/offline_inference/chat_with_tools.py +++ b/examples/offline_inference/chat_with_tools.py @@ -50,87 +50,93 @@ # or any other mistral model with function calling ability sampling_params = SamplingParams(max_tokens=8192, temperature=0.0) -llm = LLM(model=model_name, - tokenizer_mode="mistral", - config_format="mistral", - load_format="mistral") +llm = LLM( + model=model_name, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral", +) def generate_random_id(length=9): characters = string.ascii_letters + string.digits - random_id = ''.join(random.choice(characters) for _ in range(length)) + random_id = "".join(random.choice(characters) for _ in range(length)) return random_id # simulate an API that can be called -def get_current_weather(city: str, state: str, unit: 'str'): - return (f"The weather in {city}, {state} is 85 degrees {unit}. It is " - "partly cloudly, with highs in the 90's.") - - -tool_funtions = {"get_current_weather": get_current_weather} - -tools = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" +def get_current_weather(city: str, state: str, unit: "str"): + return ( + f"The weather in {city}, {state} is 85 degrees {unit}. It is " + "partly cloudly, with highs in the 90's." + ) + + +tool_functions = {"get_current_weather": get_current_weather} + +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, } -}] +] -messages = [{ - "role": - "user", - "content": - "Can you tell me what the temperate will be in Dallas, in fahrenheit?" -}] +messages = [ + { + "role": "user", + "content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?", + } +] outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools) output = outputs[0].outputs[0].text.strip() # append the assistant message -messages.append({ - "role": "assistant", - "content": output, -}) +messages.append( + { + "role": "assistant", + "content": output, + } +) # let's now actually parse and execute the model's output simulating an API call by using the # above defined function tool_calls = json.loads(output) tool_answers = [ - tool_funtions[call['name']](**call['arguments']) for call in tool_calls + tool_functions[call["name"]](**call["arguments"]) for call in tool_calls ] # append the answer as a tool message and let the LLM give you an answer -messages.append({ - "role": "tool", - "content": "\n\n".join(tool_answers), - "tool_call_id": generate_random_id(), -}) +messages.append( + { + "role": "tool", + "content": "\n\n".join(tool_answers), + "tool_call_id": generate_random_id(), + } +) outputs = llm.chat(messages, sampling_params, tools=tools) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 965915beaf58..bf60d883c410 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -27,6 +27,7 @@ --master-addr=10.99.48.128 \ --master-port=13345 """ + import os from time import sleep @@ -36,40 +37,46 @@ def parse_args(): import argparse + parser = argparse.ArgumentParser(description="Data Parallel Inference") - parser.add_argument("--model", - type=str, - default="ibm-research/PowerMoE-3b", - help="Model name or path") - parser.add_argument("--dp-size", - type=int, - default=2, - help="Data parallel size") - parser.add_argument("--tp-size", - type=int, - default=2, - help="Tensor parallel size") - parser.add_argument("--node-size", - type=int, - default=1, - help="Total number of nodes") - parser.add_argument("--node-rank", - type=int, - default=0, - help="Rank of the current node") - parser.add_argument("--master-addr", - type=str, - default="", - help="Master node IP address") - parser.add_argument("--master-port", - type=int, - default=0, - help="Master node port") + parser.add_argument( + "--model", + type=str, + default="ibm-research/PowerMoE-3b", + help="Model name or path", + ) + parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size") + parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size") + parser.add_argument( + "--node-size", type=int, default=1, help="Total number of nodes" + ) + parser.add_argument( + "--node-rank", type=int, default=0, help="Rank of the current node" + ) + parser.add_argument( + "--master-addr", type=str, default="", help="Master node IP address" + ) + parser.add_argument("--master-port", type=int, default=0, help="Master node port") + parser.add_argument( + "--enforce-eager", action="store_true", help="Enforce eager mode execution." + ) + parser.add_argument( + "--trust-remote-code", action="store_true", help="Trust remote code." + ) return parser.parse_args() -def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, - dp_master_port, GPUs_per_dp_rank): +def main( + model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + GPUs_per_dp_rank, + enforce_eager, + trust_remote_code, +): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) @@ -104,15 +111,18 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, # since we are doing data parallel, every rank can have different # sampling params. here we set different max_tokens for different # ranks for demonstration. - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=[16, 20][global_dp_rank % 2]) + sampling_params = SamplingParams( + temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2] + ) # Create an LLM. - llm = LLM(model=model, - tensor_parallel_size=GPUs_per_dp_rank, - enforce_eager=True, - enable_expert_parallel=True) + llm = LLM( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, + enable_expert_parallel=True, + trust_remote_code=trust_remote_code, + ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for i, output in enumerate(outputs): @@ -121,15 +131,16 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, break prompt = output.prompt generated_text = output.outputs[0].text - print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print( + f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}" + ) # Give engines time to pause their processing loops before exiting. sleep(1) if __name__ == "__main__": - args = parse_args() dp_size = args.dp_size @@ -151,19 +162,29 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, procs = [] for local_dp_rank, global_dp_rank in enumerate( - range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): - proc = Process(target=main, - args=(args.model, dp_size, local_dp_rank, - global_dp_rank, dp_master_ip, dp_master_port, - tp_size)) + range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node) + ): + proc = Process( + target=main, + args=( + args.model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + tp_size, + args.enforce_eager, + args.trust_remote_code, + ), + ) proc.start() procs.append(proc) exit_code = 0 for proc in procs: proc.join(timeout=300) if proc.exitcode is None: - print(f"Killing process {proc.pid} that " - f"didn't stop within 5 minutes.") + print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") proc.kill() exit_code = 1 elif proc.exitcode: diff --git a/examples/offline_inference/disaggregated-prefill-v1/README.md b/examples/offline_inference/disaggregated-prefill-v1/README.md new file mode 100644 index 000000000000..9cbdb19820f5 --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/README.md @@ -0,0 +1,10 @@ +# Disaggregated Prefill V1 + +This example contains scripts that demonstrate disaggregated prefill in the offline setting of vLLM. + +## Files + +- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially. + - Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`. +- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`. +- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`. diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py index 66efbc0c9dee..4ae5d3310e0b 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -3,34 +3,48 @@ from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig -# Read prompts from output.txt -prompts = [] -try: - with open("output.txt") as f: - for line in f: - prompts.append(line.strip()) - print(f"Loaded {len(prompts)} prompts from output.txt") -except FileNotFoundError: - print("Error: output.txt file not found") - exit(-1) - -sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - -llm = LLM( - model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - max_num_batched_tokens=64, - max_num_seqs=16, - kv_transfer_config=KVTransferConfig.from_cli( - '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",' - '"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}' - )) #, max_model_len=2048, max_num_batched_tokens=2048) - -# 1ST generation (prefill instance) -outputs = llm.generate(prompts, sampling_params) - -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +def read_prompts(): + """Read prompts from output.txt""" + prompts = [] + try: + with open("output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from output.txt") + return prompts + except FileNotFoundError: + print("Error: output.txt file not found") + exit(-1) + + +def main(): + prompts = read_prompts() + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ), + ) # , max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate(prompts, sampling_params) + + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py index f7cbf6557d54..5757a8a84b86 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -3,41 +3,55 @@ from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig -context = "Hi " * 1000 -context2 = "Hey " * 500 -prompts = [ - context + "Hello, my name is", - context + "The capital of France is", - context2 + "Your name is", - context2 + "The capital of China is", -] - -sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - -llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - kv_transfer_config=KVTransferConfig.from_cli( - '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", ' - '"kv_connector_extra_config": ' - '{"shared_storage_path": "local_storage"}}') - ) #, max_model_len=2048, max_num_batched_tokens=2048) - -# 1ST generation (prefill instance) -outputs = llm.generate( - prompts, - sampling_params, -) - -new_prompts = [] -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - new_prompts.append(prompt + generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -# Write new_prompts to output.txt -with open("output.txt", "w") as f: - for prompt in new_prompts: - f.write(prompt + "\n") -print(f"Saved {len(new_prompts)} prompts to output.txt") + +def read_prompts(): + context = "Hi " * 1000 + context2 = "Hey " * 500 + return [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", + ] + + +def main(): + prompts = read_prompts() + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ), + ) # , max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate( + prompts, + sampling_params, + ) + + new_prompts = [] + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + # Write new_prompts to output.txt + with open("output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + print(f"Saved {len(new_prompts)} prompts to output.txt") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index d60985146c5c..3ccab0dcd6d3 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -4,6 +4,7 @@ We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), and then transfer the KV cache between them. """ + import os import time from multiprocessing import Event, Process @@ -32,16 +33,21 @@ def run_prefill(prefill_done): # This instance is the prefill node (kv_producer, rank 0). # The number of parallel instances for KV cache transfer is set to 2, # as required for PyNcclConnector. - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' + ktc = KVTransferConfig( + kv_connector="PyNcclConnector", + kv_role="kv_producer", + kv_rank=0, + kv_parallel_size=2, ) # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # memory. You may need to adjust the value to fit your GPU. - llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", - kv_transfer_config=ktc, - max_model_len=2000, - gpu_memory_utilization=0.8) + llm = LLM( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8, + ) llm.generate(prompts, sampling_params) print("Prefill node is finished.") @@ -71,16 +77,21 @@ def run_decode(prefill_done): # This instance is the decode node (kv_consumer, rank 1). # The number of parallel instances for KV cache transfer is set to 2, # as required for PyNcclConnector. - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' + ktc = KVTransferConfig( + kv_connector="PyNcclConnector", + kv_role="kv_consumer", + kv_rank=1, + kv_parallel_size=2, ) # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # memory. You may need to adjust the value to fit your GPU. - llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", - kv_transfer_config=ktc, - max_model_len=2000, - gpu_memory_utilization=0.8) + llm = LLM( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8, + ) # Wait for the producer to start the pipe print("Waiting for prefill node to finish...") @@ -97,8 +108,8 @@ def run_decode(prefill_done): def main(): prefill_done = Event() - prefill_process = Process(target=run_prefill, args=(prefill_done, )) - decode_process = Process(target=run_decode, args=(prefill_done, )) + prefill_process = Process(target=run_prefill, args=(prefill_done,)) + decode_process = Process(target=run_decode, args=(prefill_done,)) # Start prefill node prefill_process.start() diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 91e2f68ecffb..606ce7799a88 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams +from vllm.v1.metrics.reader import Counter, Vector def load_prompts(dataset_path, num_prompts): @@ -20,9 +21,7 @@ def load_prompts(dataset_path, num_prompts): print(f"Error reading dataset: {e}") return [] else: - prompts = [ - "The future of AI is", "The president of the United States is" - ] + prompts = ["The future of AI is", "The president of the United States is"] return prompts[:num_prompts] @@ -33,34 +32,32 @@ def parse_args(): "--dataset", type=str, default="./examples/data/gsm8k.jsonl", - help="downloaded from the eagle repo " \ - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" + help="downloaded from the eagle repo " + "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/", + ) + parser.add_argument( + "--method", type=str, default="eagle", choices=["eagle", "eagle3"] ) - parser.add_argument("--method", - type=str, - default='eagle', - choices=['eagle', 'eagle3']) parser.add_argument("--max_num_seqs", type=int, default=8) parser.add_argument("--num_prompts", type=int, default=80) parser.add_argument("--num_spec_tokens", type=int, default=2) parser.add_argument("--tp", type=int, default=1) parser.add_argument("--draft_tp", type=int, default=1) - parser.add_argument("--enforce_eager", action='store_true') - parser.add_argument("--enable_chunked_prefill", action='store_true') + parser.add_argument("--enforce_eager", action="store_true") + parser.add_argument("--enable_chunked_prefill", action="store_true") parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) return parser.parse_args() def main(): - args = parse_args() model_dir = "meta-llama/Llama-3.1-8B-Instruct" - if args.method == 'eagle': + if args.method == "eagle": eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - elif args.method == 'eagle3': + elif args.method == "eagle3": eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" else: raise ValueError(f"unknown method: {args.method}") @@ -72,11 +69,9 @@ def main(): prompts = load_prompts(args.dataset, args.num_prompts) prompt_ids = [ - tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - add_generation_prompt=True) + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], add_generation_prompt=True + ) for prompt in prompts ] @@ -102,30 +97,42 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) - outputs = llm.generate(prompt_token_ids=prompt_ids, - sampling_params=sampling_params) + outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) - if not hasattr(outputs, "metrics") or outputs.metrics is None: + # print the generated text + for output in outputs: + print("-" * 50) + print(f"prompt: {output.prompt}") + print(f"generated text: {output.outputs[0].text}") + print("-" * 50) + + try: + metrics = llm.get_metrics() + except AssertionError: + print("Metrics are not supported in the V0 engine.") return - # calculate the average number of accepted tokens per forward pass, +1 is - # to account for the token from the target model that's always going to be - # accepted - acceptance_counts = [0] * (args.num_spec_tokens + 1) - for output in outputs: - for step, count in enumerate( - output.metrics.spec_token_acceptance_counts): - acceptance_counts[step] += count + num_drafts = num_accepted = 0 + acceptance_counts = [0] * args.num_spec_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] print("-" * 50) - print(f"mean acceptance length: \ - {sum(acceptance_counts) / acceptance_counts[0]:.2f}") + print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") print("-" * 50) # print acceptance at each token position for i in range(len(acceptance_counts)): - print(f"acceptance at token {i}:" - f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}") + print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") if __name__ == "__main__": diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/embed_jina_embeddings_v3.py index b347ddbf3197..23f60c431fc2 100644 --- a/examples/offline_inference/embed_jina_embeddings_v3.py +++ b/examples/offline_inference/embed_jina_embeddings_v3.py @@ -10,9 +10,9 @@ def parse_args(): parser = FlexibleArgumentParser() parser = EngineArgs.add_cli_args(parser) # Set example specific arguments - parser.set_defaults(model="jinaai/jina-embeddings-v3", - task="embed", - trust_remote_code=True) + parser.set_defaults( + model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True + ) return parser.parse_args() @@ -41,11 +41,14 @@ def main(args: Namespace): print("-" * 60) for prompt, output in zip(prompts, outputs): embeds = output.outputs.embedding - embeds_trimmed = ((str(embeds[:16])[:-1] + - ", ...]") if len(embeds) > 16 else embeds) - print(f"Prompt: {prompt!r} \n" - f"Embeddings for text matching: {embeds_trimmed} " - f"(size={len(embeds)})") + embeds_trimmed = ( + (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds + ) + print( + f"Prompt: {prompt!r} \n" + f"Embeddings for text matching: {embeds_trimmed} " + f"(size={len(embeds)})" + ) print("-" * 60) diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/embed_matryoshka_fy.py index 7a6cb02556d9..59c0592ae9e2 100644 --- a/examples/offline_inference/embed_matryoshka_fy.py +++ b/examples/offline_inference/embed_matryoshka_fy.py @@ -10,9 +10,9 @@ def parse_args(): parser = FlexibleArgumentParser() parser = EngineArgs.add_cli_args(parser) # Set example specific arguments - parser.set_defaults(model="jinaai/jina-embeddings-v3", - task="embed", - trust_remote_code=True) + parser.set_defaults( + model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True + ) return parser.parse_args() @@ -39,11 +39,10 @@ def main(args: Namespace): print("-" * 60) for prompt, output in zip(prompts, outputs): embeds = output.outputs.embedding - embeds_trimmed = ((str(embeds[:16])[:-1] + - ", ...]") if len(embeds) > 16 else embeds) - print(f"Prompt: {prompt!r} \n" - f"Embeddings: {embeds_trimmed} " - f"(size={len(embeds)})") + embeds_trimmed = ( + (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds + ) + print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})") print("-" * 60) diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index c4916e00f473..83dd1f667eb5 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 -''' +""" Demonstrate prompting of text-to-text encoder/decoder models, specifically BART -''' +""" from vllm import LLM, SamplingParams -from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - TokensPrompt, zip_enc_dec_prompts) +from vllm.inputs import ( + ExplicitEncoderDecoderPrompt, + TextPrompt, + TokensPrompt, + zip_enc_dec_prompts, +) def create_prompts(tokenizer): @@ -18,8 +22,9 @@ def create_prompts(tokenizer): # - Helpers for building prompts text_prompt_raw = "Hello, my name is" text_prompt = TextPrompt(prompt="The president of the United States is") - tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( - prompt="The capital of France is")) + tokens_prompt = TokensPrompt( + prompt_token_ids=tokenizer.encode(prompt="The capital of France is") + ) # - Pass a single prompt to encoder/decoder model # (implicitly encoder input prompt); # decoder input prompt is assumed to be None @@ -57,14 +62,19 @@ def create_prompts(tokenizer): # decoder prompts together into a list of ExplicitEncoderDecoderPrompt # instances zipped_prompt_list = zip_enc_dec_prompts( - ['An encoder prompt', 'Another encoder prompt'], - ['A decoder prompt', 'Another decoder prompt']) + ["An encoder prompt", "Another encoder prompt"], + ["A decoder prompt", "Another decoder prompt"], + ) # - Let's put all of the above example prompts together into one list # which we will pass to the encoder/decoder LLM. return [ - single_text_prompt_raw, single_text_prompt, single_tokens_prompt, - enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 + single_text_prompt_raw, + single_text_prompt, + single_tokens_prompt, + enc_dec_prompt1, + enc_dec_prompt2, + enc_dec_prompt3, ] + zipped_prompt_list @@ -85,10 +95,12 @@ def print_outputs(outputs): prompt = output.prompt encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text - print(f"Output {i+1}:") - print(f"Encoder prompt: {encoder_prompt!r}\n" - f"Decoder prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") + print(f"Output {i + 1}:") + print( + f"Encoder prompt: {encoder_prompt!r}\n" + f"Decoder prompt: {prompt!r}\n" + f"Generated text: {generated_text!r}" + ) print("-" * 50) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 2883c37ca236..ae3737e37594 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -3,6 +3,7 @@ This example shows how to use vLLM for running offline inference with the explicit/implicit prompt format on enc-dec LMMs for text generation. """ + import time from collections.abc import Sequence from dataclasses import asdict @@ -30,18 +31,14 @@ def run_florence2(): ) prompts = [ - { # implicit prompt with task token + { # implicit prompt with task token "prompt": "<DETAILED_CAPTION>", - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image - }, + "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image}, }, - { # explicit encoder/decoder prompt + { # explicit encoder/decoder prompt "encoder_prompt": { "prompt": "Describe in detail what is shown in the image.", - "multi_modal_data": { - "image": ImageAsset("cherry_blossom").pil_image - }, + "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image}, }, "decoder_prompt": "", }, @@ -63,20 +60,20 @@ def run_mllama(): ) prompts = [ - { # Implicit prompt - "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501 + { # Implicit prompt + "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501 "multi_modal_data": { "image": ImageAsset("stop_sign").pil_image, }, }, - { # Explicit prompt + { # Explicit prompt "encoder_prompt": { "prompt": "<|image|>", "multi_modal_data": { "image": ImageAsset("stop_sign").pil_image, }, }, - "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501 + "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501 }, ] @@ -96,13 +93,13 @@ def run_whisper(): ) prompts = [ - { # Test implicit prompt + { # Test implicit prompt "prompt": "<|startoftranscript|>", "multi_modal_data": { "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, }, }, - { # Test explicit encoder/decoder prompt + { # Test explicit encoder/decoder prompt "encoder_prompt": { "prompt": "", "multi_modal_data": { @@ -110,7 +107,7 @@ def run_whisper(): }, }, "decoder_prompt": "<|startoftranscript|>", - } + }, ] return ModelRequestData( @@ -128,18 +125,23 @@ def run_whisper(): def parse_args(): parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models for text generation') - parser.add_argument('--model-type', - '-m', - type=str, - default="mllama", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") + description="Demo on using vLLM for offline inference with " + "vision language models for text generation" + ) + parser.add_argument( + "--model-type", + "-m", + type=str, + default="mllama", + choices=model_example_map.keys(), + help='Huggingface "model_type".', + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.", + ) return parser.parse_args() @@ -153,7 +155,8 @@ def main(args): # Disable other modalities to save memory default_limits = {"image": 0, "video": 0, "audio": 0} req_data.engine_args.limit_mm_per_prompt = default_limits | dict( - req_data.engine_args.limit_mm_per_prompt or {}) + req_data.engine_args.limit_mm_per_prompt or {} + ) engine_args = asdict(req_data.engine_args) | {"seed": args.seed} llm = LLM(**engine_args) @@ -179,8 +182,7 @@ def main(args): for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Decoder prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}") duration = time.time() - start diff --git a/examples/offline_inference/llm_engine_example.py b/examples/offline_inference/llm_engine_example.py index d84cd9ee9f52..5d5e55a83d22 100644 --- a/examples/offline_inference/llm_engine_example.py +++ b/examples/offline_inference/llm_engine_example.py @@ -3,6 +3,7 @@ This file demonstrates using the `LLMEngine` for processing prompts with various sampling parameters. """ + import argparse from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams @@ -12,24 +13,26 @@ def create_test_prompts() -> list[tuple[str, SamplingParams]]: """Create a list of test prompts with their sampling parameters.""" return [ - ("A robot may not injure a human being", - SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)), - ("To be or not to be,", - SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), - ("What is the meaning of life?", - SamplingParams(n=2, - temperature=0.8, - top_p=0.95, - frequency_penalty=0.1)), + ( + "A robot may not injure a human being", + SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1), + ), + ( + "To be or not to be,", + SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2), + ), + ( + "What is the meaning of life?", + SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1), + ), ] -def process_requests(engine: LLMEngine, - test_prompts: list[tuple[str, SamplingParams]]): +def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 - print('-' * 50) + print("-" * 50) while test_prompts or engine.has_unfinished_requests(): if test_prompts: prompt, sampling_params = test_prompts.pop(0) @@ -41,7 +44,7 @@ def process_requests(engine: LLMEngine, for request_output in request_outputs: if request_output.finished: print(request_output) - print('-' * 50) + print("-" * 50) def initialize_engine(args: argparse.Namespace) -> LLMEngine: @@ -52,7 +55,8 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine: def parse_args(): parser = FlexibleArgumentParser( - description='Demo on using the LLMEngine class directly') + description="Demo on using the LLMEngine class directly" + ) parser = EngineArgs.add_cli_args(parser) return parser.parse_args() @@ -64,6 +68,6 @@ def main(args: argparse.Namespace): process_requests(engine, test_prompts) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/examples/offline_inference/load_sharded_state.py b/examples/offline_inference/load_sharded_state.py index 7e90d5d25e29..5bb2327a3f83 100644 --- a/examples/offline_inference/load_sharded_state.py +++ b/examples/offline_inference/load_sharded_state.py @@ -36,22 +36,21 @@ def parse_args(): parser.set_defaults(load_format="sharded_state") # Add validation arguments - parser.add_argument("--prompt", - type=str, - default="Hello, world!", - help="Prompt for validation") - parser.add_argument("--max-tokens", - type=int, - default=100, - help="Maximum number of tokens to generate") - parser.add_argument("--temperature", - type=float, - default=0.7, - help="Sampling temperature") - parser.add_argument("--top-p", - type=float, - default=1.0, - help="Top-p sampling parameter") + parser.add_argument( + "--prompt", type=str, default="Hello, world!", help="Prompt for validation" + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", type=float, default=0.7, help="Sampling temperature" + ) + parser.add_argument( + "--top-p", type=float, default=1.0, help="Top-p sampling parameter" + ) return parser.parse_args() @@ -60,8 +59,9 @@ def main(): args = parse_args() engine_args = EngineArgs.from_cli_args(args) - print(f"Loading model from {engine_args.model} " - f"using format {engine_args.load_format}") + print( + f"Loading model from {engine_args.model} using format {engine_args.load_format}" + ) print(f"Tensor parallel size: {engine_args.tensor_parallel_size}") # Load the model using engine args @@ -90,4 +90,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index b6608ec6e958..33c660015ba7 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -17,50 +17,55 @@ def create_test_prompts( - lora_path: str + lora_path: str, ) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: return [ # this is an example of using quantization without LoRA - ("My name is", - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128), None), + ( + "My name is", + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 + ), + None, + ), # the next three examples use quantization with LoRA - ("my name is", - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128), - LoRARequest("lora-test-1", 1, lora_path)), - ("The capital of USA is", - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128), - LoRARequest("lora-test-2", 1, lora_path)), - ("The capital of France is", - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128), - LoRARequest("lora-test-3", 1, lora_path)), + ( + "my name is", + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 + ), + LoRARequest("lora-test-1", 1, lora_path), + ), + ( + "The capital of USA is", + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 + ), + LoRARequest("lora-test-2", 1, lora_path), + ), + ( + "The capital of France is", + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 + ), + LoRARequest("lora-test-3", 1, lora_path), + ), ] -def process_requests(engine: LLMEngine, - test_prompts: list[tuple[str, SamplingParams, - Optional[LoRARequest]]]): +def process_requests( + engine: LLMEngine, + test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]], +): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 while test_prompts or engine.has_unfinished_requests(): if test_prompts: prompt, sampling_params, lora_request = test_prompts.pop(0) - engine.add_request(str(request_id), - prompt, - sampling_params, - lora_request=lora_request) + engine.add_request( + str(request_id), prompt, sampling_params, lora_request=lora_request + ) request_id += 1 request_outputs: list[RequestOutput] = engine.step() @@ -71,15 +76,18 @@ def process_requests(engine: LLMEngine, print(f"Output: {request_output.outputs[0].text}") -def initialize_engine(model: str, quantization: str, - lora_repo: Optional[str]) -> LLMEngine: +def initialize_engine( + model: str, quantization: str, lora_repo: Optional[str] +) -> LLMEngine: """Initialize the LLMEngine.""" - engine_args = EngineArgs(model=model, - quantization=quantization, - enable_lora=True, - max_lora_rank=64, - max_loras=4) + engine_args = EngineArgs( + model=model, + quantization=quantization, + enable_lora=True, + max_lora_rank=64, + max_loras=4, + ) return LLMEngine.from_engine_args(engine_args) @@ -90,32 +98,30 @@ def main(): # QLoRA (https://arxiv.org/abs/2305.14314) { "name": "qlora_inference_example", - 'model': "huggyllama/llama-7b", - 'quantization': "bitsandbytes", - 'lora_repo': 'timdettmers/qlora-flan-7b' + "model": "huggyllama/llama-7b", + "quantization": "bitsandbytes", + "lora_repo": "timdettmers/qlora-flan-7b", }, { "name": "AWQ_inference_with_lora_example", - 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', - 'quantization': "awq", - 'lora_repo': 'jashing/tinyllama-colorist-lora' + "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + "quantization": "awq", + "lora_repo": "jashing/tinyllama-colorist-lora", }, { "name": "GPTQ_inference_with_lora_example", - 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', - 'quantization': "gptq", - 'lora_repo': 'jashing/tinyllama-colorist-lora' - } + "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + "quantization": "gptq", + "lora_repo": "jashing/tinyllama-colorist-lora", + }, ] for test_config in test_configs: - print( - f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~" + print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~") + engine = initialize_engine( + test_config["model"], test_config["quantization"], test_config["lora_repo"] ) - engine = initialize_engine(test_config['model'], - test_config['quantization'], - test_config['lora_repo']) - lora_path = snapshot_download(repo_id=test_config['lora_repo']) + lora_path = snapshot_download(repo_id=test_config["lora_repo"]) test_prompts = create_test_prompts(lora_path) process_requests(engine, test_prompts) @@ -125,5 +131,5 @@ def main(): torch.cuda.empty_cache() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/offline_inference/metrics.py b/examples/offline_inference/metrics.py new file mode 100644 index 000000000000..7927f758cb57 --- /dev/null +++ b/examples/offline_inference/metrics.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + +def main(): + # Create an LLM. + llm = LLM(model="facebook/opt-125m", disable_log_stats=False) + + # Generate texts from the prompts. + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + print("-" * 50) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + # Dump all metrics + for metric in llm.get_metrics(): + if isinstance(metric, Gauge): + print(f"{metric.name} (gauge) = {metric.value}") + elif isinstance(metric, Counter): + print(f"{metric.name} (counter) = {metric.value}") + elif isinstance(metric, Vector): + print(f"{metric.name} (vector) = {metric.values}") + elif isinstance(metric, Histogram): + print(f"{metric.name} (histogram)") + print(f" sum = {metric.sum}") + print(f" count = {metric.count}") + for bucket_le, value in metric.buckets.items(): + print(f" {bucket_le} = {value}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index 37c3181dc5fa..98fef2648f6b 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -74,19 +74,10 @@ def run_simple_demo(args: argparse.Namespace): messages = [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, ], }, ] @@ -121,25 +112,11 @@ def run_advanced_demo(args: argparse.Namespace): messages = [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": url_1 - } - }, - { - "type": "image_url", - "image_url": { - "url": url_2 - } - }, + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": url_1}}, + {"type": "image_url", "image_url": {"url": url_2}}, ], }, { @@ -153,12 +130,7 @@ def run_advanced_demo(args: argparse.Namespace): { "role": "user", "content": [ - { - "type": "image_url", - "image_url": { - "url": url_3 - } - }, + {"type": "image_url", "image_url": {"url": url_3}}, ], }, ] @@ -171,7 +143,8 @@ def run_advanced_demo(args: argparse.Namespace): def parse_args(): parser = argparse.ArgumentParser( - description="Run a demo in simple or advanced mode.") + description="Run a demo in simple or advanced mode." + ) parser.add_argument( "mode", @@ -179,15 +152,18 @@ def parse_args(): help="Specify the demo mode: 'simple' or 'advanced'", ) - parser.add_argument('--format', - choices=["mistral", "hf"], - default="mistral", - help='Specify the format of the model to load.') + parser.add_argument( + "--format", + choices=["mistral", "hf"], + default="mistral", + help="Specify the format of the model to load.", + ) parser.add_argument( - '--disable-mm-preprocessor-cache', - action='store_true', - help='If True, disables caching of multi-modal preprocessor/mapper.') + "--disable-mm-preprocessor-cache", + action="store_true", + help="If True, disables caching of multi-modal preprocessor/mapper.", + ) return parser.parse_args() diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index 53c58a76d9dc..b750397f45b8 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -13,8 +13,9 @@ from vllm import LLM, SamplingParams -def time_generation(llm: LLM, prompts: list[str], - sampling_params: SamplingParams, title: str): +def time_generation( + llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str +): # Generate texts from the prompts. The output is a list of RequestOutput # objects that contain the prompt, generated text, and other information. # Warmup first @@ -25,8 +26,7 @@ def time_generation(llm: LLM, prompts: list[str], end = time.time() print("-" * 50) print(title) - print("time: ", - (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs)) + print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs)) # Print the outputs. for output in outputs: generated_text = output.outputs[0].text @@ -38,7 +38,8 @@ def main(): template = ( "Below is an instruction that describes a task. Write a response " "that appropriately completes the request.\n\n### Instruction:\n{}" - "\n\n### Response:\n") + "\n\n### Response:\n" + ) # Sample prompts. prompts = [ diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py index de409740292a..1fa2f16f82a8 100644 --- a/examples/offline_inference/multilora_inference.py +++ b/examples/offline_inference/multilora_inference.py @@ -15,7 +15,7 @@ def create_test_prompts( - lora_path: str + lora_path: str, ) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: """Create a list of test prompts with their sampling parameters. @@ -26,38 +26,49 @@ def create_test_prompts( first adapter have finished. """ return [ - ("A robot may not injure a human being", - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128), None), - ("To be or not to be,", - SamplingParams(temperature=0.8, - top_k=5, - presence_penalty=0.2, - max_tokens=128), None), + ( + "A robot may not injure a human being", + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 + ), + None, + ), + ( + "To be or not to be,", + SamplingParams( + temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128 + ), + None, + ), ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128, - stop_token_ids=[32003]), - LoRARequest("sql-lora", 1, lora_path)), + SamplingParams( + temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003], + ), + LoRARequest("sql-lora", 1, lora_path), + ), ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128, - stop_token_ids=[32003]), - LoRARequest("sql-lora2", 2, lora_path)), + SamplingParams( + temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003], + ), + LoRARequest("sql-lora2", 2, lora_path), + ), ] -def process_requests(engine: LLMEngine, - test_prompts: list[tuple[str, SamplingParams, - Optional[LoRARequest]]]): +def process_requests( + engine: LLMEngine, + test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]], +): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 @@ -65,10 +76,9 @@ def process_requests(engine: LLMEngine, while test_prompts or engine.has_unfinished_requests(): if test_prompts: prompt, sampling_params, lora_request = test_prompts.pop(0) - engine.add_request(str(request_id), - prompt, - sampling_params, - lora_request=lora_request) + engine.add_request( + str(request_id), prompt, sampling_params, lora_request=lora_request + ) request_id += 1 request_outputs: list[RequestOutput] = engine.step() @@ -88,12 +98,14 @@ def initialize_engine() -> LLMEngine: # numbers will cause higher memory usage. If you know that all LoRAs will # use the same rank, it is recommended to set this as low as possible. # max_cpu_loras: controls the size of the CPU LoRA cache. - engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", - enable_lora=True, - max_loras=1, - max_lora_rank=8, - max_cpu_loras=2, - max_num_seqs=256) + engine_args = EngineArgs( + model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, + max_num_seqs=256, + ) return LLMEngine.from_engine_args(engine_args) @@ -105,5 +117,5 @@ def main(): process_requests(engine, test_prompts) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/offline_inference/neuron.py b/examples/offline_inference/neuron.py index 5906c7b2c6b3..f2d7698f22d7 100644 --- a/examples/offline_inference/neuron.py +++ b/examples/offline_inference/neuron.py @@ -30,7 +30,8 @@ def main(): # The device argument can be either unspecified for automated detection, # or explicitly assigned. device="neuron", - tensor_parallel_size=2) + tensor_parallel_size=2, + ) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py index 4f63f1a2fb3c..5d7fb819d347 100644 --- a/examples/offline_inference/neuron_eagle.py +++ b/examples/offline_inference/neuron_eagle.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -This example shows how to run offline inference with an EAGLE speculative +This example shows how to run offline inference with an EAGLE speculative decoding model on neuron. To use EAGLE speculative decoding, you must use a draft model that is specifically fine-tuned for EAGLE speculation. Additionally, to use EAGLE with NxD Inference, the draft model must include @@ -15,40 +15,46 @@ "What is annapurna labs?", ] -# Create a sampling params object. -sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) - -# Create an LLM. -llm = LLM( - model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", - speculative_config={ - "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", - "num_speculative_tokens": 5, - "max_model_len": 2048 - }, - max_num_seqs=4, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in neuronx-distributed-inference. - max_model_len=2048, - block_size=2048, - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "enable_eagle_speculation": True, - "enable_fused_speculation": True - }, -) - -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}") + +def main(): + # Create a sampling params object. + sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) + + # Create an LLM. + llm = LLM( + model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", + speculative_config={ + "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", + "num_speculative_tokens": 5, + "max_model_len": 2048, + }, + max_num_seqs=4, + # The max_model_len and block_size arguments are required to be same as + # max sequence length when targeting neuron device. + # Currently, this is a known limitation in continuous batching support + # in neuronx-distributed-inference. + max_model_len=2048, + block_size=2048, + # The device can be automatically detected when AWS Neuron SDK is installed. + # The device argument can be either unspecified for automated detection, + # or explicitly assigned. + device="neuron", + tensor_parallel_size=32, + override_neuron_config={ + "enable_eagle_speculation": True, + "enable_fused_speculation": True, + }, + ) + + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/neuron_int8_quantization.py b/examples/offline_inference/neuron_int8_quantization.py index af21274a3a5b..ec38525b9daf 100644 --- a/examples/offline_inference/neuron_int8_quantization.py +++ b/examples/offline_inference/neuron_int8_quantization.py @@ -5,12 +5,12 @@ from vllm import LLM, SamplingParams # creates XLA hlo graphs for all the context length buckets. -os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" +os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" # creates XLA hlo graphs for all the token gen buckets. -os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" +os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" # Quantizes neuron model weight to int8 , # The default config for quantization is int8 dtype. -os.environ['NEURON_QUANT_DTYPE'] = "s8" +os.environ["NEURON_QUANT_DTYPE"] = "s8" # Sample prompts. prompts = [ @@ -44,7 +44,8 @@ def main(): override_neuron_config={ "cast_logits_dtype": "bfloat16", }, - tensor_parallel_size=2) + tensor_parallel_size=2, + ) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py index bef434bae5ba..ecacbab771c2 100644 --- a/examples/offline_inference/neuron_speculation.py +++ b/examples/offline_inference/neuron_speculation.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -This example shows how to run offline inference with a speculative +This example shows how to run offline inference with a speculative decoding model on neuron. """ @@ -19,9 +19,9 @@ def config_buckets(): """Configure context length and token gen buckets.""" # creates XLA hlo graphs for all the context length buckets. - os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" + os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" # creates XLA hlo graphs for all the token gen buckets. - os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" def initialize_model(): @@ -31,7 +31,7 @@ def initialize_model(): speculative_config={ "model": "openlm-research/open_llama_3b", "num_speculative_tokens": 4, - "max_model_len": 2048 + "max_model_len": 2048, }, max_num_seqs=4, max_model_len=2048, @@ -60,5 +60,5 @@ def main(): process_requests(model, sampling_params) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/offline_inference/openai/openai_batch.md b/examples/offline_inference/openai_batch/README.md similarity index 94% rename from examples/offline_inference/openai/openai_batch.md rename to examples/offline_inference/openai_batch/README.md index d271573aa96f..42a19f71e9de 100644 --- a/examples/offline_inference/openai/openai_batch.md +++ b/examples/offline_inference/openai_batch/README.md @@ -8,7 +8,7 @@ This is a guide to performing batch inference using the OpenAI batch file format The OpenAI batch file format consists of a series of json objects on new lines. -[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai/openai_example_batch.jsonl) +[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl) Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. @@ -30,13 +30,13 @@ We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` e To follow along with this example, you can download the example batch, or create your own batch file in your working directory. ```console -wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this ```console -$ cat offline_inference/openai/openai_example_batch.jsonl +$ cat offline_inference/openai_batch/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` @@ -48,7 +48,7 @@ The batch running tool is designed to be used from the command line. You can run the batch with the following command, which will write its results to a file called `results.jsonl` ```console -python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai_batch/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` ### Step 3: Check your results @@ -65,10 +65,10 @@ $ cat results.jsonl The batch runner supports remote input and output urls that are accessible via http/https. -For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl`, you can run +For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl`, you can run ```console -python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` ## Example 3: Integrating with AWS S3 @@ -89,13 +89,13 @@ To integrate with cloud blob storage, we recommend using presigned urls. To follow along with this example, you can download the example batch, or create your own batch file in your working directory. ```console -wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this ```console -$ cat offline_inference/openai/openai_example_batch.jsonl +$ cat offline_inference/openai_batch/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` @@ -103,7 +103,7 @@ $ cat offline_inference/openai/openai_example_batch.jsonl Now upload your batch file to your S3 bucket. ```console -aws s3 cp offline_inference/openai/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl +aws s3 cp offline_inference/openai_batch/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl ``` ### Step 2: Generate your presigned urls diff --git a/examples/offline_inference/openai/openai_example_batch.jsonl b/examples/offline_inference/openai_batch/openai_example_batch.jsonl similarity index 100% rename from examples/offline_inference/openai/openai_example_batch.jsonl rename to examples/offline_inference/openai_batch/openai_example_batch.jsonl diff --git a/examples/offline_inference/prefix_caching.py b/examples/offline_inference/prefix_caching.py index f0bec387d3a9..d3dad24956a6 100644 --- a/examples/offline_inference/prefix_caching.py +++ b/examples/offline_inference/prefix_caching.py @@ -16,7 +16,8 @@ "teaching role. They have 5 years of previous teaching experience " "as an assistant teacher at a co-ed, public school with experience " "in middle school math teaching. Based on these information, fulfill " - "the following paragraph: ") + "the following paragraph: " +) # Sample prompts. prompts = [ @@ -58,9 +59,11 @@ def main(): cleanup_dist_env_and_memory() # Create an LLM with prefix caching enabled. - prefix_cached_llm = LLM(model="facebook/opt-125m", - enable_prefix_caching=True, - gpu_memory_utilization=0.4) + prefix_cached_llm = LLM( + model="facebook/opt-125m", + enable_prefix_caching=True, + gpu_memory_utilization=0.4, + ) # Warmup so that the shared prompt's KV cache is computed. prefix_cached_llm.generate(generating_prompts[0], sampling_params) @@ -81,10 +84,12 @@ def main(): print("-" * 50) # Compare the results and display the speedup - generated_same = all([ - regular_generated_texts[i] == cached_generated_texts[i] - for i in range(len(prompts)) - ]) + generated_same = all( + [ + regular_generated_texts[i] == cached_generated_texts[i] + for i in range(len(prompts)) + ] + ) print(f"Generated answers are the same: {generated_same}") diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index f97a1f32e621..21f7668adc86 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -16,16 +16,17 @@ Run the example: python prithvi_geospatial_mae.py -""" # noqa: E501 +""" # noqa: E501 + import argparse import datetime import os -import re from typing import Union import albumentations import numpy as np import rasterio +import regex as re import torch from einops import rearrange from terratorch.datamodules import Sen1Floods11NonGeoDataModule @@ -110,77 +111,67 @@ # Temporarily creating the "config.json" for the model. # This is going to disappear once the correct config.json is available on HF -with open(os.path.join(os.path.dirname(__file__), "./model/config.json"), - 'w') as config_file: +with open( + os.path.join(os.path.dirname(__file__), "./model/config.json"), "w" +) as config_file: config_file.write(model_config) datamodule_config = { - 'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'], - 'batch_size': - 16, - 'constant_scale': - 0.0001, - 'data_root': - '/dccstor/geofm-finetuning/datasets/sen1floods11', - 'drop_last': - True, - 'no_data_replace': - 0.0, - 'no_label_replace': - -1, - 'num_workers': - 8, - 'test_transform': [ - albumentations.Resize(always_apply=False, - height=448, - interpolation=1, - p=1, - width=448), - albumentations.pytorch.ToTensorV2(transpose_mask=False, - always_apply=True, - p=1.0) + "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], + "batch_size": 16, + "constant_scale": 0.0001, + "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11", + "drop_last": True, + "no_data_replace": 0.0, + "no_label_replace": -1, + "num_workers": 8, + "test_transform": [ + albumentations.Resize( + always_apply=False, height=448, interpolation=1, p=1, width=448 + ), + albumentations.pytorch.ToTensorV2( + transpose_mask=False, always_apply=True, p=1.0 + ), ], } class PrithviMAE: - def __init__(self): print("Initializing PrithviMAE model") - self.model = LLM(model=os.path.join(os.path.dirname(__file__), - "./model"), - skip_tokenizer_init=True, - dtype="float32") + self.model = LLM( + model=os.path.join(os.path.dirname(__file__), "./model"), + skip_tokenizer_init=True, + dtype="float32", + ) def run(self, input_data, location_coords): print("################ Running inference on vLLM ##############") # merge the inputs into one data structure mm_data = { - "pixel_values": - torch.empty(0) if input_data is None else input_data, - "location_coords": - torch.empty(0) if location_coords is None else location_coords + "pixel_values": torch.empty(0) if input_data is None else input_data, + "location_coords": torch.empty(0) + if location_coords is None + else location_coords, } prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} outputs = self.model.encode(prompt, use_tqdm=False) - print( - "################ Inference done (it took seconds) ##############" - ) + print("################ Inference done (it took seconds) ##############") return outputs[0].outputs.data def generate_datamodule(): datamodule = Sen1Floods11NonGeoDataModule( - data_root=datamodule_config['data_root'], + data_root=datamodule_config["data_root"], batch_size=datamodule_config["batch_size"], num_workers=datamodule_config["num_workers"], bands=datamodule_config["bands"], drop_last=datamodule_config["drop_last"], - test_transform=datamodule_config["test_transform" - ""]) + test_transform=datamodule_config["test_transform"], + ) return datamodule @@ -204,8 +195,7 @@ def process_channel_group(orig_img, channels): max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE)) min_value = OFFSET - orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, - 1) + orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1) # No data as zeros orig_img[~valid_mask] = 0 @@ -300,18 +290,21 @@ def load_example( location_coords.append(coords) try: - match = re.search(r'(\d{7,8}T\d{6})', file) + match = re.search(r"(\d{7,8}T\d{6})", file) if match: year = int(match.group(1)[:4]) - julian_day = match.group(1).split('T')[0][4:] + julian_day = match.group(1).split("T")[0][4:] if len(julian_day) == 3: julian_day = int(julian_day) else: - julian_day = datetime.datetime.strptime( - julian_day, '%m%d').timetuple().tm_yday + julian_day = ( + datetime.datetime.strptime(julian_day, "%m%d") + .timetuple() + .tm_yday + ) temporal_coords.append([year, julian_day]) except Exception as e: - print(f'Could not extract timestamp for {file} ({e})') + print(f"Could not extract timestamp for {file} ({e})") imgs = np.stack(imgs, axis=0) # num_frames, H, W, C imgs = np.moveaxis(imgs, -1, 0).astype("float32") @@ -320,50 +313,44 @@ def load_example( return imgs, temporal_coords, location_coords, metas -def run_model(input_data, - temporal_coords, - location_coords, - model, - datamodule, - img_size, - lightning_model=None): +def run_model( + input_data, + temporal_coords, + location_coords, + model, + datamodule, + img_size, + lightning_model=None, +): # Reflect pad if not divisible by img_size original_h, original_w = input_data.shape[-2:] pad_h = (img_size - (original_h % img_size)) % img_size pad_w = (img_size - (original_w % img_size)) % img_size - input_data = np.pad(input_data, - ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), - mode="reflect") + input_data = np.pad( + input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect" + ) # Build sliding window batch_size = 1 batch = torch.tensor(input_data, device="cpu") - windows = (batch.unfold(3, img_size, - img_size).unfold(4, img_size, img_size)) + windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) h1, w1 = windows.shape[3:5] - windows = rearrange(windows, - "b c t h1 w1 h w -> (b h1 w1) c t h w", - h=img_size, - w=img_size) + windows = rearrange( + windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size + ) # Split into batches if number of windows > batch_size - num_batches = windows.shape[0] // batch_size if windows.shape[ - 0] > batch_size else 1 + num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 windows = torch.tensor_split(windows, num_batches, dim=0) - if torch.cuda.is_available(): - device = torch.device('cuda') - else: - device = torch.device('cpu') + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if temporal_coords: - temporal_coords = torch.tensor(temporal_coords, - device=device).unsqueeze(0) + temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0) else: temporal_coords = None if location_coords: - location_coords = torch.tensor(location_coords[0], - device=device).unsqueeze(0) + location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0) else: location_coords = None @@ -371,26 +358,24 @@ def run_model(input_data, pred_imgs = [] for x in windows: # Apply standardization - x = datamodule.test_transform( - image=x.squeeze().numpy().transpose(1, 2, 0)) - x = datamodule.aug(x)['image'] + x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0)) + x = datamodule.aug(x)["image"] with torch.no_grad(): x = x.to(device) pred = model.run(x, location_coords=location_coords) if lightning_model: pred_lightning = lightning_model( - x, - temporal_coords=temporal_coords, - location_coords=location_coords) + x, temporal_coords=temporal_coords, location_coords=location_coords + ) pred_lightning = pred_lightning.output.detach().cpu() if not torch.equal(pred, pred_lightning): print("Inference output is not equal") y_hat = pred.argmax(dim=1) - y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), - size=img_size, - mode="nearest") + y_hat = torch.nn.functional.interpolate( + y_hat.unsqueeze(1).float(), size=img_size, mode="nearest" + ) pred_imgs.append(y_hat) @@ -437,8 +422,7 @@ def parse_args(): default=[1, 2, 3, 8, 11, 12], type=int, nargs="+", - help= - "0-based indices of the six Prithvi channels to be selected from the " + help="0-based indices of the six Prithvi channels to be selected from the " "input. By default selects [1,2,3,8,11,12] for S2L1C data.", ) parser.add_argument( @@ -478,17 +462,18 @@ def main( # Running model ------------------------------------------------------------ channels = [ - datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"] + datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"] ] # BGR -> RGB - pred = run_model(input_data, temporal_coords, location_coords, model_obj, - datamodule, img_size) + pred = run_model( + input_data, temporal_coords, location_coords, model_obj, datamodule, img_size + ) # Save pred meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) pred_file = os.path.join( - output_dir, - f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff" + ) save_geotiff(_convert_np_uint8(pred), pred_file, meta_data) # Save image + pred @@ -502,13 +487,13 @@ def main( channels=channels, ) - pred[pred == 0.] = np.nan + pred[pred == 0.0] = np.nan img_pred = rgb_orig * 0.7 + pred * 0.3 img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()] img_pred_file = os.path.join( - output_dir, - f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff" + ) save_geotiff( image=_convert_np_uint8(img_pred), output_path=img_pred_file, @@ -518,8 +503,9 @@ def main( # Save image rgb if rgb_outputs: rgb_file = os.path.join( - output_dir, "original_rgb_" - f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + output_dir, + f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff", + ) save_geotiff( image=_convert_np_uint8(rgb_orig), output_path=rgb_file, @@ -528,7 +514,6 @@ def main( if __name__ == "__main__": - args = parse_args() main(**vars(args)) diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py index 99303950d39d..244a64b891c9 100644 --- a/examples/offline_inference/profiling.py +++ b/examples/offline_inference/profiling.py @@ -44,14 +44,17 @@ def get_dtype(dtype: str): OutputLen_NumReqs_Map: TypeAlias = dict[int, int] -def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \ - -> OutputLen_NumReqs_Map: + + +def compute_request_output_lengths( + batch_size: int, step_requests: list[int] +) -> OutputLen_NumReqs_Map: """ Given the number of requests, batch_size, and the number of requests that each engine-step should process, step_requests, determine the output lengths of the requests such that step_request is honoured. - Example: + Example: if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1] then return, {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning, @@ -100,17 +103,19 @@ def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \ output_length -= 1 # sanity checks. - assert sum(ol_nr.values()) == batch_size, \ - ("Number of requests in output-length assignment does not match " - f"batch-size.\n batch size {batch_size} - " - f"step requests {step_requests} - assignments {ol_nr}") + assert sum(ol_nr.values()) == batch_size, ( + "Number of requests in output-length assignment does not match " + f"batch-size.\n batch size {batch_size} - " + f"step requests {step_requests} - assignments {ol_nr}" + ) # Check that the output-length is in [1, num-steps]. Output length must be # at least 1 as all requests must participate in the prefill-step. - assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \ - ("Output lengths of requests should be in range " - f"[1, num-engine-steps].\n batch size {batch_size} - " - f"step requests {step_requests} - assignments {ol_nr}") + assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), ( + "Output lengths of requests should be in range " + f"[1, num-engine-steps].\n batch size {batch_size} - " + f"step requests {step_requests} - assignments {ol_nr}" + ) return ol_nr @@ -131,7 +136,7 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]: context: ProfileContext object. Returns: - list[int]: Number of requests to process for all engine-steps. + list[int]: Number of requests to process for all engine-steps. output[i], contains the number of requests that the ith step should process. """ @@ -140,10 +145,13 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]: # that their output lengths must be equal to num_engine_steps. return [context.batch_size] * context.num_steps - assert context.complete_num_requests_per_step and \ - context.complete_num_requests_per_step > 0, \ - (f"Expected a positive complete_num_requests_per_step argument." - f"Instead got {context.complete_num_requests_per_step}") + assert ( + context.complete_num_requests_per_step + and context.complete_num_requests_per_step > 0 + ), ( + f"Expected a positive complete_num_requests_per_step argument." + f"Instead got {context.complete_num_requests_per_step}" + ) # We start dropping after the first decode step. step_requests = [ @@ -165,8 +173,9 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]: return step_requests -def run_profile(context: ProfileContext, csv_output: Optional[str], - json_output: Optional[str]): +def run_profile( + context: ProfileContext, csv_output: Optional[str], json_output: Optional[str] +): print("Run profile with:") for key, value in asdict(context).items(): print(f" {key} = {value}") @@ -174,7 +183,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], requests_per_step: list[int] = determine_requests_per_step(context) ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( - context.batch_size, requests_per_step) + context.batch_size, requests_per_step + ) num_steps_to_profile: int = len(requests_per_step) max_output_len: int = max(ol_nr.keys()) @@ -186,44 +196,51 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], top_p=0.95, # max_tokens is set on a per-request basis. max_tokens=None, - ignore_eos=True) + ignore_eos=True, + ) # Create LLM llm = LLM(**asdict(context.engine_args)) batch_size = context.batch_size prompt_len = context.prompt_len - scheduler_config = llm.llm_engine.scheduler_config + scheduler_config = llm.llm_engine.vllm_config.scheduler_config max_model_len = llm.llm_engine.model_config.max_model_len max_num_batched_tokens = scheduler_config.max_num_batched_tokens max_num_seqs = scheduler_config.max_num_seqs if batch_size * prompt_len > max_num_batched_tokens: - print(f"ERROR: chosen batch_size * prompt_len " - f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " - f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " - f"and therefore cannot be run in a single profile step, please " - f"choose a smaller batch size or prompt length, or increase " - f"--max-num-batched-tokens") + print( + f"ERROR: chosen batch_size * prompt_len " + f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " + f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " + f"and therefore cannot be run in a single profile step, please " + f"choose a smaller batch size or prompt length, or increase " + f"--max-num-batched-tokens" + ) sys.exit(-1) if batch_size > max_num_seqs: print( f"ERROR: chosen batch_size ({batch_size}) is larger than " f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " - f"single profile step, please choose a smaller batch size") + f"single profile step, please choose a smaller batch size" + ) sys.exit(-1) - print("llm.llm_engine.model_config.max_model_len: ", - llm.llm_engine.model_config.max_model_len) + print( + "llm.llm_engine.model_config.max_model_len: ", + llm.llm_engine.model_config.max_model_len, + ) if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len: - print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " - f"{max_output_len} = {prompt_len + max_output_len}) is larger " - f"than the model's max_model_len ({max_model_len}), please " - f"choose a smaller prompt_len or max_output_len, or increase " - f"--max-model-len") + print( + f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " + f"{max_output_len} = {prompt_len + max_output_len}) is larger " + f"than the model's max_model_len ({max_model_len}), please " + f"choose a smaller prompt_len or max_output_len, or increase " + f"--max-model-len" + ) sys.exit(-1) def add_requests(): - def get_output_len_generator() -> Generator[int, Any, Any]: for output_len, num_reqs in ol_nr.items(): for _ in range(num_reqs): @@ -234,13 +251,15 @@ def get_output_len_generator() -> Generator[int, Any, Any]: sampling_params.max_tokens = next(output_len_generator) assert isinstance(sampling_params.max_tokens, int) - prompt_token_ids = torch.randint(llm.get_tokenizer().vocab_size, - size=(prompt_len, )).tolist() + prompt_token_ids = torch.randint( + llm.get_tokenizer().vocab_size, size=(prompt_len,) + ).tolist() llm.llm_engine.add_request( request_id=f"seq{i}", - prompt={'prompt_token_ids': prompt_token_ids}, - params=sampling_params) + prompt={"prompt_token_ids": prompt_token_ids}, + params=sampling_params, + ) def abort_requests(): for i in range(batch_size): @@ -261,10 +280,8 @@ def abort_requests(): decode_profs = [] for _ in tqdm.tqdm(range(num_steps_to_profile - 1)): - num_running_seqs = llm.llm_engine.scheduler[ - 0].get_num_unfinished_seq_groups() - with layerwise_profile( - num_running_seqs=num_running_seqs) as decode_prof: + num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups() + with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof: llm.llm_engine.step() decode_profs.append(decode_prof) @@ -274,8 +291,7 @@ def abort_requests(): LINE_WIDTH = 80 print("=" * LINE_WIDTH) - print(f"= Prefill Model Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})") + print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})") print("=" * LINE_WIDTH) print() prefill_results.print_model_table() @@ -283,16 +299,17 @@ def abort_requests(): if has_decode: print() print("=" * LINE_WIDTH) - print(f"= First Decode Step Model Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})") + print( + f"= First Decode Step Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})" + ) print("=" * LINE_WIDTH) print() decode_results_list[0].print_model_table() print() print("=" * LINE_WIDTH) - print(f"= Prefill Summary Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})") + print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})") print("=" * LINE_WIDTH) print() prefill_results.print_summary_table() @@ -300,25 +317,32 @@ def abort_requests(): if has_decode: print() print("=" * LINE_WIDTH) - print(f"= First Decode Step Summary Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})") + print( + f"= First Decode Step Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})" + ) print("=" * LINE_WIDTH) print() decode_results_list[0].print_summary_table() if csv_output: - csv_filename_base = csv_output[:-4] \ - if csv_output.endswith('.csv') else csv_output + csv_filename_base = ( + csv_output[:-4] if csv_output.endswith(".csv") else csv_output + ) prefill_results.export_model_stats_table_csv( - csv_filename_base + "_prefill_model_table.csv") + csv_filename_base + "_prefill_model_table.csv" + ) prefill_results.export_summary_stats_table_csv( - csv_filename_base + "_prefill_summary_table.csv") + csv_filename_base + "_prefill_summary_table.csv" + ) if has_decode: - decode_results_list[0].export_model_stats_table_csv(\ - csv_filename_base + "_decode_model_table.csv") + decode_results_list[0].export_model_stats_table_csv( + csv_filename_base + "_decode_model_table.csv" + ) decode_results_list[0].export_summary_stats_table_csv( - csv_filename_base + "_decode_summary_table.csv") + csv_filename_base + "_decode_summary_table.csv" + ) if json_output: cuda_devices = [ @@ -332,7 +356,7 @@ def abort_requests(): "torch_version": f"{torch.__version__}", "torch_cuda_version": f"{torch.version.cuda}", "cuda_devices": f"{cuda_devices}", - **asdict(context) + **asdict(context), }, "prefill": prefill_results.convert_stats_to_dict(), } @@ -342,8 +366,9 @@ def abort_requests(): json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() # Add .json to json_output filename if it doesn't exist already. - json_output_file = json_output if json_output.endswith( - '.json') else json_output + '.json' + json_output_file = ( + json_output if json_output.endswith(".json") else json_output + ".json" + ) with open(json_output_file, "w+") as f: json.dump(json_dict, f, indent=2) pass @@ -351,16 +376,21 @@ def abort_requests(): if context.save_chrome_traces_folder is not None: os.makedirs(context.save_chrome_traces_folder, exist_ok=True) prefill_prof.profiler.export_chrome_trace( - context.save_chrome_traces_folder + "/prefill.json") + context.save_chrome_traces_folder + "/prefill.json" + ) for idx, decode_prof in enumerate(decode_profs): decode_prof.profiler.export_chrome_trace( - context.save_chrome_traces_folder + f"/decode_{idx + 1}.json") - print("Traces saved as prefill.json and decode_1.json, etc." - f" in folder {context.save_chrome_traces_folder}") + context.save_chrome_traces_folder + f"/decode_{idx + 1}.json" + ) + print( + "Traces saved as prefill.json and decode_1.json, etc." + f" in folder {context.save_chrome_traces_folder}" + ) def parse_args(): - parser = FlexibleArgumentParser(description=""" + parser = FlexibleArgumentParser( + description=""" Profile a model example: @@ -384,7 +414,8 @@ def parse_args(): --output-directory profile_breakdown --plot-metric pct_cuda_time ``` """, - formatter_class=RawTextHelpFormatter) + formatter_class=RawTextHelpFormatter, + ) parser.add_argument( "--csv", type=str, @@ -393,59 +424,68 @@ def parse_args(): "filename, will create <filename>_prefill_model_table.csv, " "<filename>_prefill_summary_table.csv, " "<filename>_decode_model_table.csv, and " - "<filename>_decode_summary_table.csv") + "<filename>_decode_summary_table.csv", + ) parser.add_argument( "--json", type=str, default=None, - help="Export the results as a json file. This should be the filename") - parser.add_argument("--save-chrome-traces-folder", - type=str, - help="Save chrome traces for the prefill and decode " - "will save traces as prefill.json and decode_1.json, " - "etc. inside this folder") + help="Export the results as a json file. This should be the filename", + ) + parser.add_argument( + "--save-chrome-traces-folder", + type=str, + help="Save chrome traces for the prefill and decode " + "will save traces as prefill.json and decode_1.json, " + "etc. inside this folder", + ) parser.add_argument( "--prompt-len", type=int, default=PROMPT_LEN_DEFAULT, help=f"Length of the random prompt to use when profiling, all batched " - f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") - parser.add_argument("--batch-size", - type=int, - default=BATCH_SIZE_DEFAULT, - help=f"Number of requests to run as a single batch, " - f"default={BATCH_SIZE_DEFAULT}") + f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}", + ) + parser.add_argument( + "--batch-size", + type=int, + default=BATCH_SIZE_DEFAULT, + help=f"Number of requests to run as a single batch, " + f"default={BATCH_SIZE_DEFAULT}", + ) subparsers = parser.add_subparsers(dest="cmd") run_num_steps_parser = subparsers.add_parser( - "run_num_steps", - help="This variation profiles n engine.step() invocations.") + "run_num_steps", help="This variation profiles n engine.step() invocations." + ) run_num_steps_parser.add_argument( - '-n', - '--num-steps', + "-n", + "--num-steps", type=int, help="Number of engine steps to profile.\n" "Setting it to 1, profiles only the prefill step.\n" "Setting it to 2, profiles the prefill and first decode step\n" "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n" - "and so on ...") + "and so on ...", + ) run_to_completion_parser = subparsers.add_parser( "run_to_completion", help="This variation profiles all the engine.step() invocations" - "until the engine exhausts all submitted requests.") + "until the engine exhausts all submitted requests.", + ) run_to_completion_parser.add_argument( - '-n', - '--complete-num-requests-per-step', + "-n", + "--complete-num-requests-per-step", type=int, - help= - "Complete complete_num_requests_per_step requests every decode step." + help="Complete complete_num_requests_per_step requests every decode step." "For e.g., with batch_size 128 and complete_num_requests_per_step 32," "the profiler is run for 6 engine steps, with the steps processing, " "128, 128, 96, 64, 32, 1 requests respectively.\n" "Note that we tack-on a one-request step at the end as it is often " - "useful.") + "useful.", + ) EngineArgs.add_cli_args(parser) @@ -459,7 +499,8 @@ def main(args): k: v for k, v in vars(args).items() if k in inspect.signature(ProfileContext).parameters - }) + }, + ) run_profile(context, csv_output=args.csv, json_output=args.json) diff --git a/examples/offline_inference/profiling_tpu/profiling.py b/examples/offline_inference/profiling_tpu/profiling.py index 61da4705e18e..82737d538df4 100644 --- a/examples/offline_inference/profiling_tpu/profiling.py +++ b/examples/offline_inference/profiling_tpu/profiling.py @@ -31,18 +31,16 @@ def main(args: argparse.Namespace): max_tokens=args.output_len, ) print(sampling_params) - dummy_prompt_token_ids = np.random.randint(10000, - size=(args.batch_size, - args.input_len)) - dummy_prompts: list[PromptType] = [{ - "prompt_token_ids": batch - } for batch in dummy_prompt_token_ids.tolist()] + dummy_prompt_token_ids = np.random.randint( + 10000, size=(args.batch_size, args.input_len) + ) + dummy_prompts: list[PromptType] = [ + {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() + ] def run_to_completion(): start_time = time.perf_counter() - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() latency = end_time - start_time return latency @@ -58,10 +56,9 @@ def run_to_completion(): profile_dir = args.profile_result_dir print(f"Profiling (results will be saved to '{profile_dir}')...") # Enable tracing on server - xp.trace_detached("localhost:9012", - profile_dir, - delay_ms=DELAY_MS, - duration_ms=DURATION_MS) + xp.trace_detached( + "localhost:9012", profile_dir, delay_ms=DELAY_MS, duration_ms=DURATION_MS + ) if DELAY_MS == 0: time.sleep(1.0) profile_latencies = [] @@ -72,30 +69,36 @@ def run_to_completion(): return -if __name__ == '__main__': +if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the latency of processing a single batch of ' - 'requests till completion.') - parser.add_argument('--input-len', type=int, default=32) - parser.add_argument('--output-len', type=int, default=128) - parser.add_argument('--batch-size', type=int, default=8) - parser.add_argument('--num-iters-warmup', - type=int, - default=5, - help='Number of iterations to run for warmup.') - parser.add_argument('--num-iters', - type=int, - default=1, - help='Number of iterations to run for profiling.') + description="Benchmark the latency of processing a single batch of " + "requests till completion." + ) + parser.add_argument("--input-len", type=int, default=32) + parser.add_argument("--output-len", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument( + "--num-iters-warmup", + type=int, + default=5, + help="Number of iterations to run for warmup.", + ) + parser.add_argument( + "--num-iters", + type=int, + default=1, + help="Number of iterations to run for profiling.", + ) parser.add_argument( - '--profile-result-dir', + "--profile-result-dir", type=str, default="profiles", - help= - ('path to save the pytorch profiler output. Can be visualized ' - 'with ui.perfetto.dev or Tensorboard ' - '(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).' - )) + help=( + "path to save the pytorch profiler output. Can be visualized " + "with ui.perfetto.dev or Tensorboard " + "(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)." + ), + ) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/examples/offline_inference/prompt_embed_inference.py b/examples/offline_inference/prompt_embed_inference.py new file mode 100644 index 000000000000..9f6a602233f8 --- /dev/null +++ b/examples/offline_inference/prompt_embed_inference.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Demonstrates how to generate prompt embeddings using +Hugging Face Transformers and use them as input to vLLM +for both single and batch inference. + +Model: meta-llama/Llama-3.2-1B-Instruct +Note: This model is gated on Hugging Face Hub. + You must request access to use it: + https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct + +Requirements: +- vLLM +- transformers + +Run: + python examples/offline_inference/prompt_embed_inference.py +""" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer + +from vllm import LLM + + +def init_tokenizer_and_llm(model_name: str): + tokenizer = AutoTokenizer.from_pretrained(model_name) + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) + embedding_layer = transformers_model.get_input_embeddings() + llm = LLM(model=model_name, enable_prompt_embeds=True) + return tokenizer, embedding_layer, llm + + +def get_prompt_embeds( + chat: list[dict[str, str]], + tokenizer: PreTrainedTokenizer, + embedding_layer: torch.nn.Module, +): + token_ids = tokenizer.apply_chat_template( + chat, add_generation_prompt=True, return_tensors="pt" + ) + prompt_embeds = embedding_layer(token_ids).squeeze(0) + return prompt_embeds + + +def single_prompt_inference( + llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module +): + chat = [{"role": "user", "content": "Please tell me about the capital of France."}] + prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) + + outputs = llm.generate( + { + "prompt_embeds": prompt_embeds, + } + ) + + print("\n[Single Inference Output]") + print("-" * 30) + for o in outputs: + print(o.outputs[0].text) + print("-" * 30) + + +def batch_prompt_inference( + llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module +): + chats = [ + [{"role": "user", "content": "Please tell me about the capital of France."}], + [{"role": "user", "content": "When is the day longest during the year?"}], + [{"role": "user", "content": "Where is bigger, the moon or the sun?"}], + ] + + prompt_embeds_list = [ + get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats + ] + + outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list]) + + print("\n[Batch Inference Outputs]") + print("-" * 30) + for i, o in enumerate(outputs): + print(f"Q{i + 1}: {chats[i][0]['content']}") + print(f"A{i + 1}: {o.outputs[0].text}\n") + print("-" * 30) + + +def main(): + model_name = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name) + single_prompt_inference(llm, tokenizer, embedding_layer) + batch_prompt_inference(llm, tokenizer, embedding_layer) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md index c30541a598ce..16d44cbadbc9 100644 --- a/examples/offline_inference/qwen2_5_omni/README.md +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -6,14 +6,19 @@ This folder provides several example scripts on how to inference Qwen2.5-Omni of ```bash # Audio + image + video -python examples/offline_inference/qwen2_5_omni/only_thinker.py -q mixed_modalities +python examples/offline_inference/qwen2_5_omni/only_thinker.py \ + -q mixed_modalities # Read vision and audio inputs from a single video file # NOTE: V1 engine does not support interleaved modalities yet. -VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video +VLLM_USE_V1=0 \ +python examples/offline_inference/qwen2_5_omni/only_thinker.py \ + -q use_audio_in_video # Multiple audios -VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q multi_audios +VLLM_USE_V1=0 \ +python examples/offline_inference/qwen2_5_omni/only_thinker.py \ + -q multi_audios ``` This script will run the thinker part of Qwen2.5-Omni, and generate text response. @@ -22,11 +27,16 @@ You can also test Qwen2.5-Omni on a single modality: ```bash # Process audio inputs -python examples/offline_inference/audio_language.py --model-type qwen2_5_omni +python examples/offline_inference/audio_language.py \ + --model-type qwen2_5_omni # Process image inputs -python examples/offline_inference/vision_language.py --modality image --model-type qwen2_5_omni +python examples/offline_inference/vision_language.py \ + --modality image \ + --model-type qwen2_5_omni # Process video inputs -python examples/offline_inference/vision_language.py --modality video --model-type qwen2_5_omni +python examples/offline_inference/vision_language.py \ + --modality video \ + --model-type qwen2_5_omni ``` diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index c2c28d5ae6ae..6482490d1a93 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -This example shows how to use vLLM for running offline inference +This example shows how to use vLLM for running offline inference with the correct prompt format on Qwen2.5-Omni (thinker only). """ @@ -11,6 +11,7 @@ from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset +from vllm.multimodal.image import convert_image_mode from vllm.utils import FlexibleArgumentParser @@ -26,50 +27,55 @@ class QueryResult(NamedTuple): default_system = ( "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "Group, capable of perceiving auditory and visual inputs, as well as " - "generating text and speech.") + "generating text and speech." +) def get_mixed_modalities_query() -> QueryResult: - question = ("What is recited in the audio? " - "What is the content of this image? Why is this video funny?") - prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" - "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" - "<|vision_bos|><|IMAGE|><|vision_eos|>" - "<|vision_bos|><|VIDEO|><|vision_eos|>" - f"{question}<|im_end|>\n" - f"<|im_start|>assistant\n") + question = ( + "What is recited in the audio? " + "What is the content of this image? Why is this video funny?" + ) + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|vision_bos|><|IMAGE|><|vision_eos|>" + "<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) return QueryResult( inputs={ "prompt": prompt, "multi_modal_data": { - "audio": - AudioAsset("mary_had_lamb").audio_and_sample_rate, - "image": - ImageAsset("cherry_blossom").pil_image.convert("RGB"), - "video": - VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + "image": convert_image_mode( + ImageAsset("cherry_blossom").pil_image, "RGB" + ), + "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, }, }, - limit_mm_per_prompt={ - "audio": 1, - "image": 1, - "video": 1 - }, + limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1}, ) def get_use_audio_in_video_query() -> QueryResult: - question = ("Describe the content of the video, " - "then convert what the baby say into text.") - prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" - "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" - f"{question}<|im_end|>\n" - f"<|im_start|>assistant\n") + question = ( + "Describe the content of the video, then convert what the baby say into text." + ) + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) asset = VideoAsset(name="baby_reading", num_frames=16) audio = asset.get_audio(sampling_rate=16000) - assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. " - "Please launch this example with " - "`VLLM_USE_V1=0`.") + assert not envs.VLLM_USE_V1, ( + "V1 does not support use_audio_in_video. " + "Please launch this example with " + "`VLLM_USE_V1=0`." + ) return QueryResult( inputs={ "prompt": prompt, @@ -81,20 +87,19 @@ def get_use_audio_in_video_query() -> QueryResult: "use_audio_in_video": True, }, }, - limit_mm_per_prompt={ - "audio": 1, - "video": 1 - }, + limit_mm_per_prompt={"audio": 1, "video": 1}, ) def get_multi_audios_query() -> QueryResult: question = "Are these two audio clips the same?" - prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" - "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" - "<|audio_bos|><|AUDIO|><|audio_eos|>" - f"{question}<|im_end|>\n" - f"<|im_start|>assistant\n") + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|audio_bos|><|AUDIO|><|audio_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) return QueryResult( inputs={ "prompt": prompt, @@ -122,38 +127,48 @@ def main(args): model_name = "Qwen/Qwen2.5-Omni-7B" query_result = query_map[args.query_type]() - llm = LLM(model=model_name, - max_model_len=5632, - max_num_seqs=5, - limit_mm_per_prompt=query_result.limit_mm_per_prompt, - seed=args.seed) + llm = LLM( + model=model_name, + max_model_len=5632, + max_num_seqs=5, + limit_mm_per_prompt=query_result.limit_mm_per_prompt, + seed=args.seed, + ) # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. sampling_params = SamplingParams(temperature=0.2, max_tokens=64) - outputs = llm.generate(query_result.inputs, - sampling_params=sampling_params) + outputs = llm.generate(query_result.inputs, sampling_params=sampling_params) for o in outputs: generated_text = o.outputs[0].text print(generated_text) -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'audio language models') - parser.add_argument('--query-type', - '-q', - type=str, - default="mixed_modalities", - choices=query_map.keys(), - help='Query type.') - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") - - args = parser.parse_args() + description="Demo on using vLLM for offline inference with " + "audio language models" + ) + parser.add_argument( + "--query-type", + "-q", + type=str, + default="mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/examples/offline_inference/qwen_1m.py b/examples/offline_inference/qwen_1m.py new file mode 100644 index 000000000000..856a35b0e59b --- /dev/null +++ b/examples/offline_inference/qwen_1m.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from urllib.request import urlopen + +from vllm import LLM, SamplingParams + +os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN" +os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" + + +def load_prompt() -> str: + # Test cases with various lengths can be found at: + # + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt + + with urlopen( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt", + timeout=5, + ) as response: + prompt = response.read().decode("utf-8") + return prompt + + +# Processing the prompt. +def process_requests(llm: LLM, prompts: list[str]) -> None: + # Create a sampling params object. + sampling_params = SamplingParams( + temperature=0.7, + top_p=0.8, + top_k=20, + repetition_penalty=1.05, + detokenize=True, + max_tokens=256, + ) + # Generate texts from the prompts. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt_token_ids = output.prompt_token_ids + generated_text = output.outputs[0].text + print( + f"Prompt length: {len(prompt_token_ids)}, " + f"Generated text: {generated_text!r}" + ) + + +# Create an LLM. +def initialize_engine() -> LLM: + llm = LLM( + model="Qwen/Qwen2.5-7B-Instruct-1M", + max_model_len=1048576, + tensor_parallel_size=4, + enforce_eager=True, + enable_chunked_prefill=True, + max_num_batched_tokens=131072, + ) + return llm + + +def main(): + llm = initialize_engine() + prompt = load_prompt() + process_requests(llm, [prompt]) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/reproducibility.py b/examples/offline_inference/reproducibility.py index b2be117d1a0a..6d048986e710 100644 --- a/examples/offline_inference/reproducibility.py +++ b/examples/offline_inference/reproducibility.py @@ -1,24 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 +""" +Demonstrates how to achieve reproducibility in vLLM. + +Main article: https://docs.vllm.ai/en/latest/usage/reproducibility.html +""" + import os +import random from vllm import LLM, SamplingParams -# vLLM does not guarantee the reproducibility of the results by default, -# for the sake of performance. You need to do the following to achieve -# reproducible results: -# 1. Turn off multiprocessing to make the scheduling deterministic. -# NOTE(woosuk): This is not needed and will be ignored for V0. +# V1 only: Turn off multiprocessing to make the scheduling deterministic. os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" -# 2. Fix the global seed for reproducibility. The default seed is None, which is + +# V0 only: Set the global seed. The default seed is None, which is # not reproducible. SEED = 42 -# NOTE(woosuk): Even with the above two settings, vLLM only provides -# reproducibility when it runs on the same hardware and the same vLLM version. -# Also, the online serving API (`vllm serve`) does not support reproducibility -# because it is almost impossible to make the scheduling deterministic in the -# online serving setting. - prompts = [ "Hello, my name is", "The president of the United States is", @@ -38,6 +36,11 @@ def main(): print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") print("-" * 50) + # Try generating random numbers outside vLLM + # The same number is output across runs, meaning that the random state + # in the user code has been updated by vLLM + print(random.randint(0, 100)) + if __name__ == "__main__": main() diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index e0ed0ac49754..a8f6977e29a4 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -12,6 +12,7 @@ and multiple inference instances. For the full implementation, please refer to the OpenRLHF framework. """ + import os import ray @@ -26,7 +27,6 @@ class MyLLM(LLM): - def __init__(self, *args, **kwargs): # a hack to make the script work. # stop ray from manipulating CUDA_VISIBLE_DEVICES @@ -89,8 +89,7 @@ def __init__(self, *args, **kwargs): for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") print("-" * 50) # set up the communication between the training process @@ -98,11 +97,13 @@ def __init__(self, *args, **kwargs): master_address = get_ip() master_port = get_open_port() -handle = llm.collective_rpc.remote("init_weight_update_group", - args=(master_address, master_port, 1, 3)) +handle = llm.collective_rpc.remote( + "init_weight_update_group", args=(master_address, master_port, 1, 3) +) -model_update_group = stateless_init_process_group(master_address, master_port, - 0, 3, torch.device("cuda:0")) +model_update_group = stateless_init_process_group( + master_address, master_port, 0, 3, torch.device("cuda:0") +) ray.get(handle) # simulate training, modify the weights of the model. @@ -111,8 +112,7 @@ def __init__(self, *args, **kwargs): # sync weight from the training process to the inference engine. for name, p in train_model.named_parameters(): - handle = llm.collective_rpc.remote("update_weight", - args=(name, p.dtype, p.shape)) + handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape)) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) ray.get(handle) @@ -126,6 +126,5 @@ def __init__(self, *args, **kwargs): for output in outputs_updated: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") print("-" * 50) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 3ceac0fa2e20..76eafdca1f6c 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -9,6 +9,7 @@ - Use cuda-ipc to pass tensors, since NCCL does not work when we have multiple processes on the same GPU. """ + import os import ray @@ -20,7 +21,6 @@ class MyLLM(LLM): - def __init__(self, *args, bundle_indices: list, **kwargs): # a hack to make the script work. # stop ray from manipulating CUDA_VISIBLE_DEVICES @@ -29,17 +29,16 @@ def __init__(self, *args, bundle_indices: list, **kwargs): # every worker will use 0.4 GPU, so that we can schedule # 2 instances on the same GPUs. os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" - os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join( - map(str, bundle_indices)) + os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) print(f"creating LLM with bundle_indices={bundle_indices}") super().__init__(*args, **kwargs) class RayTrainingActor: - def __init__(self): # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs from transformers import AutoModelForCausalLM + self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") self.model.to("cuda:0") for name, p in self.model.named_parameters(): @@ -48,6 +47,7 @@ def __init__(self): # the argument for get_device_uuid is the index # of the GPU in the visible devices. from vllm.platforms import current_platform + self.device_uuid = current_platform.get_device_uuid(0) def report_device_id(self) -> str: @@ -55,6 +55,7 @@ def report_device_id(self) -> str: def get_weight_ipc_handles(self): from torch.multiprocessing.reductions import reduce_tensor + data = {} for name, p in self.model.named_parameters(): # the training actor might only have a subset of the weights @@ -101,7 +102,7 @@ def get_weight_ipc_handles(self): print(f"training actor {bundle_index} is on {device_id}") training_actor_device_ids.append(device_id) -for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]): +for i, bundle_indices in enumerate([[0, 1], [2, 3]]): # IMPORTANT: when creating vLLM instances, we need to # make sure there are no GPU activities on the target GPUs, # otherwise, they will interfere with the vLLM memory profiling, @@ -128,7 +129,8 @@ def get_weight_ipc_handles(self): for i, llm in enumerate(inference_engines): inference_engine_device_ids.append( - ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))) + ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())) + ) print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") # check the placement @@ -147,9 +149,10 @@ def get_weight_ipc_handles(self): print("update the weights of the inference engines") for llm in inference_engines: ray.get( - llm.collective_rpc.remote("update_weights_from_ipc_handles", - args=(ipc_handles, ))) + llm.collective_rpc.remote( + "update_weights_from_ipc_handles", args=(ipc_handles,) + ) + ) print("check if the weights are updated") for llm in inference_engines: - assert ray.get( - llm.collective_rpc.remote("check_weights_changed", args=tuple())) + assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index 11b73b7c4a0a..3461af707eba 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -2,21 +2,20 @@ import torch -def stateless_init_process_group(master_address, master_port, rank, world_size, - device): +def stateless_init_process_group(master_address, master_port, rank, world_size, device): """ vLLM provides `StatelessProcessGroup` to create a process group without considering the global process group in torch.distributed. It is recommended to create `StatelessProcessGroup`, and then initialize - the data-plane communication (NCCL) between external (train processes) + the data-plane communication (NCCL) between external (train processes) and vLLM workers. """ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup - pg = StatelessProcessGroup.create(host=master_address, - port=master_port, - rank=rank, - world_size=world_size) + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) pynccl = PyNcclCommunicator(pg, device=device) return pynccl @@ -31,9 +30,11 @@ class WorkerExtension: should pass the full qualified name as `worker_extension_cls` argument. """ - def init_weight_update_group(self, master_address, master_port, - rank_offset, world_size): + def init_weight_update_group( + self, master_address, master_port, rank_offset, world_size + ): from vllm.distributed.parallel_state import get_world_group + rank = get_world_group().rank + rank_offset self.model_update_group = stateless_init_process_group( master_address, @@ -45,9 +46,9 @@ def init_weight_update_group(self, master_address, master_port, def update_weight(self, name, dtype, shape): weight = torch.empty(shape, dtype=dtype, device="cuda") - self.model_update_group.broadcast(weight, - src=0, - stream=torch.cuda.current_stream()) + self.model_update_group.broadcast( + weight, src=0, stream=torch.cuda.current_stream() + ) self.model_runner.model.load_weights(weights=[(name, weight)]) @@ -59,8 +60,7 @@ def check_weights_changed(self): """ weights_updated = True for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose( - p, torch.zeros_like(p)) + weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) return weights_updated @@ -76,6 +76,7 @@ class ColocateWorkerExtension: def report_device_id(self) -> str: from vllm.platforms import current_platform + self.device_uuid = current_platform.get_device_uuid(self.device.index) return self.device_uuid @@ -100,6 +101,5 @@ def check_weights_changed(self): """ weights_updated = True for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose( - p, torch.zeros_like(p)) + weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) return weights_updated diff --git a/examples/offline_inference/save_sharded_state.py b/examples/offline_inference/save_sharded_state.py index 338380cc9684..860fe2b5fe06 100644 --- a/examples/offline_inference/save_sharded_state.py +++ b/examples/offline_inference/save_sharded_state.py @@ -21,6 +21,7 @@ tensor_parallel_size=8, ) """ + import dataclasses import os import shutil @@ -33,18 +34,18 @@ def parse_args(): parser = FlexibleArgumentParser() EngineArgs.add_cli_args(parser) - parser.add_argument("--output", - "-o", - required=True, - type=str, - help="path to output checkpoint") - parser.add_argument("--file-pattern", - type=str, - help="string pattern of saved filenames") - parser.add_argument("--max-file-size", - type=str, - default=5 * 1024**3, - help="max size (in bytes) of each safetensors file") + parser.add_argument( + "--output", "-o", required=True, type=str, help="path to output checkpoint" + ) + parser.add_argument( + "--file-pattern", type=str, help="string pattern of saved filenames" + ) + parser.add_argument( + "--max-file-size", + type=str, + default=5 * 1024**3, + help="max size (in bytes) of each safetensors file", + ) return parser.parse_args() @@ -68,23 +69,23 @@ def main(args): # For V1 engine, we need to use engine_core.save_sharded_state print("Using V1 engine save path") llm.llm_engine.engine_core.save_sharded_state( - path=args.output, - pattern=args.file_pattern, - max_size=args.max_file_size) + path=args.output, pattern=args.file_pattern, max_size=args.max_file_size + ) else: # For V0 engine print("Using V0 engine save path") model_executor = llm.llm_engine.model_executor - model_executor.save_sharded_state(path=args.output, - pattern=args.file_pattern, - max_size=args.max_file_size) + model_executor.save_sharded_state( + path=args.output, pattern=args.file_pattern, max_size=args.max_file_size + ) # Copy metadata files to output directory for file in os.listdir(model_path): if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): if os.path.isdir(os.path.join(model_path, file)): - shutil.copytree(os.path.join(model_path, file), - os.path.join(args.output, file)) + shutil.copytree( + os.path.join(model_path, file), os.path.join(args.output, file) + ) else: shutil.copy(os.path.join(model_path, file), args.output) diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index 363b500e0adf..9ed7299606b7 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """ -This file demonstrates the example usage of guided decoding -to generate structured outputs using vLLM. It shows how to apply -different guided decoding techniques such as Choice, Regex, JSON schema, -and Grammar to produce structured and formatted results +This file demonstrates the example usage of guided decoding +to generate structured outputs using vLLM. It shows how to apply +different guided decoding techniques such as Choice, Regex, JSON schema, +and Grammar to produce structured and formatted results based on specific prompts. """ @@ -15,20 +15,20 @@ from vllm.sampling_params import GuidedDecodingParams # Guided decoding by Choice (list of possible options) -guided_decoding_params_choice = GuidedDecodingParams( - choice=["Positive", "Negative"]) -sampling_params_choice = SamplingParams( - guided_decoding=guided_decoding_params_choice) +guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) +sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) prompt_choice = "Classify this sentiment: vLLM is wonderful!" # Guided decoding by Regex guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") sampling_params_regex = SamplingParams( - guided_decoding=guided_decoding_params_regex, stop=["\n"]) + guided_decoding=guided_decoding_params_regex, stop=["\n"] +) prompt_regex = ( "Generate an email address for Alan Turing, who works in Enigma." "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + "alan.turing@enigma.com\n" +) # Guided decoding by JSON using Pydantic schema @@ -47,10 +47,11 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() guided_decoding_params_json = GuidedDecodingParams(json=json_schema) -sampling_params_json = SamplingParams( - guided_decoding=guided_decoding_params_json) -prompt_json = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") +sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json) +prompt_json = ( + "Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's" +) # Guided decoding by Grammar simplified_sql_grammar = """ @@ -61,12 +62,11 @@ class CarDescription(BaseModel): condition ::= column "= " number number ::= "1 " | "2 " """ -guided_decoding_params_grammar = GuidedDecodingParams( - grammar=simplified_sql_grammar) -sampling_params_grammar = SamplingParams( - guided_decoding=guided_decoding_params_grammar) -prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") +guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) +sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar) +prompt_grammar = ( + "Generate an SQL query to show the 'username' and 'email'from the 'users' table." +) def format_output(title: str, output: str): @@ -90,8 +90,7 @@ def main(): json_output = generate_output(prompt_json, sampling_params_json, llm) format_output("Guided decoding by JSON", json_output) - grammar_output = generate_output(prompt_grammar, sampling_params_grammar, - llm) + grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm) format_output("Guided decoding by Grammar", grammar_output) diff --git a/examples/offline_inference/torchrun_example.py b/examples/offline_inference/torchrun_example.py index c6d9e6b47e21..2fa49c0835e3 100644 --- a/examples/offline_inference/torchrun_example.py +++ b/examples/offline_inference/torchrun_example.py @@ -8,6 +8,8 @@ see `tests/distributed/test_torchrun_example.py` for the unit test. """ +import torch.distributed as dist + from vllm import LLM, SamplingParams # Create prompts, the same across all ranks @@ -27,23 +29,25 @@ # all ranks have the same random seed, so that sampling can be # deterministic across ranks. llm = LLM( - model="facebook/opt-125m", + model="meta-llama/Llama-3.1-8B", tensor_parallel_size=2, + pipeline_parallel_size=2, distributed_executor_backend="external_launcher", - seed=0, + max_model_len=32768, + seed=1, ) outputs = llm.generate(prompts, sampling_params) # all ranks will have the same outputs -print("-" * 50) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") +if dist.get_rank() == 0: print("-" * 50) -""" + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n") + print("-" * 50) + """ Further tips: 1. to communicate control messages across all ranks, use the cpu group, diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index 71cd88f2788a..e4a75b3f9380 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -20,10 +20,12 @@ def main(): # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - max_num_batched_tokens=64, - max_num_seqs=4, - max_model_len=128) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + max_num_batched_tokens=64, + max_num_seqs=4, + max_model_len=128, + ) outputs = llm.generate(prompts, sampling_params) print("-" * 50) for output, answer in zip(outputs, answers): diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 5c173ab1abb9..f0504501639d 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -6,6 +6,7 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ + import os import random from contextlib import contextmanager @@ -19,6 +20,7 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.lora.request import LoRARequest +from vllm.multimodal.image import convert_image_mode from vllm.utils import FlexibleArgumentParser @@ -48,9 +50,13 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData: limit_mm_per_prompt={modality: 1}, ) - prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}" - "<|im_end|>\n<|im_start|>assistant\n") - for question in questions] + prompts = [ + ( + f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}" + "<|im_end|>\n<|im_start|>assistant\n" + ) + for question in questions + ] stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] @@ -134,8 +140,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|User|>: <image>\n{question}\n\n<|Assistant|>:" - for question in questions + f"<|User|>: <image>\n{question}\n\n<|Assistant|>:" for question in questions ] return ModelRequestData( @@ -197,9 +202,14 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: limit_mm_per_prompt={modality: 1}, ) - prompts = [("<bos><start_of_turn>user\n" - f"<start_of_image>{question}<end_of_turn>\n" - "<start_of_turn>model\n") for question in questions] + prompts = [ + ( + "<bos><start_of_turn>user\n" + f"<start_of_image>{question}<end_of_turn>\n" + "<start_of_turn>model\n" + ) + for question in questions + ] return ModelRequestData( engine_args=engine_args, @@ -224,7 +234,8 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: prompts = [ f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" for question in questions + {question}<|assistant|>" + for question in questions ] stop_token_ids = [151329, 151336, 151338] @@ -249,15 +260,13 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: limit_mm_per_prompt={modality: 1}, ) - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - messages = [[{ - 'role': 'user', - 'content': f"<image>\n{question}" - }] for question in questions] - prompts = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"<image>\n{question}"}] for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # Stop tokens for H2OVL-Mississippi # https://huggingface.co/h2oai/h2ovl-mississippi-800m @@ -283,15 +292,14 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: # if you are running out of memory, you can reduce the "longest_edge". # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations mm_processor_kwargs={ - "size": { - "longest_edge": 3 * 364 - }, + "size": {"longest_edge": 3 * 364}, }, limit_mm_per_prompt={modality: 1}, ) - prompts = [( - f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:" - ) for question in questions] + prompts = [ + (f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:") + for question in questions + ] return ModelRequestData( engine_args=engine_args, @@ -310,9 +318,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, enforce_eager=True, mm_processor_kwargs={ - "max_image_size": { - "longest_edge": 384 - }, + "max_image_size": {"longest_edge": 384}, }, limit_mm_per_prompt={modality: 1}, ) @@ -329,26 +335,28 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData: # InternVL def run_internvl(questions: list[str], modality: str) -> ModelRequestData: - assert modality == "image" - - model_name = "OpenGVLab/InternVL2-2B" + model_name = "OpenGVLab/InternVL3-2B" engine_args = EngineArgs( model=model_name, trust_remote_code=True, - max_model_len=4096, + max_model_len=8192, limit_mm_per_prompt={modality: 1}, ) - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - messages = [[{ - 'role': 'user', - 'content': f"<image>\n{question}" - }] for question in questions] - prompts = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + if modality == "image": + placeholder = "<image>" + elif modality == "video": + placeholder = "<video>" + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"{placeholder}\n{question}"}] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # Stop tokens for InternVL # models variants may have different stop tokens @@ -356,6 +364,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None] return ModelRequestData( engine_args=engine_args, @@ -371,7 +380,8 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: prompts = [ "<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>" f"<|media_pad|><|media_end|>{question}<|im_end|>" - "<|im_assistant|>assistant<|im_middle|>" for question in questions + "<|im_assistant|>assistant<|im_middle|>" + for question in questions ] engine_args = EngineArgs( @@ -391,9 +401,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: def run_llava(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" - prompts = [ - f"USER: <image>\n{question}\nASSISTANT:" for question in questions - ] + prompts = [f"USER: <image>\n{question}\nASSISTANT:" for question in questions] engine_args = EngineArgs( model="llava-hf/llava-1.5-7b-hf", @@ -426,13 +434,10 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData: # LlaVA-NeXT-Video # Currently only support for video input -def run_llava_next_video(questions: list[str], - modality: str) -> ModelRequestData: +def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestData: assert modality == "video" - prompts = [ - f"USER: <video>\n{question} ASSISTANT:" for question in questions - ] + prompts = [f"USER: <video>\n{question} ASSISTANT:" for question in questions] engine_args = EngineArgs( model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192, @@ -447,19 +452,19 @@ def run_llava_next_video(questions: list[str], # LLaVA-OneVision -def run_llava_onevision(questions: list[str], - modality: str) -> ModelRequestData: - +def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData: if modality == "video": prompts = [ f"<|im_start|>user <video>\n{question}<|im_end|> \ - <|im_start|>assistant\n" for question in questions + <|im_start|>assistant\n" + for question in questions ] elif modality == "image": prompts = [ f"<|im_start|>user <image>\n{question}<|im_end|> \ - <|im_start|>assistant\n" for question in questions + <|im_start|>assistant\n" + for question in questions ] engine_args = EngineArgs( @@ -478,11 +483,8 @@ def run_llava_onevision(questions: list[str], def run_mantis(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" - llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501 - prompts = [ - llama3_template.format(f"{question}\n<image>") - for question in questions - ] + llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" # noqa: E501 + prompts = [llama3_template.format(f"{question}\n<image>") for question in questions] engine_args = EngineArgs( model="TIGER-Lab/Mantis-8B-siglip-llama3", @@ -522,8 +524,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): # 2.6: image, video # o2.6: image, video, audio # model_name = "openbmb/MiniCPM-o-2_6" - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) engine_args = EngineArgs( model=model_name, max_model_len=4096, @@ -539,7 +540,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id] # 2.6 / o2.6 - stop_tokens = ['<|im_end|>', '<|endoftext|>'] + stop_tokens = ["<|im_end|>", "<|endoftext|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] modality_placeholder = { @@ -549,12 +550,16 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): prompts = [ tokenizer.apply_chat_template( - [{ - 'role': 'user', - 'content': f"{modality_placeholder[modality]}\n{question}" - }], + [ + { + "role": "user", + "content": f"{modality_placeholder[modality]}\n{question}", + } + ], tokenize=False, - add_generation_prompt=True) for question in questions + add_generation_prompt=True, + ) + for question in questions ] return ModelRequestData( @@ -614,19 +619,18 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: ) tokenizer = AutoTokenizer.from_pretrained(model_name) - messages = [[{ - "role": - "user", - "content": [{ - "type": "image" - }, { - "type": "text", - "text": question - }] - }] for question in questions] - prompts = tokenizer.apply_chat_template(messages, - add_generation_prompt=True, - tokenize=False) + messages = [ + [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": question}], + } + ] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) return ModelRequestData( engine_args=engine_args, @@ -649,19 +653,18 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData: ) tokenizer = AutoTokenizer.from_pretrained(model_name) - messages = [[{ - "role": - "user", - "content": [{ - "type": "image" - }, { - "type": "text", - "text": f"{question}" - }] - }] for question in questions] - prompts = tokenizer.apply_chat_template(messages, - add_generation_prompt=True, - tokenize=False) + messages = [ + [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": f"{question}"}], + } + ] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) stop_token_ids = None return ModelRequestData( engine_args=engine_args, @@ -685,7 +688,8 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: prompts = [ f"<|im_start|>user <image>\n{question}<|im_end|> \ - <|im_start|>assistant\n" for question in questions + <|im_start|>assistant\n" + for question in questions ] return ModelRequestData( @@ -709,15 +713,13 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: limit_mm_per_prompt={modality: 1}, ) - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - messages = [[{ - 'role': 'user', - 'content': f"<image>\n{question}" - }] for question in questions] - prompts = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"<image>\n{question}"}] for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return ModelRequestData( engine_args=engine_args, @@ -725,8 +727,8 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: ) -# Ovis2 -def run_ovis2(questions: list[str], modality: str) -> ModelRequestData: +# Ovis +def run_ovis(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "AIDC-AI/Ovis2-1B" @@ -737,15 +739,16 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, trust_remote_code=True, dtype="half", - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, limit_mm_per_prompt={modality: 1}, ) - placeholder = "<image>\n" - prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - f"<|im_start|>user\n{placeholder}" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") for question in questions] + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"<image>\n{question}"}] for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return ModelRequestData( engine_args=engine_args, @@ -836,8 +839,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: # we have to manually specify the path of the lora weights. vision_lora_path = os.path.join(model_path, "vision-lora") prompts = [ - f"<|user|><|image_1|>{question}<|end|><|assistant|>" - for question in questions + f"<|user|><|image_1|>{question}<|end|><|assistant|>" for question in questions ] engine_args = EngineArgs( model=model_path, @@ -904,7 +906,6 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData: # Qwen2-VL def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: - model_name = "Qwen/Qwen2-VL-7B-Instruct" engine_args = EngineArgs( @@ -925,10 +926,13 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: placeholder = "<|video_pad|>" prompts = [ - ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") for question in questions + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions ] return ModelRequestData( @@ -939,7 +943,6 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: # Qwen2.5-VL def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: - model_name = "Qwen/Qwen2.5-VL-3B-Instruct" engine_args = EngineArgs( @@ -960,10 +963,13 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: placeholder = "<|video_pad|>" prompts = [ - ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") for question in questions + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions ] return ModelRequestData( @@ -996,12 +1002,18 @@ def run_qwen2_5_omni(questions: list[str], modality: str): default_system = ( "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "Group, capable of perceiving auditory and visual inputs, as well as " - "generating text and speech.") + "generating text and speech." + ) - prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n" - f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") for question in questions] + prompts = [ + ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] return ModelRequestData( engine_args=engine_args, prompts=prompts, @@ -1021,15 +1033,13 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: limit_mm_per_prompt={modality: 1}, ) - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - messages = [[{ - 'role': 'user', - 'content': f"<image>\n{question}" - }] for question in questions] - prompts = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"<image>\n{question}"}] for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # Stop tokens for SkyworkR1V # https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py @@ -1069,7 +1079,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "llama4": run_llama4, "molmo": run_molmo, "NVLM_D": run_nvlm_d, - "ovis2": run_ovis2, + "ovis": run_ovis, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, @@ -1093,8 +1103,7 @@ def get_multi_modal_input(args): """ if args.modality == "image": # Input image and question - image = ImageAsset("cherry_blossom") \ - .pil_image.convert("RGB") + image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") img_questions = [ "What is the content of this image?", "Describe the content of this image in detail.", @@ -1109,8 +1118,7 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question - video = VideoAsset(name="baby_reading", - num_frames=args.num_frames).np_ndarrays + video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays vid_questions = ["Why is this video funny?"] return { @@ -1122,12 +1130,13 @@ def get_multi_modal_input(args): raise ValueError(msg) -def apply_image_repeat(image_repeat_prob, num_prompts, data, - prompts: list[str], modality): - """Repeats images with provided probability of "image_repeat_prob". +def apply_image_repeat( + image_repeat_prob, num_prompts, data, prompts: list[str], modality +): + """Repeats images with provided probability of "image_repeat_prob". Used to simulate hit/miss for the MM preprocessor cache. """ - assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0) + assert image_repeat_prob <= 1.0 and image_repeat_prob >= 0 no_yes = [0, 1] probs = [1.0 - image_repeat_prob, image_repeat_prob] @@ -1142,12 +1151,12 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data, new_val = (i // 256 // 256, i // 256, i % 256) cur_image.putpixel((0, 0), new_val) - inputs.append({ - "prompt": prompts[i % len(prompts)], - "multi_modal_data": { - modality: cur_image + inputs.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: cur_image}, } - }) + ) return inputs @@ -1156,6 +1165,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data, def time_counter(enable: bool): if enable: import time + start_time = time.time() yield elapsed_time = time.time() - start_time @@ -1168,54 +1178,65 @@ def time_counter(enable: bool): def parse_args(): parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models for text generation') - parser.add_argument('--model-type', - '-m', - type=str, - default="llava", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument('--num-prompts', - type=int, - default=4, - help='Number of prompts to run.') - parser.add_argument('--modality', - type=str, - default="image", - choices=['image', 'video'], - help='Modality of the input.') - parser.add_argument('--num-frames', - type=int, - default=16, - help='Number of frames to extract from the video.') - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") + description="Demo on using vLLM for offline inference with " + "vision language models for text generation" + ) + parser.add_argument( + "--model-type", + "-m", + type=str, + default="llava", + choices=model_example_map.keys(), + help='Huggingface "model_type".', + ) + parser.add_argument( + "--num-prompts", type=int, default=4, help="Number of prompts to run." + ) + parser.add_argument( + "--modality", + type=str, + default="image", + choices=["image", "video"], + help="Modality of the input.", + ) + parser.add_argument( + "--num-frames", + type=int, + default=16, + help="Number of frames to extract from the video.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.", + ) parser.add_argument( - '--image-repeat-prob', + "--image-repeat-prob", type=float, default=None, - help='Simulates the hit-ratio for multi-modal preprocessor cache' - ' (if enabled)') + help="Simulates the hit-ratio for multi-modal preprocessor cache (if enabled)", + ) parser.add_argument( - '--disable-mm-preprocessor-cache', - action='store_true', - help='If True, disables caching of multi-modal preprocessor/mapper.') + "--disable-mm-preprocessor-cache", + action="store_true", + help="If True, disables caching of multi-modal preprocessor/mapper.", + ) parser.add_argument( - '--time-generate', - action='store_true', - help='If True, then print the total generate() call time') + "--time-generate", + action="store_true", + help="If True, then print the total generate() call time", + ) parser.add_argument( - '--use-different-prompt-per-request', - action='store_true', - help='If True, then use different prompt (with the same multi-modal ' - 'data) for each request.') + "--use-different-prompt-per-request", + action="store_true", + help="If True, then use different prompt (with the same multi-modal " + "data) for each request.", + ) return parser.parse_args() @@ -1234,7 +1255,8 @@ def main(args): # Disable other modalities to save memory default_limits = {"image": 0, "video": 0, "audio": 0} req_data.engine_args.limit_mm_per_prompt = default_limits | dict( - req_data.engine_args.limit_mm_per_prompt or {}) + req_data.engine_args.limit_mm_per_prompt or {} + ) engine_args = asdict(req_data.engine_args) | { "seed": args.seed, @@ -1243,44 +1265,46 @@ def main(args): llm = LLM(**engine_args) # Don't want to check the flag multiple times, so just hijack `prompts`. - prompts = req_data.prompts if args.use_different_prompt_per_request else [ - req_data.prompts[0] - ] + prompts = ( + req_data.prompts + if args.use_different_prompt_per_request + else [req_data.prompts[0]] + ) # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. - sampling_params = SamplingParams(temperature=0.2, - max_tokens=64, - stop_token_ids=req_data.stop_token_ids) + sampling_params = SamplingParams( + temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids + ) assert args.num_prompts > 0 if args.num_prompts == 1: # Single inference inputs = { "prompt": prompts[0], - "multi_modal_data": { - modality: data - }, + "multi_modal_data": {modality: data}, } else: # Batch inference if args.image_repeat_prob is not None: # Repeat images with specified probability of "image_repeat_prob" - inputs = apply_image_repeat(args.image_repeat_prob, - args.num_prompts, data, prompts, - modality) + inputs = apply_image_repeat( + args.image_repeat_prob, args.num_prompts, data, prompts, modality + ) else: # Use the same image for all prompts - inputs = [{ - "prompt": prompts[i % len(prompts)], - "multi_modal_data": { - modality: data - }, - } for i in range(args.num_prompts)] + inputs = [ + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: data}, + } + for i in range(args.num_prompts) + ] # Add LoRA request if applicable - lora_request = (req_data.lora_requests * - args.num_prompts if req_data.lora_requests else None) + lora_request = ( + req_data.lora_requests * args.num_prompts if req_data.lora_requests else None + ) with time_counter(args.time_generate): outputs = llm.generate( diff --git a/examples/offline_inference/vision_language_embedding.py b/examples/offline_inference/vision_language_embedding.py index 2637949551a1..cee02d06c607 100644 --- a/examples/offline_inference/vision_language_embedding.py +++ b/examples/offline_inference/vision_language_embedding.py @@ -6,6 +6,7 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ + from argparse import Namespace from dataclasses import asdict from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args @@ -44,19 +45,17 @@ class ModelRequestData(NamedTuple): def run_e5_v(query: Query) -> ModelRequestData: - llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 + llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501 if query["modality"] == "text": text = query["text"] - prompt = llama3_template.format( - f"{text}\nSummary above sentence in one word: ") + prompt = llama3_template.format(f"{text}\nSummary above sentence in one word: ") image = None elif query["modality"] == "image": - prompt = llama3_template.format( - "<image>\nSummary above image in one word: ") + prompt = llama3_template.format("<image>\nSummary above image in one word: ") image = query["image"] else: - modality = query['modality'] + modality = query["modality"] raise ValueError(f"Unsupported query modality: '{modality}'") engine_args = EngineArgs( @@ -83,10 +82,12 @@ def run_vlm2vec(query: Query) -> ModelRequestData: image = query["image"] elif query["modality"] == "text+image": text = query["text"] - prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501 + prompt = ( + f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501 + ) image = query["image"] else: - modality = query['modality'] + modality = query["modality"] raise ValueError(f"Unsupported query modality: '{modality}'") engine_args = EngineArgs( @@ -136,7 +137,8 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): # Disable other modalities to save memory default_limits = {"image": 0, "video": 0, "audio": 0} req_data.engine_args.limit_mm_per_prompt = default_limits | dict( - req_data.engine_args.limit_mm_per_prompt or {}) + req_data.engine_args.limit_mm_per_prompt or {} + ) engine_args = asdict(req_data.engine_args) | {"seed": seed} llm = LLM(**engine_args) @@ -145,10 +147,12 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): if req_data.image is not None: mm_data["image"] = req_data.image - outputs = llm.embed({ - "prompt": req_data.prompt, - "multi_modal_data": mm_data, - }) + outputs = llm.embed( + { + "prompt": req_data.prompt, + "multi_modal_data": mm_data, + } + ) print("-" * 50) for output in outputs: @@ -164,23 +168,30 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): def parse_args(): parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models for multimodal embedding') - parser.add_argument('--model-name', - '-m', - type=str, - default="vlm2vec", - choices=model_example_map.keys(), - help='The name of the embedding model.') - parser.add_argument('--modality', - type=str, - default="image", - choices=get_args(QueryModality), - help='Modality of the input.') - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") + description="Demo on using vLLM for offline inference with " + "vision language models for multimodal embedding" + ) + parser.add_argument( + "--model-name", + "-m", + type=str, + default="vlm2vec", + choices=model_example_map.keys(), + help="The name of the embedding model.", + ) + parser.add_argument( + "--modality", + type=str, + default="image", + choices=get_args(QueryModality), + help="Modality of the input.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.", + ) return parser.parse_args() diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 48d590b05b06..e776ff7fe6ae 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -4,6 +4,7 @@ multi-image input on vision language models for text generation, using the chat template defined by the model. """ + import os from argparse import Namespace from dataclasses import asdict @@ -59,8 +60,9 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData: limit_mm_per_prompt={"image": len(image_urls)}, ) placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls) - prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n" - "<|im_start|>assistant\n") + prompt = ( + f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n<|im_start|>assistant\n" + ) stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] return ModelRequestData( @@ -81,23 +83,21 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData: ) placeholders = [{"type": "image", "image": url} for url in image_urls] - messages = [{ - "role": - "user", - "content": [ - *placeholders, - { - "type": "text", - "text": question - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] processor = AutoProcessor.from_pretrained(model_name) - prompt = processor.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return ModelRequestData( engine_args=engine_args, @@ -106,8 +106,7 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData: ) -def load_deepseek_vl2(question: str, - image_urls: list[str]) -> ModelRequestData: +def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "deepseek-ai/deepseek-vl2-tiny" engine_args = EngineArgs( @@ -118,8 +117,9 @@ def load_deepseek_vl2(question: str, limit_mm_per_prompt={"image": len(image_urls)}, ) - placeholder = "".join(f"image_{i}:<image>\n" - for i, _ in enumerate(image_urls, start=1)) + placeholder = "".join( + f"image_{i}:<image>\n" for i, _ in enumerate(image_urls, start=1) + ) prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:" return ModelRequestData( @@ -140,23 +140,21 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: ) placeholders = [{"type": "image", "image": url} for url in image_urls] - messages = [{ - "role": - "user", - "content": [ - *placeholders, - { - "type": "text", - "text": question - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] processor = AutoProcessor.from_pretrained(model_name) - prompt = processor.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return ModelRequestData( engine_args=engine_args, @@ -176,15 +174,15 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: mm_processor_kwargs={"max_dynamic_patch": 4}, ) - placeholders = "\n".join(f"Image-{i}: <image>\n" - for i, _ in enumerate(image_urls, start=1)) - messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # Stop tokens for H2OVL-Mississippi # https://huggingface.co/h2oai/h2ovl-mississippi-800m @@ -211,14 +209,13 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData: # if you are running out of memory, you can reduce the "longest_edge". # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations mm_processor_kwargs={ - "size": { - "longest_edge": 2 * 364 - }, + "size": {"longest_edge": 2 * 364}, }, ) - placeholders = "\n".join(f"Image-{i}: <image>\n" - for i, _ in enumerate(image_urls, start=1)) + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501 return ModelRequestData( engine_args=engine_args, @@ -238,15 +235,16 @@ def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData: enforce_eager=True, limit_mm_per_prompt={"image": len(image_urls)}, mm_processor_kwargs={ - "max_image_size": { - "longest_edge": 384 - }, + "max_image_size": {"longest_edge": 384}, }, ) - placeholders = "\n".join(f"Image-{i}: <image>\n" - for i, _ in enumerate(image_urls, start=1)) - prompt = f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501 + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + prompt = ( + f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501 + ) return ModelRequestData( engine_args=engine_args, prompt=prompt, @@ -265,15 +263,15 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: mm_processor_kwargs={"max_dynamic_patch": 4}, ) - placeholders = "\n".join(f"Image-{i}: <image>\n" - for i, _ in enumerate(image_urls, start=1)) - messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # Stop tokens for InternVL # models variants may have different stop tokens @@ -301,23 +299,21 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: ) placeholders = [{"type": "image", "image": url} for url in image_urls] - messages = [{ - "role": - "user", - "content": [ - *placeholders, - { - "type": "text", - "text": question - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] processor = AutoProcessor.from_pretrained(model_name) - prompt = processor.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return ModelRequestData( engine_args=engine_args, @@ -338,24 +334,21 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) placeholders = [{"type": "image", "image": url} for url in image_urls] - messages = [{ - "role": - "user", - "content": [ - *placeholders, - { - "type": "text", - "text": question - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] - processor = AutoProcessor.from_pretrained(model_name, - trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) - prompt = processor.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return ModelRequestData( engine_args=engine_args, @@ -419,15 +412,15 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: mm_processor_kwargs={"max_dynamic_patch": 4}, ) - placeholders = "\n".join(f"Image-{i}: <image>\n" - for i, _ in enumerate(image_urls, start=1)) - messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return ModelRequestData( engine_args=engine_args, @@ -436,8 +429,8 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: ) -# Ovis2 -def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData: +# Ovis +def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "AIDC-AI/Ovis2-1B" engine_args = EngineArgs( @@ -447,15 +440,17 @@ def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData: trust_remote_code=True, dtype="half", limit_mm_per_prompt={"image": len(image_urls)}, - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, ) - placeholder = '\n'.join( - [f'Image {i+1}: <image>' for i in range(len(image_urls))]) + '\n' - prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - f"<|im_start|>user\n{placeholder}" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return ModelRequestData( engine_args=engine_args, @@ -507,8 +502,9 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData: limit_mm_per_prompt={"image": len(image_urls)}, mm_processor_kwargs={"num_crops": 4}, ) - placeholders = "\n".join(f"<|image_{i}|>" - for i, _ in enumerate(image_urls, start=1)) + placeholders = "\n".join( + f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1) + ) prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" return ModelRequestData( @@ -540,8 +536,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: mm_processor_kwargs={"dynamic_hd": 4}, ) - placeholders = "".join(f"<|image_{i}|>" - for i, _ in enumerate(image_urls, start=1)) + placeholders = "".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)) prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>" return ModelRequestData( @@ -552,8 +547,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: ) -def load_qwen_vl_chat(question: str, - image_urls: list[str]) -> ModelRequestData: +def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" engine_args = EngineArgs( model=model_name, @@ -563,24 +557,26 @@ def load_qwen_vl_chat(question: str, hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, limit_mm_per_prompt={"image": len(image_urls)}, ) - placeholders = "".join(f"Picture {i}: <img></img>\n" - for i, _ in enumerate(image_urls, start=1)) + placeholders = "".join( + f"Picture {i}: <img></img>\n" for i, _ in enumerate(image_urls, start=1) + ) # This model does not have a chat_template attribute on its tokenizer, # so we need to explicitly pass it. We use ChatML since it's used in the # generation utils of the model: # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265 - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501 - messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True, - chat_template=chat_template) + messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=chat_template, + ) stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] @@ -598,9 +594,11 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData: try: from qwen_vl_utils import process_vision_info except ModuleNotFoundError: - print('WARNING: `qwen-vl-utils` not installed, input images will not ' - 'be automatically resized. You can enable this functionality by ' - '`pip install qwen-vl-utils`.') + print( + "WARNING: `qwen-vl-utils` not installed, input images will not " + "be automatically resized. You can enable this functionality by " + "`pip install qwen-vl-utils`." + ) process_vision_info = None model_name = "Qwen/Qwen2-VL-7B-Instruct" @@ -614,26 +612,22 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) placeholders = [{"type": "image", "image": url} for url in image_urls] - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": - "user", - "content": [ - *placeholders, - { - "type": "text", - "text": question - }, - ], - }] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] processor = AutoProcessor.from_pretrained(model_name) - prompt = processor.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) if process_vision_info is None: image_data = [fetch_image(url) for url in image_urls] @@ -651,9 +645,11 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: try: from qwen_vl_utils import process_vision_info except ModuleNotFoundError: - print('WARNING: `qwen-vl-utils` not installed, input images will not ' - 'be automatically resized. You can enable this functionality by ' - '`pip install qwen-vl-utils`.') + print( + "WARNING: `qwen-vl-utils` not installed, input images will not " + "be automatically resized. You can enable this functionality by " + "`pip install qwen-vl-utils`." + ) process_vision_info = None model_name = "Qwen/Qwen2.5-VL-3B-Instruct" @@ -666,32 +662,27 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) placeholders = [{"type": "image", "image": url} for url in image_urls] - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": - "user", - "content": [ - *placeholders, - { - "type": "text", - "text": question - }, - ], - }] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] processor = AutoProcessor.from_pretrained(model_name) - prompt = processor.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) if process_vision_info is None: image_data = [fetch_image(url) for url in image_urls] else: - image_data, _ = process_vision_info(messages, - return_video_kwargs=False) + image_data, _ = process_vision_info(messages, return_video_kwargs=False) return ModelRequestData( engine_args=engine_args, @@ -713,7 +704,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: "mistral3": load_mistral3, "mllama": load_mllama, "NVLM_D": load_nvlm_d, - "ovis2": load_ovis2, + "ovis": load_ovis, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "pixtral_hf": load_pixtral_hf, @@ -724,23 +715,20 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: } -def run_generate(model, question: str, image_urls: list[str], - seed: Optional[int]): +def run_generate(model, question: str, image_urls: list[str], seed: Optional[int]): req_data = model_example_map[model](question, image_urls) engine_args = asdict(req_data.engine_args) | {"seed": args.seed} llm = LLM(**engine_args) - sampling_params = SamplingParams(temperature=0.0, - max_tokens=256, - stop_token_ids=req_data.stop_token_ids) + sampling_params = SamplingParams( + temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids + ) outputs = llm.generate( { "prompt": req_data.prompt, - "multi_modal_data": { - "image": req_data.image_data - }, + "multi_modal_data": {"image": req_data.image_data}, }, sampling_params=sampling_params, lora_request=req_data.lora_requests, @@ -753,38 +741,40 @@ def run_generate(model, question: str, image_urls: list[str], print("-" * 50) -def run_chat(model: str, question: str, image_urls: list[str], - seed: Optional[int]): +def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[int]): req_data = model_example_map[model](question, image_urls) # Disable other modalities to save memory default_limits = {"image": 0, "video": 0, "audio": 0} req_data.engine_args.limit_mm_per_prompt = default_limits | dict( - req_data.engine_args.limit_mm_per_prompt or {}) + req_data.engine_args.limit_mm_per_prompt or {} + ) engine_args = asdict(req_data.engine_args) | {"seed": seed} llm = LLM(**engine_args) - sampling_params = SamplingParams(temperature=0.0, - max_tokens=256, - stop_token_ids=req_data.stop_token_ids) + sampling_params = SamplingParams( + temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids + ) outputs = llm.chat( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": question, - }, - *({ - "type": "image_url", - "image_url": { - "url": image_url + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": question, }, - } for image_url in image_urls), - ], - }], + *( + { + "type": "image_url", + "image_url": {"url": image_url}, + } + for image_url in image_urls + ), + ], + } + ], sampling_params=sampling_params, chat_template=req_data.chat_template, lora_request=req_data.lora_requests, @@ -799,32 +789,39 @@ def run_chat(model: str, question: str, image_urls: list[str], def parse_args(): parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models that support multi-image input for text ' - 'generation') - parser.add_argument('--model-type', - '-m', - type=str, - default="phi3_v", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument("--method", - type=str, - default="generate", - choices=["generate", "chat"], - help="The method to run in `vllm.LLM`.") - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") + description="Demo on using vLLM for offline inference with " + "vision language models that support multi-image input for text " + "generation" + ) + parser.add_argument( + "--model-type", + "-m", + type=str, + default="phi3_v", + choices=model_example_map.keys(), + help='Huggingface "model_type".', + ) + parser.add_argument( + "--method", + type=str, + default="generate", + choices=["generate", "chat"], + help="The method to run in `vllm.LLM`.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.", + ) parser.add_argument( "--num-images", "-n", type=int, - choices=list(range(1, - len(IMAGE_URLS) + 1)), # the max number of images + choices=list(range(1, len(IMAGE_URLS) + 1)), # the max number of images default=2, - help="Number of images to use for the demo.") + help="Number of images to use for the demo.", + ) return parser.parse_args() @@ -833,7 +830,7 @@ def main(args: Namespace): method = args.method seed = args.seed - image_urls = IMAGE_URLS[:args.num_images] + image_urls = IMAGE_URLS[: args.num_images] if method == "generate": run_generate(model, QUESTION, image_urls, seed) diff --git a/examples/online_serving/api_client.py b/examples/online_serving/api_client.py index 36079ff11d07..cc190e91c141 100644 --- a/examples/online_serving/api_client.py +++ b/examples/online_serving/api_client.py @@ -17,16 +17,15 @@ def clear_line(n: int = 1) -> None: - LINE_UP = '\033[1A' - LINE_CLEAR = '\x1b[2K' + LINE_UP = "\033[1A" + LINE_CLEAR = "\x1b[2K" for _ in range(n): print(LINE_UP, end=LINE_CLEAR, flush=True) -def post_http_request(prompt: str, - api_url: str, - n: int = 1, - stream: bool = False) -> requests.Response: +def post_http_request( + prompt: str, api_url: str, n: int = 1, stream: bool = False +) -> requests.Response: headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, @@ -35,17 +34,14 @@ def post_http_request(prompt: str, "max_tokens": 16, "stream": stream, } - response = requests.post(api_url, - headers=headers, - json=pload, - stream=stream) + response = requests.post(api_url, headers=headers, json=pload, stream=stream) return response def get_streaming_response(response: requests.Response) -> Iterable[list[str]]: - for chunk in response.iter_lines(chunk_size=8192, - decode_unicode=False, - delimiter=b"\n"): + for chunk in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\n" + ): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["text"] diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/cohere_rerank_client.py index c2d4ef08ddbb..e57b94e8805f 100644 --- a/examples/online_serving/cohere_rerank_client.py +++ b/examples/online_serving/cohere_rerank_client.py @@ -6,6 +6,7 @@ run: vllm serve BAAI/bge-reranker-base """ + from typing import Union import cohere @@ -16,28 +17,28 @@ query = "What is the capital of France?" documents = [ - "The capital of France is Paris", "Reranking is fun!", - "vLLM is an open-source framework for fast AI serving" + "The capital of France is Paris", + "Reranking is fun!", + "vLLM is an open-source framework for fast AI serving", ] -def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str, - documents: list[str]) -> dict: +def cohere_rerank( + client: Union[Client, ClientV2], model: str, query: str, documents: list[str] +) -> dict: return client.rerank(model=model, query=query, documents=documents) def main(): # cohere v1 client - cohere_v1 = cohere.Client(base_url="http://localhost:8000", - api_key="sk-fake-key") + cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key") rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents) print("-" * 50) print("rerank_v1_result:\n", rerank_v1_result) print("-" * 50) # or the v2 - cohere_v2 = cohere.ClientV2("sk-fake-key", - base_url="http://localhost:8000") + cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000") rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents) print("rerank_v2_result:\n", rerank_v2_result) print("-" * 50) diff --git a/examples/online_serving/disaggregated_serving/README.md b/examples/online_serving/disaggregated_serving/README.md new file mode 100644 index 000000000000..090afd7515ee --- /dev/null +++ b/examples/online_serving/disaggregated_serving/README.md @@ -0,0 +1,8 @@ +# Disaggregated Serving + +This example contains scripts that demonstrate the disaggregated serving features of vLLM. + +## Files + +- `disagg_proxy_demo.py` - Demonstrates XpYd (X prefill instances, Y decode instances). +- `kv_events.sh` - Demonstrates KV cache event publishing. diff --git a/examples/online_serving/disagg_examples/disagg_proxy_demo.py b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py similarity index 72% rename from examples/online_serving/disagg_examples/disagg_proxy_demo.py rename to examples/online_serving/disaggregated_serving/disagg_proxy_demo.py index a701636f357a..2ffba4a7ed3f 100644 --- a/examples/online_serving/disagg_examples/disagg_proxy_demo.py +++ b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py @@ -4,7 +4,7 @@ example usage of XpYd disaggregated prefilling. We can launch multiple vllm instances (2 for prefill and 2 for decode), and launch this proxy demo through: - python3 examples/online_serving/disagg_examples/disagg_proxy_demo.py \ + python3 examples/online_serving/disaggregated_serving/disagg_proxy_demo.py \ --model $model_name \ --prefill localhost:8100 localhost:8101 \ --decode localhost:8200 localhost:8201 \ @@ -13,6 +13,7 @@ Note: This demo will be removed once the PDController implemented in PR 15343 (https://github.com/vllm-project/vllm/pull/15343) supports XpYd. """ + import argparse import ipaddress import itertools @@ -26,8 +27,7 @@ import aiohttp import requests import uvicorn -from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException, - Request, status) +from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status from fastapi.responses import JSONResponse, StreamingResponse AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) @@ -36,24 +36,24 @@ class SchedulingPolicy(ABC): - @abstractmethod def schedule(self, cycler: itertools.cycle): raise NotImplementedError("Scheduling Proxy is not set.") class Proxy: - def __init__( self, prefill_instances: list[str], decode_instances: list[str], model: str, scheduling_policy: SchedulingPolicy, - custom_create_completion: Optional[Callable[[Request], - StreamingResponse]] = None, - custom_create_chat_completion: Optional[Callable[ - [Request], StreamingResponse]] = None, + custom_create_completion: Optional[ + Callable[[Request], StreamingResponse] + ] = None, + custom_create_chat_completion: Optional[ + Callable[[Request], StreamingResponse] + ] = None, ): self.prefill_instances = prefill_instances self.decode_instances = decode_instances @@ -68,30 +68,30 @@ def __init__( def setup_routes(self): self.router.post( - "/v1/completions", - dependencies=[ - Depends(self.validate_json_request) - ])(self.custom_create_completion if self. - custom_create_completion else self.create_completion) + "/v1/completions", dependencies=[Depends(self.validate_json_request)] + )( + self.custom_create_completion + if self.custom_create_completion + else self.create_completion + ) + self.router.post( + "/v1/chat/completions", dependencies=[Depends(self.validate_json_request)] + )( + self.custom_create_chat_completion + if self.custom_create_chat_completion + else self.create_chat_completion + ) + self.router.get("/status", response_class=JSONResponse)(self.get_status) self.router.post( - "/v1/chat/completions", - dependencies=[ - Depends(self.validate_json_request) - ])(self.custom_create_chat_completion if self. - custom_create_chat_completion else self.create_chat_completion) - self.router.get("/status", - response_class=JSONResponse)(self.get_status) - self.router.post("/instances/add", - dependencies=[Depends(self.api_key_authenticate) - ])(self.add_instance_endpoint) + "/instances/add", dependencies=[Depends(self.api_key_authenticate)] + )(self.add_instance_endpoint) async def validate_json_request(self, raw_request: Request): content_type = raw_request.headers.get("content-type", "").lower() if content_type != "application/json": raise HTTPException( status_code=415, - detail= - "Unsupported Media Type: Only 'application/json' is allowed", + detail="Unsupported Media Type: Only 'application/json' is allowed", ) def api_key_authenticate(self, x_api_key: str = Header(...)): @@ -103,8 +103,7 @@ def api_key_authenticate(self, x_api_key: str = Header(...)): detail="Server configuration error.", ) if x_api_key != expected_api_key: - logger.warning("Unauthorized access attempt with API Key: %s", - x_api_key) + logger.warning("Unauthorized access attempt with API Key: %s", x_api_key) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden: Invalid API Key.", @@ -113,8 +112,7 @@ def api_key_authenticate(self, x_api_key: str = Header(...)): async def validate_instance(self, instance: str) -> bool: url = f"http://{instance}/v1/models" try: - async with aiohttp.ClientSession( - timeout=AIOHTTP_TIMEOUT) as client: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as client: logger.info("Verifying %s ...", instance) async with client.get(url) as response: if response.status == 200: @@ -122,12 +120,15 @@ async def validate_instance(self, instance: str) -> bool: if "data" in data and len(data["data"]) > 0: model_cur = data["data"][0].get("id", "") if model_cur == self.model: - logger.info("Instance: %s could be added.", - instance) + logger.info("Instance: %s could be added.", instance) return True else: - logger.warning("Mismatch model %s : %s != %s", - instance, model_cur, self.model) + logger.warning( + "Mismatch model %s : %s != %s", + instance, + model_cur, + self.model, + ) return False else: return False @@ -147,48 +148,47 @@ async def add_instance_endpoint(self, request: Request): instance_type = data.get("type") instance = data.get("instance") if instance_type not in ["prefill", "decode"]: - raise HTTPException(status_code=400, - detail="Invalid instance type.") + raise HTTPException(status_code=400, detail="Invalid instance type.") if not instance or ":" not in instance: - raise HTTPException(status_code=400, - detail="Invalid instance format.") + raise HTTPException(status_code=400, detail="Invalid instance format.") host, port_str = instance.split(":") try: if host != "localhost": ipaddress.ip_address(host) port = int(port_str) if not (0 < port < 65536): - raise HTTPException(status_code=400, - detail="Invalid port number.") + raise HTTPException(status_code=400, detail="Invalid port number.") except Exception as e: - raise HTTPException(status_code=400, - detail="Invalid instance address.") from e + raise HTTPException( + status_code=400, detail="Invalid instance address." + ) from e is_valid = await self.validate_instance(instance) if not is_valid: - raise HTTPException(status_code=400, - detail="Instance validation failed.") + raise HTTPException( + status_code=400, detail="Instance validation failed." + ) if instance_type == "prefill": if instance not in self.prefill_instances: self.prefill_instances.append(instance) - self.prefill_cycler = itertools.cycle( - self.prefill_instances) + self.prefill_cycler = itertools.cycle(self.prefill_instances) else: - raise HTTPException(status_code=400, - detail="Instance already exists.") + raise HTTPException( + status_code=400, detail="Instance already exists." + ) else: if instance not in self.decode_instances: self.decode_instances.append(instance) self.decode_cycler = itertools.cycle(self.decode_instances) else: - raise HTTPException(status_code=400, - detail="Instance already exists.") + raise HTTPException( + status_code=400, detail="Instance already exists." + ) - return JSONResponse(content={ - "message": - f"Added {instance} to {instance_type}_instances." - }) + return JSONResponse( + content={"message": f"Added {instance} to {instance_type}_instances."} + ) except HTTPException as http_exc: raise http_exc except Exception as e: @@ -197,16 +197,16 @@ async def add_instance_endpoint(self, request: Request): async def forward_request(self, url, data, use_chunked=True): async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} try: - async with session.post(url=url, json=data, - headers=headers) as response: + async with session.post( + url=url, json=data, headers=headers + ) as response: if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 if use_chunked: async for chunk_bytes in response.content.iter_chunked( # noqa: E501 - 1024): + 1024 + ): yield chunk_bytes else: content = await response.read() @@ -217,20 +217,21 @@ async def forward_request(self, url, data, use_chunked=True): error_content = json.loads(error_content) except json.JSONDecodeError: error_content = error_content - logger.error("Request failed with status %s: %s", - response.status, error_content) + logger.error( + "Request failed with status %s: %s", + response.status, + error_content, + ) raise HTTPException( status_code=response.status, - detail= - f"Request failed with status {response.status}: " + detail=f"Request failed with status {response.status}: " f"{error_content}", ) except aiohttp.ClientError as e: logger.error("ClientError occurred: %s", str(e)) raise HTTPException( status_code=502, - detail= - "Bad Gateway: Error communicating with upstream server.", + detail="Bad Gateway: Error communicating with upstream server.", ) from e except Exception as e: logger.error("Unexpected error: %s", str(e)) @@ -258,8 +259,8 @@ async def create_completion(self, raw_request: Request): prefill_instance = self.schedule(self.prefill_cycler) try: async for _ in self.forward_request( - f"http://{prefill_instance}/v1/completions", - kv_prepare_request): + f"http://{prefill_instance}/v1/completions", kv_prepare_request + ): continue except HTTPException as http_exc: self.remove_instance_endpoint("prefill", prefill_instance) @@ -270,7 +271,8 @@ async def create_completion(self, raw_request: Request): try: generator = self.forward_request( - f"http://{decode_instance}/v1/completions", request) + f"http://{decode_instance}/v1/completions", request + ) except HTTPException as http_exc: self.remove_instance_endpoint("decode", decode_instance) raise http_exc @@ -295,8 +297,8 @@ async def create_chat_completion(self, raw_request: Request): prefill_instance = self.schedule(self.prefill_cycler) try: async for _ in self.forward_request( - f"http://{prefill_instance}/v1/chat/completions", - kv_prepare_request): + f"http://{prefill_instance}/v1/chat/completions", kv_prepare_request + ): continue except HTTPException as http_exc: self.remove_instance_endpoint("prefill", prefill_instance) @@ -306,8 +308,8 @@ async def create_chat_completion(self, raw_request: Request): try: generator = self.forward_request( - "http://" + decode_instance + "/v1/chat/completions", - request) + "http://" + decode_instance + "/v1/chat/completions", request + ) except HTTPException as http_exc: self.remove_instance_endpoint("decode", decode_instance) raise http_exc @@ -318,20 +320,20 @@ async def create_chat_completion(self, raw_request: Request): error_messages = [str(e) for e in exc_info if e] print("Error occurred in disagg proxy server") print(error_messages) - return StreamingResponse(content=iter(error_messages), - media_type="text/event-stream") + return StreamingResponse( + content=iter(error_messages), media_type="text/event-stream" + ) def remove_instance_endpoint(self, instance_type, instance): - if (instance_type == "decode" and instance in self.decode_instances): + if instance_type == "decode" and instance in self.decode_instances: self.decode_instances.remove(instance) self.decode_cycler = itertools.cycle(self.decode_instances) - if (instance_type == "prefill" and instance in self.decode_instances): + if instance_type == "prefill" and instance in self.decode_instances: self.prefill_instances.remove(instance) self.prefill_cycler = itertools.cycle(self.decode_instances) class RoundRobinSchedulingPolicy(SchedulingPolicy): - def __init__(self): super().__init__() @@ -340,15 +342,12 @@ def schedule(self, cycler: itertools.cycle) -> str: class ProxyServer: - def __init__( self, args: argparse.Namespace, scheduling_policy: Optional[SchedulingPolicy] = None, - create_completion: Optional[Callable[[Request], - StreamingResponse]] = None, - create_chat_completion: Optional[Callable[[Request], - StreamingResponse]] = None, + create_completion: Optional[Callable[[Request], StreamingResponse]] = None, + create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None, ): self.validate_parsed_serve_args(args) self.port = args.port @@ -356,8 +355,11 @@ def __init__( prefill_instances=[] if args.prefill is None else args.prefill, decode_instances=[] if args.decode is None else args.decode, model=args.model, - scheduling_policy=(scheduling_policy if scheduling_policy - is not None else RoundRobinSchedulingPolicy()), + scheduling_policy=( + scheduling_policy + if scheduling_policy is not None + else RoundRobinSchedulingPolicy() + ), custom_create_completion=create_completion, custom_create_chat_completion=create_chat_completion, ) @@ -382,11 +384,9 @@ def validate_instances(self, instances: list): ipaddress.ip_address(host) port = int(port) if not (0 < port < 65536): - raise ValueError( - f"Invalid port number in instance: {instance}") + raise ValueError(f"Invalid port number in instance: {instance}") except Exception as e: - raise ValueError( - f"Invalid instance {instance}: {str(e)}") from e + raise ValueError(f"Invalid instance {instance}: {str(e)}") from e def verify_model_config(self, instances: list, model: str) -> None: model_suffix = model.split("/")[-1] @@ -399,12 +399,14 @@ def verify_model_config(self, instances: list, model: str) -> None: if model_cur_suffix != model_suffix: raise ValueError( f"{instance} serves a different model: " - f"{model_cur} != {model}") + f"{model_cur} != {model}" + ) else: raise ValueError(f"Cannot get model id from {instance}!") except requests.RequestException as e: raise ValueError( - f"Error communicating with {instance}: {str(e)}") from e + f"Error communicating with {instance}: {str(e)}" + ) from e def run_server(self): app = FastAPI() @@ -414,14 +416,10 @@ def run_server(self): server.run() -if __name__ == "__main__": +def parse_args(): # Todo: allow more config parser = argparse.ArgumentParser("vLLM disaggregated proxy server.") - parser.add_argument("--model", - "-m", - type=str, - required=True, - help="Model name") + parser.add_argument("--model", "-m", type=str, required=True, help="Model name") parser.add_argument( "--prefill", @@ -445,6 +443,10 @@ def run_server(self): default=8000, help="Server port number", ) - args = parser.parse_args() + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() proxy_server = ProxyServer(args=args) proxy_server.run_server() diff --git a/examples/online_serving/kv_events.sh b/examples/online_serving/disaggregated_serving/kv_events.sh similarity index 100% rename from examples/online_serving/kv_events.sh rename to examples/online_serving/disaggregated_serving/kv_events.sh diff --git a/examples/online_serving/gradio_openai_chatbot_webserver.py b/examples/online_serving/gradio_openai_chatbot_webserver.py index 314f1c5b7395..3f2a3d01b456 100644 --- a/examples/online_serving/gradio_openai_chatbot_webserver.py +++ b/examples/online_serving/gradio_openai_chatbot_webserver.py @@ -17,6 +17,7 @@ 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc """ + import argparse import gradio as gr @@ -24,16 +25,12 @@ def format_history_to_openai(history): - history_openai_format = [{ - "role": "system", - "content": "You are a great AI assistant." - }] + history_openai_format = [ + {"role": "system", "content": "You are a great AI assistant."} + ] for human, assistant in history: history_openai_format.append({"role": "user", "content": human}) - history_openai_format.append({ - "role": "assistant", - "content": assistant - }) + history_openai_format.append({"role": "assistant", "content": assistant}) return history_openai_format @@ -49,17 +46,17 @@ def predict(message, history, client, model_name, temp, stop_token_ids): temperature=temp, stream=True, extra_body={ - 'repetition_penalty': - 1, - 'stop_token_ids': - [int(id.strip()) - for id in stop_token_ids.split(',')] if stop_token_ids else [] - }) + "repetition_penalty": 1, + "stop_token_ids": [int(id.strip()) for id in stop_token_ids.split(",")] + if stop_token_ids + else [], + }, + ) # Collect all chunks and concatenate them into a full message full_message = "" for chunk in stream: - full_message += (chunk.choices[0].delta.content or "") + full_message += chunk.choices[0].delta.content or "" # Return the full message as a single response return full_message @@ -67,38 +64,34 @@ def predict(message, history, client, model_name, temp, stop_token_ids): def parse_args(): parser = argparse.ArgumentParser( - description='Chatbot Interface with Customizable Parameters') - parser.add_argument('--model-url', - type=str, - default='http://localhost:8000/v1', - help='Model URL') - parser.add_argument('-m', - '--model', - type=str, - required=True, - help='Model name for the chatbot') - parser.add_argument('--temp', - type=float, - default=0.8, - help='Temperature for text generation') - parser.add_argument('--stop-token-ids', - type=str, - default='', - help='Comma-separated stop token IDs') + description="Chatbot Interface with Customizable Parameters" + ) + parser.add_argument( + "--model-url", type=str, default="http://localhost:8000/v1", help="Model URL" + ) + parser.add_argument( + "-m", "--model", type=str, required=True, help="Model name for the chatbot" + ) + parser.add_argument( + "--temp", type=float, default=0.8, help="Temperature for text generation" + ) + parser.add_argument( + "--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs" + ) parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8001) return parser.parse_args() def build_gradio_interface(client, model_name, temp, stop_token_ids): - def chat_predict(message, history): - return predict(message, history, client, model_name, temp, - stop_token_ids) + return predict(message, history, client, model_name, temp, stop_token_ids) - return gr.ChatInterface(fn=chat_predict, - title="Chatbot Interface", - description="A simple chatbot powered by vLLM") + return gr.ChatInterface( + fn=chat_predict, + title="Chatbot Interface", + description="A simple chatbot powered by vLLM", + ) def main(): @@ -113,12 +106,13 @@ def main(): client = OpenAI(api_key=openai_api_key, base_url=openai_api_base) # Define the Gradio chatbot interface using the predict function - gradio_interface = build_gradio_interface(client, args.model, args.temp, - args.stop_token_ids) + gradio_interface = build_gradio_interface( + client, args.model, args.temp, args.stop_token_ids + ) - gradio_interface.queue().launch(server_name=args.host, - server_port=args.port, - share=True) + gradio_interface.queue().launch( + server_name=args.host, server_port=args.port, share=True + ) if __name__ == "__main__": diff --git a/examples/online_serving/gradio_webserver.py b/examples/online_serving/gradio_webserver.py index 2e7c2a0c5838..fd341ff493b5 100644 --- a/examples/online_serving/gradio_webserver.py +++ b/examples/online_serving/gradio_webserver.py @@ -17,6 +17,7 @@ 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc """ + import argparse import json @@ -31,14 +32,11 @@ def http_bot(prompt): "stream": True, "max_tokens": 128, } - response = requests.post(args.model_url, - headers=headers, - json=pload, - stream=True) - - for chunk in response.iter_lines(chunk_size=8192, - decode_unicode=False, - delimiter=b"\n"): + response = requests.post(args.model_url, headers=headers, json=pload, stream=True) + + for chunk in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\n" + ): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["text"][0] @@ -48,10 +46,10 @@ def http_bot(prompt): def build_demo(): with gr.Blocks() as demo: gr.Markdown("# vLLM text completion demo\n") - inputbox = gr.Textbox(label="Input", - placeholder="Enter text and press ENTER") - outputbox = gr.Textbox(label="Output", - placeholder="Generated result from the model") + inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") + outputbox = gr.Textbox( + label="Output", placeholder="Generated result from the model" + ) inputbox.submit(http_bot, [inputbox], [outputbox]) return demo @@ -60,17 +58,15 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8001) - parser.add_argument("--model-url", - type=str, - default="http://localhost:8000/generate") + parser.add_argument( + "--model-url", type=str, default="http://localhost:8000/generate" + ) return parser.parse_args() def main(args): demo = build_demo() - demo.queue().launch(server_name=args.host, - server_port=args.port, - share=True) + demo.queue().launch(server_name=args.host, server_port=args.port, share=True) if __name__ == "__main__": diff --git a/examples/online_serving/jinaai_rerank_client.py b/examples/online_serving/jinaai_rerank_client.py index 3076bba765ce..7eb3d2193f41 100644 --- a/examples/online_serving/jinaai_rerank_client.py +++ b/examples/online_serving/jinaai_rerank_client.py @@ -5,6 +5,7 @@ run: vllm serve BAAI/bge-reranker-base """ + import json import requests @@ -14,14 +15,13 @@ headers = {"accept": "application/json", "Content-Type": "application/json"} data = { - "model": - "BAAI/bge-reranker-base", - "query": - "What is the capital of France?", + "model": "BAAI/bge-reranker-base", + "query": "What is the capital of France?", "documents": [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris.", "Horses and cows are both animals" - ] + "The capital of France is Paris.", + "Horses and cows are both animals", + ], } diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index 88bbbebd7478..65d74dccab80 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -9,17 +9,14 @@ # # Types copied from vllm.distributed.kv_events # -class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, - gc=False): +class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, gc=False): ts: float events: list[Any] -class KVCacheEvent(msgspec.Struct, - array_like=True, - omit_defaults=True, - gc=False, - tag=True): +class KVCacheEvent( + msgspec.Struct, array_like=True, omit_defaults=True, gc=False, tag=True +): """Base class for all KV cache-related events""" @@ -77,8 +74,9 @@ def main(): if last_seq >= 0 and seq > last_seq + 1: missed = seq - last_seq - 1 - print(f"Missed {missed} messages" - f" (last: {last_seq}, current: {seq})") + print( + f"Missed {missed} messages (last: {last_seq}, current: {seq})" + ) replay.send((last_seq + 1).to_bytes(8, "big")) diff --git a/examples/online_serving/openai_chat_completion_client.py b/examples/online_serving/openai_chat_completion_client.py index 74e0c045d621..2856e3be3e2d 100644 --- a/examples/online_serving/openai_chat_completion_client.py +++ b/examples/online_serving/openai_chat_completion_client.py @@ -3,28 +3,35 @@ NOTE: start a supported chat completion model server with `vllm serve`, e.g. vllm serve meta-llama/Llama-2-7b-chat-hf """ + +import argparse + from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -messages = [{ - "role": "system", - "content": "You are a helpful assistant." -}, { - "role": "user", - "content": "Who won the world series in 2020?" -}, { - "role": "assistant", - "content": "The Los Angeles Dodgers won the World Series in 2020." -}, { - "role": "user", - "content": "Where was it played?" -}] - - -def main(): +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, +] + + +def parse_args(): + parser = argparse.ArgumentParser(description="Client for vLLM API server") + parser.add_argument( + "--stream", action="store_true", help="Enable streaming response" + ) + return parser.parse_args() + + +def main(args): client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, @@ -34,16 +41,23 @@ def main(): models = client.models.list() model = models.data[0].id + # Chat Completion API chat_completion = client.chat.completions.create( messages=messages, model=model, + stream=args.stream, ) print("-" * 50) print("Chat completion results:") - print(chat_completion) + if args.stream: + for c in chat_completion: + print(c) + else: + print(chat_completion) print("-" * 50) if __name__ == "__main__": - main() + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index cffd093c983a..8c3c6ecdd4b0 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""An example showing how to use vLLM to serve multimodal models +"""An example showing how to use vLLM to serve multimodal models and run online serving with OpenAI client. Launch the vLLM server with the following command: @@ -12,12 +12,18 @@ --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' (audio inference with Ultravox) -vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b \ + --max-model-len 4096 --trust-remote-code + +run the script with +python openai_chat_completion_client_for_multimodal.py --chat-type audio """ + import base64 import requests from openai import OpenAI +from utils import get_first_model from vllm.utils import FlexibleArgumentParser @@ -31,27 +37,21 @@ base_url=openai_api_base, ) -models = client.models.list() -model = models.data[0].id - def encode_base64_content_from_url(content_url: str) -> str: """Encode a content retrieved from a remote url to base64 format.""" with requests.get(content_url) as response: response.raise_for_status() - result = base64.b64encode(response.content).decode('utf-8') + result = base64.b64encode(response.content).decode("utf-8") return result # Text-only inference -def run_text_only() -> None: +def run_text_only(model: str) -> None: chat_completion = client.chat.completions.create( - messages=[{ - "role": "user", - "content": "What's the capital of France?" - }], + messages=[{"role": "user", "content": "What's the capital of France?"}], model=model, max_completion_tokens=64, ) @@ -61,27 +61,22 @@ def run_text_only() -> None: # Single-image input inference -def run_single_image() -> None: - +def run_single_image(model: str) -> None: ## Use image url in the payload image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": image_url}, }, - }, - ], - }], + ], + } + ], model=model, max_completion_tokens=64, ) @@ -92,22 +87,18 @@ def run_single_image() -> None: ## Use base64 encoded image in the payload image_base64 = encode_base64_content_from_url(image_url) chat_completion_from_base64 = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, }, - }, - ], - }], + ], + } + ], model=model, max_completion_tokens=64, ) @@ -117,32 +108,26 @@ def run_single_image() -> None: # Multi-image input inference -def run_multi_image() -> None: +def run_multi_image(model: str) -> None: image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What are the animals in these images?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url_duck + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What are the animals in these images?"}, + { + "type": "image_url", + "image_url": {"url": image_url_duck}, }, - }, - { - "type": "image_url", - "image_url": { - "url": image_url_lion + { + "type": "image_url", + "image_url": {"url": image_url_lion}, }, - }, - ], - }], + ], + } + ], model=model, max_completion_tokens=64, ) @@ -152,28 +137,24 @@ def run_multi_image() -> None: # Video input inference -def run_video() -> None: +def run_video(model: str) -> None: video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4" video_base64 = encode_base64_content_from_url(video_url) ## Use video url in the payload chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this video?" - }, - { - "type": "video_url", - "video_url": { - "url": video_url + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this video?"}, + { + "type": "video_url", + "video_url": {"url": video_url}, }, - }, - ], - }], + ], + } + ], model=model, max_completion_tokens=64, ) @@ -183,22 +164,18 @@ def run_video() -> None: ## Use base64 encoded video in the payload chat_completion_from_base64 = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this video?" - }, - { - "type": "video_url", - "video_url": { - "url": f"data:video/mp4;base64,{video_base64}" + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this video?"}, + { + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, }, - }, - ], - }], + ], + } + ], model=model, max_completion_tokens=64, ) @@ -208,7 +185,7 @@ def run_video() -> None: # Audio input inference -def run_audio() -> None: +def run_audio(model: str) -> None: from vllm.assets.audio import AudioAsset audio_url = AudioAsset("winning_call").url @@ -216,24 +193,22 @@ def run_audio() -> None: # OpenAI-compatible schema (`input_audio`) chat_completion_from_base64 = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this audio?" - }, - { - "type": "input_audio", - "input_audio": { - # Any format supported by librosa is supported - "data": audio_base64, - "format": "wav" + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this audio?"}, + { + "type": "input_audio", + "input_audio": { + # Any format supported by librosa is supported + "data": audio_base64, + "format": "wav", + }, }, - }, - ], - }], + ], + } + ], model=model, max_completion_tokens=64, ) @@ -243,23 +218,21 @@ def run_audio() -> None: # HTTP URL chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this audio?" - }, - { - "type": "audio_url", - "audio_url": { - # Any format supported by librosa is supported - "url": audio_url + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this audio?"}, + { + "type": "audio_url", + "audio_url": { + # Any format supported by librosa is supported + "url": audio_url + }, }, - }, - ], - }], + ], + } + ], model=model, max_completion_tokens=64, ) @@ -269,23 +242,21 @@ def run_audio() -> None: # base64 URL chat_completion_from_base64 = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this audio?" - }, - { - "type": "audio_url", - "audio_url": { - # Any format supported by librosa is supported - "url": f"data:audio/ogg;base64,{audio_base64}" + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this audio?"}, + { + "type": "audio_url", + "audio_url": { + # Any format supported by librosa is supported + "url": f"data:audio/ogg;base64,{audio_base64}" + }, }, - }, - ], - }], + ], + } + ], model=model, max_completion_tokens=64, ) @@ -305,20 +276,24 @@ def run_audio() -> None: def parse_args(): parser = FlexibleArgumentParser( - description='Demo on using OpenAI client for online serving with ' - 'multimodal language models served with vLLM.') - parser.add_argument('--chat-type', - '-c', - type=str, - default="single-image", - choices=list(example_function_map.keys()), - help='Conversation type with multimodal data.') + description="Demo on using OpenAI client for online serving with " + "multimodal language models served with vLLM." + ) + parser.add_argument( + "--chat-type", + "-c", + type=str, + default="single-image", + choices=list(example_function_map.keys()), + help="Conversation type with multimodal data.", + ) return parser.parse_args() def main(args) -> None: chat_type = args.chat_type - example_function_map[chat_type]() + model = get_first_model(client) + example_function_map[chat_type](model) if __name__ == "__main__": diff --git a/examples/online_serving/openai_chat_completion_client_with_tools.py b/examples/online_serving/openai_chat_completion_client_with_tools.py index c25203860ff3..a0d7841f644f 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools.py @@ -7,15 +7,16 @@ templates, or your own - the model default doesn't work for tool calls with vLLM See the vLLM docs on OpenAI server & tool calling for more details. -vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \ +vllm serve mistralai/Mistral-7B-Instruct-v0.3 \ --chat-template examples/tool_chat_template_mistral.jinja \ --enable-auto-tool-choice --tool-call-parser mistral OR -vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ +vllm serve NousResearch/Hermes-2-Pro-Llama-3-8B \ --chat-template examples/tool_chat_template_hermes.jinja \ --enable-auto-tool-choice --tool-call-parser hermes """ + import json from typing import Any @@ -25,55 +26,55 @@ openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -tools = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" - }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } +properties = { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, +} + +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": properties, + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, } -}] - -messages = [{ - "role": "user", - "content": "Hi! How are you doing today?" -}, { - "role": "assistant", - "content": "I'm doing well! How can I help you?" -}, { - "role": - "user", - "content": - "Can you tell me what the temperate will be in Dallas, in fahrenheit?" -}] - - -def get_current_weather(city: str, state: str, unit: 'str'): - return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " - "partly cloudly, with highs in the 90's.") +] + +messages = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, + { + "role": "user", + "content": ( + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" + ), + }, +] + + +def get_current_weather(city: str, state: str, unit: "str"): + return ( + "The weather in Dallas, Texas is 85 degrees fahrenheit. It is " + "partly cloudly, with highs in the 90's." + ) def handle_tool_calls_stream( @@ -82,10 +83,9 @@ def handle_tool_calls_stream( model: str, tools: list[dict[str, Any]], ) -> list[Any]: - tool_calls_stream = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - stream=True) + tool_calls_stream = client.chat.completions.create( + messages=messages, model=model, tools=tools, stream=True + ) chunks = [] print("chunks: ") for chunk in tool_calls_stream: @@ -106,8 +106,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]: tool_call = chunk.choices[0].delta.tool_calls[0] if tool_call.index != tool_call_idx: if tool_call_idx >= 0: - print(f"streamed tool call arguments: " - f"{arguments[tool_call_idx]}") + print(f"streamed tool call arguments: {arguments[tool_call_idx]}") tool_call_idx = chunk.choices[0].delta.tool_calls[0].index arguments.append("") if tool_call.id: @@ -115,8 +114,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]: if tool_call.function: if tool_call.function.name: - print( - f"streamed tool call name: {tool_call.function.name}") + print(f"streamed tool call name: {tool_call.function.name}") if tool_call.function.arguments: arguments[tool_call_idx] += tool_call.function.arguments @@ -136,9 +134,9 @@ def main(): models = client.models.list() model = models.data[0].id - chat_completion = client.chat.completions.create(messages=messages, - model=model, - tools=tools) + chat_completion = client.chat.completions.create( + messages=messages, model=model, tools=tools + ) print("-" * 70) print("Chat completion results:") @@ -158,10 +156,12 @@ def main(): print("-" * 70) # Add tool call results to the conversation - messages.append({ - "role": "assistant", - "tool_calls": chat_completion.choices[0].message.tool_calls - }) + messages.append( + { + "role": "assistant", + "tool_calls": chat_completion.choices[0].message.tool_calls, + } + ) # Now, simulate a tool call available_tools = {"get_current_weather": get_current_weather} @@ -172,17 +172,18 @@ def main(): args = json.loads(call.function.arguments) result = tool_to_call(**args) print("tool_to_call result: ", result) - messages.append({ - "role": "tool", - "content": result, - "tool_call_id": call.id, - "name": call.function.name - }) - - chat_completion_2 = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - stream=False) + messages.append( + { + "role": "tool", + "content": result, + "tool_call_id": call.id, + "name": call.function.name, + } + ) + + chat_completion_2 = client.chat.completions.create( + messages=messages, model=model, tools=tools, stream=False + ) print("Chat completion2 results:") print(chat_completion_2) print("-" * 70) diff --git a/examples/online_serving/openai_chat_completion_client_with_tools_required.py b/examples/online_serving/openai_chat_completion_client_with_tools_required.py index 97d900bb75f1..45c4232fe1de 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools_required.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools_required.py @@ -28,18 +28,16 @@ "type": "object", "properties": { "city": { - "type": - "string", - "description": - "The city to find the weather for" + "type": "string", + "description": "The city to find the weather for" ", e.g. 'San Francisco'", }, "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the " - "city is in, e.g. 'CA' which would mean 'California'", + "type": "string", + "description": ( + "the two-letter abbreviation for the state that the " + "city is in, e.g. 'CA' which would mean 'California'" + ), }, "unit": { "type": "string", @@ -60,22 +58,20 @@ "type": "object", "properties": { "city": { - "type": - "string", - "description": - "The city to get the forecast for, e.g. 'New York'", + "type": "string", + "description": ( + "The city to get the forecast for, e.g. 'New York'" + ), }, "state": { - "type": - "string", - "description": - "The two-letter abbreviation for the state, e.g. 'NY'", + "type": "string", + "description": ( + "The two-letter abbreviation for the state, e.g. 'NY'" + ), }, "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", + "type": "integer", + "description": "Number of days to get the forecast for (1-7)", }, "unit": { "type": "string", @@ -90,19 +86,11 @@ ] messages = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, { "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Dallas \ + "content": "Can you tell me what the current weather is in Dallas \ and the forecast for the next 5 days, in fahrenheit?", }, ] @@ -123,17 +111,16 @@ def main(): model=model, tools=tools, tool_choice="required", - stream=True # Enable streaming response + stream=True, # Enable streaming response ) for chunk in chat_completion: if chunk.choices and chunk.choices[0].delta.tool_calls: print(chunk.choices[0].delta.tool_calls) - chat_completion = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - tool_choice="required") + chat_completion = client.chat.completions.create( + messages=messages, model=model, tools=tools, tool_choice="required" + ) print(chat_completion.choices[0].message.tool_calls) diff --git a/examples/online_serving/openai_chat_completion_structured_outputs.py b/examples/online_serving/openai_chat_completion_structured_outputs.py index 660369e55d40..a4134ea43c4b 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs.py @@ -12,15 +12,17 @@ from openai import BadRequestError, OpenAI from pydantic import BaseModel +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + # Guided decoding by Choice (list of possible options) def guided_choice_completion(client: OpenAI, model: str): completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": "Classify this sentiment: vLLM is wonderful!" - }], + messages=[ + {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} + ], extra_body={"guided_choice": ["positive", "negative"]}, ) return completion.choices[0].message.content @@ -28,20 +30,21 @@ def guided_choice_completion(client: OpenAI, model: str): # Guided decoding by Regex def guided_regex_completion(client: OpenAI, model: str): - prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + prompt = ( + "Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n" + ) completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={ - "guided_regex": r"\w+@\w+\.com\n", - "stop": ["\n"] - }, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]}, ) return completion.choices[0].message.content @@ -63,14 +66,18 @@ class CarDescription(BaseModel): def guided_json_completion(client: OpenAI, model: str): json_schema = CarDescription.model_json_schema() - prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") + prompt = ( + "Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's" + ) completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": prompt, - }], + messages=[ + { + "role": "user", + "content": prompt, + } + ], extra_body={"guided_json": json_schema}, ) return completion.choices[0].message.content @@ -92,14 +99,18 @@ def guided_grammar_completion(client: OpenAI, model: str): number ::= "1 " | "2 " """ - prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") + prompt = ( + "Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table." + ) completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": prompt, - }], + messages=[ + { + "role": "user", + "content": prompt, + } + ], extra_body={"guided_grammar": simplified_sql_grammar}, ) return completion.choices[0].message.content @@ -107,19 +118,23 @@ def guided_grammar_completion(client: OpenAI, model: str): # Extra backend options def extra_backend_options_completion(client: OpenAI, model: str): - prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + prompt = ( + "Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n" + ) try: # The guided_decoding_disable_fallback option forces vLLM to use # xgrammar, so when it fails you get a 400 with the reason why completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": prompt, - }], + messages=[ + { + "role": "user", + "content": prompt, + } + ], extra_body={ "guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"], @@ -134,8 +149,8 @@ def extra_backend_options_completion(client: OpenAI, model: str): def main(): client: OpenAI = OpenAI( - base_url="http://localhost:8000/v1", - api_key="-", + base_url=openai_api_base, + api_key=openai_api_key, ) model = client.models.list().data[0].id diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py index 42aa12c451c0..c73208abe600 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py @@ -7,18 +7,20 @@ # to enforce the format of a tool call response, but it could be used for # any structured output within a subset of the response. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + def main(): client = OpenAI( - base_url="http://localhost:8000/v1", - api_key="-", + base_url=openai_api_base, + api_key=openai_api_key, ) - messages = [{ - "role": - "user", - "content": - """ + messages = [ + { + "role": "user", + "content": """ You have access to the following function to retrieve the weather in a city: { @@ -55,29 +57,28 @@ def main(): Given the previous instructions, what is the weather in New York City, Boston, and San Francisco? -""" - }] +""", + } + ] response = client.chat.completions.create( model=client.models.list().data[0].id, messages=messages, response_format={ - "type": - "structural_tag", - "structures": [{ - "begin": "<function=get_weather>", - "schema": { - "type": "object", - "properties": { - "city": { - "type": "string" - } - } - }, - "end": "</function>" - }], - "triggers": ["<function="] - }) + "type": "structural_tag", + "structures": [ + { + "begin": "<function=get_weather>", + "schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + "end": "</function>", + } + ], + "triggers": ["<function="], + }, + ) print(response) diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py index a04f0cdf12f7..1ca61a8d5895 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py @@ -27,21 +27,22 @@ def print_completion_details(completion): - print("reasoning_content: ", - completion.choices[0].message.reasoning_content) + print("reasoning_content: ", completion.choices[0].message.reasoning_content) print("content: ", completion.choices[0].message.content) # Guided decoding by Regex def guided_regex_completion(client: OpenAI, model: str): - prompt = ("What is the capital of France?") + prompt = "What is the capital of France?" completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": prompt, - }], + messages=[ + { + "role": "user", + "content": prompt, + } + ], extra_body={ "guided_regex": "(Paris|London)", }, @@ -57,13 +58,15 @@ class People(BaseModel): def guided_json_completion(client: OpenAI, model: str): json_schema = People.model_json_schema() - prompt = ("Generate a JSON with the name and age of one random person.") + prompt = "Generate a JSON with the name and age of one random person." completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": prompt, - }], + messages=[ + { + "role": "user", + "content": prompt, + } + ], extra_body={"guided_json": json_schema}, ) print_completion_details(completion) @@ -86,14 +89,18 @@ class CarDescription(BaseModel): def guided_car_json_completion(client: OpenAI, model: str): json_schema = CarDescription.model_json_schema() - prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") + prompt = ( + "Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's" + ) completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": prompt, - }], + messages=[ + { + "role": "user", + "content": prompt, + } + ], extra_body={"guided_json": json_schema}, ) print_completion_details(completion) @@ -116,14 +123,18 @@ def guided_grammar_completion(client: OpenAI, model: str): """ # This may be very slow https://github.com/vllm-project/vllm/issues/12122 - prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") + prompt = ( + "Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table." + ) completion = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": prompt, - }], + messages=[ + { + "role": "user", + "content": prompt, + } + ], extra_body={"guided_grammar": simplified_sql_grammar}, ) print_completion_details(completion) diff --git a/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py b/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py index 9417abd3989a..a5febad45863 100644 --- a/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py @@ -20,9 +20,11 @@ # Now, simulate a tool call -def get_current_weather(city: str, state: str, unit: 'str'): - return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " - "partly cloudly, with highs in the 90's.") +def get_current_weather(city: str, state: str, unit: "str"): + return ( + "The weather in Dallas, Texas is 85 degrees fahrenheit. It is " + "partly cloudly, with highs in the 90's." + ) available_tools = {"get_current_weather": get_current_weather} @@ -31,49 +33,47 @@ def get_current_weather(city: str, state: str, unit: 'str'): openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -tools = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" - }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } +properties = { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, +} + +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": properties, + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, } -}] -messages = [{ - "role": "user", - "content": "Hi! How are you doing today?" -}, { - "role": "assistant", - "content": "I'm doing well! How can I help you?" -}, { - "role": - "user", - "content": - "Can you tell me what the temperate will be in Dallas, in fahrenheit?" -}] +] +messages = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, + { + "role": "user", + "content": ( + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" + ), + }, +] def extract_reasoning_and_calls(chunks: list): @@ -110,73 +110,55 @@ def main(): models = client.models.list() model = models.data[0].id + print("---------Full Generate With Automatic Function Calling-------------") + tool_calls = client.chat.completions.create( + messages=messages, model=model, tools=tools + ) + print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}") + print(f"function name: {tool_calls.choices[0].message.tool_calls[0].function.name}") print( - "---------Full Generate With Automatic Function Calling-------------") - tool_calls = client.chat.completions.create(messages=messages, - model=model, - tools=tools) - print( - f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}" + f"function arguments: " + f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}" ) - print(f"function name: " - f"{tool_calls.choices[0].message.tool_calls[0].function.name}") - print(f"function arguments: " - f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}") - print( - "----------Stream Generate With Automatic Function Calling-----------") - tool_calls_stream = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - stream=True) + print("----------Stream Generate With Automatic Function Calling-----------") + tool_calls_stream = client.chat.completions.create( + messages=messages, model=model, tools=tools, stream=True + ) chunks = list(tool_calls_stream) - reasoning_content, arguments, function_names = extract_reasoning_and_calls( - chunks) + reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks) print(f"reasoning_content: {reasoning_content}") print(f"function name: {function_names[0]}") print(f"function arguments: {arguments[0]}") - print( - "----------Full Generate With Named Function Calling-----------------") - tool_calls = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - tool_choice={ - "type": "function", - "function": { - "name": - "get_current_weather" - } - }) + print("----------Full Generate With Named Function Calling-----------------") + tool_calls = client.chat.completions.create( + messages=messages, + model=model, + tools=tools, + tool_choice={"type": "function", "function": {"name": "get_current_weather"}}, + ) tool_call = tool_calls.choices[0].message.tool_calls[0].function - print( - f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}" - ) + print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}") print(f"function name: {tool_call.name}") print(f"function arguments: {tool_call.arguments}") - print( - "----------Stream Generate With Named Function Calling--------------") + print("----------Stream Generate With Named Function Calling--------------") tool_calls_stream = client.chat.completions.create( messages=messages, model=model, tools=tools, - tool_choice={ - "type": "function", - "function": { - "name": "get_current_weather" - } - }, - stream=True) + tool_choice={"type": "function", "function": {"name": "get_current_weather"}}, + stream=True, + ) chunks = list(tool_calls_stream) - reasoning_content, arguments, function_names = extract_reasoning_and_calls( - chunks) + reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks) print(f"reasoning_content: {reasoning_content}") print(f"function name: {function_names[0]}") print(f"function arguments: {arguments[0]}") diff --git a/examples/online_serving/openai_chat_completion_with_reasoning.py b/examples/online_serving/openai_chat_completion_with_reasoning.py index 4bf7731cb41e..f6b8082115f1 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning.py @@ -45,12 +45,12 @@ def main(): # Round 2 messages.append({"role": "assistant", "content": content}) - messages.append({ - "role": - "user", - "content": - "How many Rs are there in the word 'strawberry'?", - }) + messages.append( + { + "role": "user", + "content": "How many Rs are there in the word 'strawberry'?", + } + ) response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content diff --git a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py index 9cc0a5f2476b..f984fbabf24f 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py @@ -43,9 +43,7 @@ def main(): # ruff: noqa: E501 # For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}` - stream = client.chat.completions.create(model=model, - messages=messages, - stream=True) + stream = client.chat.completions.create(model=model, messages=messages, stream=True) print("client: Start streaming chat completions...") printed_reasoning_content = False diff --git a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py index c850b5aa2f80..ee519e555ff7 100644 --- a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py @@ -14,26 +14,17 @@ def vlm2vec(): response = requests.post( "http://localhost:8000/v1/embeddings", json={ - "model": - "TIGER-Lab/VLM2Vec-Full", - "messages": [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Represent the given image." - }, - ], - }], - "encoding_format": - "float", + "model": "TIGER-Lab/VLM2Vec-Full", + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Represent the given image."}, + ], + } + ], + "encoding_format": "float", }, ) response.raise_for_status() @@ -45,19 +36,20 @@ def vlm2vec(): def dse_qwen2_vl(inp: dict): # Embedding an Image if inp["type"] == "image": - messages = [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": inp["image_url"], - } - }, { - "type": "text", - "text": "What is shown in this image?" - }] - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": inp["image_url"], + }, + }, + {"type": "text", "text": "What is shown in this image?"}, + ], + } + ] # Embedding a Text Query else: # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image @@ -66,23 +58,21 @@ def dse_qwen2_vl(inp: dict): image_placeholder = Image.new("RGB", (56, 56)) image_placeholder.save(buffer, "png") buffer.seek(0) - image_placeholder = base64.b64encode(buffer.read()).decode('utf-8') - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_placeholder}", - } - }, - { - "type": "text", - "text": f"Query: {inp['content']}" - }, - ] - }] + image_placeholder = base64.b64encode(buffer.read()).decode("utf-8") + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_placeholder}", + }, + }, + {"type": "text", "text": f"Query: {inp['content']}"}, + ], + } + ] response = requests.post( "http://localhost:8000/v1/embeddings", @@ -101,12 +91,15 @@ def dse_qwen2_vl(inp: dict): def parse_args(): parser = argparse.ArgumentParser( "Script to call a specified VLM through the API. Make sure to serve " - "the model with --task embed before running this.") - parser.add_argument("--model", - type=str, - choices=["vlm2vec", "dse_qwen2_vl"], - required=True, - help="Which model to call.") + "the model with --task embed before running this." + ) + parser.add_argument( + "--model", + type=str, + choices=["vlm2vec", "dse_qwen2_vl"], + required=True, + help="Which model to call.", + ) return parser.parse_args() @@ -114,16 +107,20 @@ def main(args): if args.model == "vlm2vec": vlm2vec() elif args.model == "dse_qwen2_vl": - dse_qwen2_vl({ - "type": "image", - "image_url": image_url, - }) - dse_qwen2_vl({ - "type": "text", - "content": "What is the weather like today?", - }) + dse_qwen2_vl( + { + "type": "image", + "image_url": image_url, + } + ) + dse_qwen2_vl( + { + "type": "text", + "content": "What is the weather like today?", + } + ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/examples/online_serving/openai_classification_client.py b/examples/online_serving/openai_classification_client.py new file mode 100644 index 000000000000..649cfa5d6686 --- /dev/null +++ b/examples/online_serving/openai_classification_client.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import pprint + +import requests + + +def post_http_request(payload: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=payload) + return response + + +def parse_args(): + parse = argparse.ArgumentParser() + parse.add_argument("--host", type=str, default="localhost") + parse.add_argument("--port", type=int, default=8000) + parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach") + return parse.parse_args() + + +def main(args): + host = args.host + port = args.port + model_name = args.model + + api_url = f"http://{host}:{port}/classify" + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + payload = { + "model": model_name, + "input": prompts, + } + + classify_response = post_http_request(payload=payload, api_url=api_url) + pprint.pprint(classify_response.json()) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 6ab7619bff19..b1d21b5e4b9f 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import argparse + from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. @@ -7,7 +9,15 @@ openai_api_base = "http://localhost:8000/v1" -def main(): +def parse_args(): + parser = argparse.ArgumentParser(description="Client for vLLM API server") + parser.add_argument( + "--stream", action="store_true", help="Enable streaming response" + ) + return parser.parse_args() + + +def main(args): client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, @@ -18,18 +28,18 @@ def main(): model = models.data[0].id # Completion API - stream = False completion = client.completions.create( model=model, prompt="A robot may not injure a human being", echo=False, n=2, - stream=stream, - logprobs=3) + stream=args.stream, + logprobs=3, + ) print("-" * 50) print("Completion results:") - if stream: + if args.stream: for c in completion: print(c) else: @@ -38,4 +48,5 @@ def main(): if __name__ == "__main__": - main() + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_cross_encoder_score.py b/examples/online_serving/openai_cross_encoder_score.py index 20a64ddb2141..7891e14cb71e 100644 --- a/examples/online_serving/openai_cross_encoder_score.py +++ b/examples/online_serving/openai_cross_encoder_score.py @@ -4,6 +4,7 @@ Run `vllm serve <model> --task score` to start up the server in vLLM. """ + import argparse import pprint @@ -38,9 +39,7 @@ def main(args): pprint.pprint(score_response.json()) text_1 = "What is the capital of France?" - text_2 = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." - ] + text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."] prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} score_response = post_http_request(prompt=prompt, api_url=api_url) print("\nPrompt when text_1 is string and text_2 is a list:") @@ -48,12 +47,8 @@ def main(args): print("\nScore Response:") pprint.pprint(score_response.json()) - text_1 = [ - "What is the capital of Brazil?", "What is the capital of France?" - ] - text_2 = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." - ] + text_1 = ["What is the capital of Brazil?", "What is the capital of France?"] + text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."] prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} score_response = post_http_request(prompt=prompt, api_url=api_url) print("\nPrompt when text_1 and text_2 are both lists:") diff --git a/examples/online_serving/openai_embedding_client.py b/examples/online_serving/openai_embedding_client.py index bc217f7ca7a0..a055654e9133 100644 --- a/examples/online_serving/openai_embedding_client.py +++ b/examples/online_serving/openai_embedding_client.py @@ -21,7 +21,7 @@ def main(): # ruff: noqa: E501 input=[ "Hello my name is", - "The best thing about vLLM is that it supports many different models" + "The best thing about vLLM is that it supports many different models", ], model=model, ) diff --git a/examples/online_serving/openai_pooling_client.py b/examples/online_serving/openai_pooling_client.py index abcfe27c2769..2620a1232024 100644 --- a/examples/online_serving/openai_pooling_client.py +++ b/examples/online_serving/openai_pooling_client.py @@ -5,6 +5,7 @@ Run `vllm serve <model> --task <embed|classify|reward|score>` to start up the server in vLLM. """ + import argparse import pprint @@ -21,9 +22,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--model", - type=str, - default="jason9693/Qwen2.5-1.5B-apeach") + parser.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach") return parser.parse_args() @@ -42,15 +41,13 @@ def main(args): # Input like Chat API prompt = { - "model": - model_name, - "messages": [{ - "role": "user", - "content": [{ - "type": "text", - "text": "vLLM is great!" - }], - }] + "model": model_name, + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "vLLM is great!"}], + } + ], } pooling_response = post_http_request(prompt=prompt, api_url=api_url) print("Pooling Response:") diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 66e622672ef2..eb501ae72aa9 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -7,8 +7,8 @@ from vllm.assets.audio import AudioAsset -mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() -winning_call = AudioAsset('winning_call').get_local_path() +mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path() +winning_call = AudioAsset("winning_call").get_local_path() # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" @@ -31,7 +31,8 @@ def sync_openai(): extra_body=dict( seed=4419, repetition_penalty=1.3, - )) + ), + ) print("transcription result:", transcription.text) @@ -42,33 +43,30 @@ def sync_openai(): async def stream_openai_response(): data = { "language": "en", - 'stream': True, + "stream": True, "model": "openai/whisper-large-v3", } url = openai_api_base + "/audio/transcriptions" headers = {"Authorization": f"Bearer {openai_api_key}"} - print("transcription result:", end=' ') + print("transcription result:", end=" ") async with httpx.AsyncClient() as client: with open(str(winning_call), "rb") as f: - async with client.stream('POST', - url, - files={'file': f}, - data=data, - headers=headers) as response: + async with client.stream( + "POST", url, files={"file": f}, data=data, headers=headers + ) as response: async for line in response.aiter_lines(): # Each line is a JSON object prefixed with 'data: ' if line: - if line.startswith('data: '): - line = line[len('data: '):] + if line.startswith("data: "): + line = line[len("data: ") :] # Last chunk, stream ends - if line.strip() == '[DONE]': + if line.strip() == "[DONE]": break # Parse the JSON response chunk = json.loads(line) # Extract and print the content - content = chunk['choices'][0].get('delta', - {}).get('content') - print(content, end='') + content = chunk["choices"][0].get("delta", {}).get("content") + print(content, end="") # Run the asynchronous function diff --git a/examples/online_serving/opentelemetry/Otel.md b/examples/online_serving/opentelemetry/README.md similarity index 100% rename from examples/online_serving/opentelemetry/Otel.md rename to examples/online_serving/opentelemetry/README.md diff --git a/examples/online_serving/opentelemetry/dummy_client.py b/examples/online_serving/opentelemetry/dummy_client.py index a8b353090d79..33d365f0caa5 100644 --- a/examples/online_serving/opentelemetry/dummy_client.py +++ b/examples/online_serving/opentelemetry/dummy_client.py @@ -1,14 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import requests -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( - OTLPSpanExporter) +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import (BatchSpanProcessor, - ConsoleSpanExporter) +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter from opentelemetry.trace import SpanKind, set_tracer_provider -from opentelemetry.trace.propagation.tracecontext import ( - TraceContextTextMapPropagator) +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator trace_provider = TracerProvider() set_tracer_provider(trace_provider) diff --git a/examples/online_serving/prompt_embed_inference_with_openai_client.py b/examples/online_serving/prompt_embed_inference_with_openai_client.py new file mode 100644 index 000000000000..85ea2340736e --- /dev/null +++ b/examples/online_serving/prompt_embed_inference_with_openai_client.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +vLLM OpenAI-Compatible Client with Prompt Embeddings + +This script demonstrates how to: +1. Generate prompt embeddings using Hugging Face Transformers +2. Encode them in base64 format +3. Send them to a vLLM server via the OpenAI-compatible Completions API + +Run the vLLM server first: +vllm serve meta-llama/Llama-3.2-1B-Instruct \ + --task generate \ + --max-model-len 4096 \ + --enable-prompt-embeds + +Run the client: +python examples/online_serving/prompt_embed_inference_with_openai_client.py + +Model: meta-llama/Llama-3.2-1B-Instruct +Note: This model is gated on Hugging Face Hub. + You must request access to use it: + https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct + +Dependencies: +- transformers +- torch +- openai +""" + +import base64 +import io + +import torch +import transformers +from openai import OpenAI + + +def main(): + client = OpenAI( + api_key="EMPTY", + base_url="http://localhost:8000/v1", + ) + + model_name = "meta-llama/Llama-3.2-1B-Instruct" + + # Transformers + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) + + # Refer to the HuggingFace repo for the correct format to use + chat = [{"role": "user", "content": "Please tell me about the capital of France."}] + token_ids = tokenizer.apply_chat_template( + chat, add_generation_prompt=True, return_tensors="pt" + ) + + embedding_layer = transformers_model.get_input_embeddings() + prompt_embeds = embedding_layer(token_ids).squeeze(0) + + # Prompt embeddings + buffer = io.BytesIO() + torch.save(prompt_embeds, buffer) + buffer.seek(0) + binary_data = buffer.read() + encoded_embeds = base64.b64encode(binary_data).decode("utf-8") + + completion = client.completions.create( + model=model_name, + # NOTE: The OpenAI client does not allow `None` as an input to + # `prompt`. Use an empty string if you have no text prompts. + prompt="", + max_tokens=5, + temperature=0.0, + # NOTE: The OpenAI client allows passing in extra JSON body via the + # `extra_body` argument. + extra_body={"prompt_embeds": encoded_embeds}, + ) + + print("-" * 30) + print(completion.choices[0].text) + print("-" * 30) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/ray_serve_deepseek.py b/examples/online_serving/ray_serve_deepseek.py index f9ef3e2da1a1..a76020130c3a 100644 --- a/examples/online_serving/ray_serve_deepseek.py +++ b/examples/online_serving/ray_serve_deepseek.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """ Example to deploy DeepSeek R1 or V3 with Ray Serve LLM. -See Ray Serve LLM documentation at: +See more details at: +https://docs.ray.io/en/latest/serve/tutorials/serve-deepseek.html +And see Ray Serve LLM documentation at: https://docs.ray.io/en/latest/serve/llm/serving-llms.html Run `python3 ray_serve_deepseek.py` to deploy the model. @@ -26,9 +28,7 @@ }, # Change to the accelerator type of the node accelerator_type="H100", - runtime_env={"env_vars": { - "VLLM_USE_V1": "1" - }}, + runtime_env={"env_vars": {"VLLM_USE_V1": "1"}}, # Customize engine arguments as needed (e.g. vLLM engine kwargs) engine_kwargs={ "tensor_parallel_size": 8, diff --git a/examples/online_serving/retrieval_augmented_generation_with_langchain.py b/examples/online_serving/retrieval_augmented_generation_with_langchain.py index 73063065cb36..37af3b3887f5 100644 --- a/examples/online_serving/retrieval_augmented_generation_with_langchain.py +++ b/examples/online_serving/retrieval_augmented_generation_with_langchain.py @@ -55,7 +55,7 @@ def load_and_split_documents(config: dict[str, Any]): Load and split documents from web URL """ try: - loader = WebBaseLoader(web_paths=(config["url"], )) + loader = WebBaseLoader(web_paths=(config["url"],)) docs = loader.load() text_splitter = RecursiveCharacterTextSplitter( @@ -121,64 +121,71 @@ def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate): """ Set up question answering chain """ - return ({ - "context": retriever | format_docs, - "question": RunnablePassthrough(), - } - | prompt - | llm - | StrOutputParser()) + return ( + { + "context": retriever | format_docs, + "question": RunnablePassthrough(), + } + | prompt + | llm + | StrOutputParser() + ) def get_parser() -> argparse.ArgumentParser: """ Parse command line arguments """ - parser = argparse.ArgumentParser(description='RAG with vLLM and langchain') + parser = argparse.ArgumentParser(description="RAG with vLLM and langchain") # Add command line arguments - parser.add_argument('--vllm-api-key', - default="EMPTY", - help='API key for vLLM compatible services') - parser.add_argument('--vllm-embedding-endpoint', - default="http://localhost:8000/v1", - help='Base URL for embedding service') - parser.add_argument('--vllm-chat-endpoint', - default="http://localhost:8001/v1", - help='Base URL for chat service') - parser.add_argument('--uri', - default="./milvus.db", - help='URI for Milvus database') parser.add_argument( - '--url', - default=("https://docs.vllm.ai/en/latest/getting_started/" - "quickstart.html"), - help='URL of the document to process') - parser.add_argument('--embedding-model', - default="ssmits/Qwen2-7B-Instruct-embed-base", - help='Model name for embeddings') - parser.add_argument('--chat-model', - default="qwen/Qwen1.5-0.5B-Chat", - help='Model name for chat') - parser.add_argument('-i', - '--interactive', - action='store_true', - help='Enable interactive Q&A mode') - parser.add_argument('-k', - '--top-k', - type=int, - default=3, - help='Number of top results to retrieve') - parser.add_argument('-c', - '--chunk-size', - type=int, - default=1000, - help='Chunk size for document splitting') - parser.add_argument('-o', - '--chunk-overlap', - type=int, - default=200, - help='Chunk overlap for document splitting') + "--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services" + ) + parser.add_argument( + "--vllm-embedding-endpoint", + default="http://localhost:8000/v1", + help="Base URL for embedding service", + ) + parser.add_argument( + "--vllm-chat-endpoint", + default="http://localhost:8001/v1", + help="Base URL for chat service", + ) + parser.add_argument("--uri", default="./milvus.db", help="URI for Milvus database") + parser.add_argument( + "--url", + default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"), + help="URL of the document to process", + ) + parser.add_argument( + "--embedding-model", + default="ssmits/Qwen2-7B-Instruct-embed-base", + help="Model name for embeddings", + ) + parser.add_argument( + "--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat" + ) + parser.add_argument( + "-i", "--interactive", action="store_true", help="Enable interactive Q&A mode" + ) + parser.add_argument( + "-k", "--top-k", type=int, default=3, help="Number of top results to retrieve" + ) + parser.add_argument( + "-c", + "--chunk-size", + type=int, + default=1000, + help="Chunk size for document splitting", + ) + parser.add_argument( + "-o", + "--chunk-overlap", + type=int, + default=200, + help="Chunk overlap for document splitting", + ) return parser @@ -198,7 +205,7 @@ def init_config(args: Namespace): "url": args.url, "chunk_size": args.chunk_size, "chunk_overlap": args.chunk_overlap, - "top_k": args.top_k + "top_k": args.top_k, } @@ -230,7 +237,7 @@ def main(): while True: question = input("\nPlease enter your question: ") - if question.lower() in ['q', 'quit']: + if question.lower() in ["q", "quit"]: print("\nThank you for using! Goodbye!") break @@ -238,7 +245,7 @@ def main(): print(output) else: # Default single question mode - question = ("How to install vLLM?") + question = "How to install vLLM?" output = qa_chain.invoke(question) print("-" * 50) print(output) diff --git a/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py b/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py index a8f76dfe4c69..08796b1b3a54 100644 --- a/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py +++ b/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py @@ -35,6 +35,7 @@ - Default ports: 8000 (embedding), 8001 (chat) - First run may take time to download models """ + import argparse from argparse import Namespace from typing import Any @@ -59,7 +60,7 @@ def init_config(args: Namespace): "db_path": args.db_path, "chunk_size": args.chunk_size, "chunk_overlap": args.chunk_overlap, - "top_k": args.top_k + "top_k": args.top_k, } @@ -117,52 +118,58 @@ def query_document(index: VectorStoreIndex, question: str, top_k: int): def get_parser() -> argparse.ArgumentParser: """Parse command line arguments""" - parser = argparse.ArgumentParser( - description='RAG with vLLM and LlamaIndex') + parser = argparse.ArgumentParser(description="RAG with vLLM and LlamaIndex") # Add command line arguments parser.add_argument( - '--url', - default=("https://docs.vllm.ai/en/latest/getting_started/" - "quickstart.html"), - help='URL of the document to process') - parser.add_argument('--embedding-model', - default="ssmits/Qwen2-7B-Instruct-embed-base", - help='Model name for embeddings') - parser.add_argument('--chat-model', - default="qwen/Qwen1.5-0.5B-Chat", - help='Model name for chat') - parser.add_argument('--vllm-api-key', - default="EMPTY", - help='API key for vLLM compatible services') - parser.add_argument('--embedding-endpoint', - default="http://localhost:8000/v1", - help='Base URL for embedding service') - parser.add_argument('--chat-endpoint', - default="http://localhost:8001/v1", - help='Base URL for chat service') - parser.add_argument('--db-path', - default="./milvus_demo.db", - help='Path to Milvus database') - parser.add_argument('-i', - '--interactive', - action='store_true', - help='Enable interactive Q&A mode') - parser.add_argument('-c', - '--chunk-size', - type=int, - default=1000, - help='Chunk size for document splitting') - parser.add_argument('-o', - '--chunk-overlap', - type=int, - default=200, - help='Chunk overlap for document splitting') - parser.add_argument('-k', - '--top-k', - type=int, - default=3, - help='Number of top results to retrieve') + "--url", + default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"), + help="URL of the document to process", + ) + parser.add_argument( + "--embedding-model", + default="ssmits/Qwen2-7B-Instruct-embed-base", + help="Model name for embeddings", + ) + parser.add_argument( + "--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat" + ) + parser.add_argument( + "--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services" + ) + parser.add_argument( + "--embedding-endpoint", + default="http://localhost:8000/v1", + help="Base URL for embedding service", + ) + parser.add_argument( + "--chat-endpoint", + default="http://localhost:8001/v1", + help="Base URL for chat service", + ) + parser.add_argument( + "--db-path", default="./milvus_demo.db", help="Path to Milvus database" + ) + parser.add_argument( + "-i", "--interactive", action="store_true", help="Enable interactive Q&A mode" + ) + parser.add_argument( + "-c", + "--chunk-size", + type=int, + default=1000, + help="Chunk size for document splitting", + ) + parser.add_argument( + "-o", + "--chunk-overlap", + type=int, + default=200, + help="Chunk overlap for document splitting", + ) + parser.add_argument( + "-k", "--top-k", type=int, default=3, help="Number of top results to retrieve" + ) return parser @@ -193,7 +200,7 @@ def main(): question = input("\nEnter your question: ") # Check for exit command - if question.lower() in ['quit', 'exit', 'q']: + if question.lower() in ["quit", "exit", "q"]: print("Exiting interactive mode...") break diff --git a/examples/online_serving/streamlit_openai_chatbot_webserver.py b/examples/online_serving/streamlit_openai_chatbot_webserver.py index d8a0f211d44d..0722aa671f66 100644 --- a/examples/online_serving/streamlit_openai_chatbot_webserver.py +++ b/examples/online_serving/streamlit_openai_chatbot_webserver.py @@ -26,6 +26,7 @@ streamlit run streamlit_openai_chatbot_webserver.py \ --logger.level=debug """ + import os from datetime import datetime @@ -33,8 +34,8 @@ from openai import OpenAI # Get command line arguments from environment variables -openai_api_key = os.getenv('VLLM_API_KEY', "EMPTY") -openai_api_base = os.getenv('VLLM_API_BASE', "http://localhost:8000/v1") +openai_api_key = os.getenv("VLLM_API_KEY", "EMPTY") +openai_api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1") # Initialize session states for managing chat sessions if "sessions" not in st.session_state: @@ -81,9 +82,9 @@ def get_llm_response(messages, model): Streaming response object or error message string """ try: - response = client.chat.completions.create(model=model, - messages=messages, - stream=True) + response = client.chat.completions.create( + model=model, messages=messages, stream=True + ) return response except Exception as e: st.error(f"Error details: {str(e)}") @@ -92,8 +93,9 @@ def get_llm_response(messages, model): # Sidebar - API Settings first st.sidebar.title("API Settings") -new_api_base = st.sidebar.text_input("API Base URL:", - value=st.session_state.api_base_url) +new_api_base = st.sidebar.text_input( + "API Base URL:", value=st.session_state.api_base_url +) if new_api_base != st.session_state.api_base_url: st.session_state.api_base_url = new_api_base st.rerun() @@ -109,16 +111,20 @@ def get_llm_response(messages, model): for session_id in sorted(st.session_state.sessions.keys(), reverse=True): # Mark the active session with a pinned button if session_id == st.session_state.active_session: - st.sidebar.button(f"๐Ÿ“ {session_id}", - key=session_id, - type="primary", - on_click=switch_to_chat_session, - args=(session_id, )) + st.sidebar.button( + f"๐Ÿ“ {session_id}", + key=session_id, + type="primary", + on_click=switch_to_chat_session, + args=(session_id,), + ) else: - st.sidebar.button(f"Session {session_id}", - key=session_id, - on_click=switch_to_chat_session, - args=(session_id, )) + st.sidebar.button( + f"Session {session_id}", + key=session_id, + on_click=switch_to_chat_session, + args=(session_id,), + ) # Main interface st.title("vLLM Chat Assistant") @@ -145,18 +151,18 @@ def get_llm_response(messages, model): if prompt := st.chat_input("Type your message here..."): # Save user message to session st.session_state.messages.append({"role": "user", "content": prompt}) - st.session_state.sessions[ - st.session_state.current_session] = st.session_state.messages + st.session_state.sessions[st.session_state.current_session] = ( + st.session_state.messages + ) # Display user message with st.chat_message("user"): st.write(prompt) # Prepare messages for llm - messages_for_llm = [{ - "role": m["role"], - "content": m["content"] - } for m in st.session_state.messages] + messages_for_llm = [ + {"role": m["role"], "content": m["content"]} for m in st.session_state.messages + ] # Generate and display llm response with st.chat_message("assistant"): @@ -179,7 +185,4 @@ def get_llm_response(messages, model): message_placeholder.markdown(full_response) # Save llm response to session history - st.session_state.messages.append({ - "role": "assistant", - "content": full_response - }) + st.session_state.messages.append({"role": "assistant", "content": full_response}) diff --git a/examples/online_serving/utils.py b/examples/online_serving/utils.py new file mode 100644 index 000000000000..0781a27f19c5 --- /dev/null +++ b/examples/online_serving/utils.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import APIConnectionError, OpenAI +from openai.pagination import SyncPage +from openai.types.model import Model + + +def get_first_model(client: OpenAI) -> str: + """ + Get the first model from the vLLM server. + """ + try: + models: SyncPage[Model] = client.models.list() + except APIConnectionError as e: + raise RuntimeError( + "Failed to get the list of models from the vLLM server at " + f"{client.base_url} with API key {client.api_key}. Check\n" + "1. the server is running\n" + "2. the server URL is correct\n" + "3. the API key is correct" + ) from e + + if len(models.data) == 0: + raise RuntimeError(f"No models found on the vLLM server at {client.base_url}") + + return models.data[0].id diff --git a/examples/lmcache/README.md b/examples/others/lmcache/README.md similarity index 100% rename from examples/lmcache/README.md rename to examples/others/lmcache/README.md diff --git a/examples/lmcache/cpu_offload_lmcache.py b/examples/others/lmcache/cpu_offload_lmcache.py similarity index 85% rename from examples/lmcache/cpu_offload_lmcache.py rename to examples/others/lmcache/cpu_offload_lmcache.py index bf191960b080..98eafb31ed4f 100644 --- a/examples/lmcache/cpu_offload_lmcache.py +++ b/examples/others/lmcache/cpu_offload_lmcache.py @@ -20,6 +20,7 @@ Learn more about LMCache environment setup, please refer to: https://docs.lmcache.ai/getting_started/installation.html """ + import argparse import contextlib import os @@ -34,7 +35,7 @@ from vllm.engine.arg_utils import EngineArgs -def setup_environment_variables(): +def setup_environment_variables(vllm_version: str): # LMCache-related environment variables # Use experimental features in LMCache os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" @@ -44,11 +45,12 @@ def setup_environment_variables(): os.environ["LMCACHE_LOCAL_CPU"] = "True" # Set local CPU memory limit to 5.0 GB os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" + if vllm_version == "v0": + os.environ["VLLM_USE_V1"] = "0" @contextlib.contextmanager -def build_llm_with_lmcache(lmcache_connector: str, model: str, - vllm_version: str): +def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str): ktc = KVTransferConfig( kv_connector=lmcache_connector, kv_role="kv_both", @@ -95,18 +97,19 @@ def print_output( for output in outputs: generated_text = output.outputs[0].text print(f"Generated text: {generated_text!r}") - print(f"Generation took {time.time() - start:.2f} seconds, " - f"{req_str} request done.") + print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.") print("-" * 50) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("-v", - "--version", - choices=["v0", "v1"], - default="v1", - help="Specify vLLM version (default: v1)") + parser.add_argument( + "-v", + "--version", + choices=["v0", "v1"], + default="v1", + help="Specify vLLM version (default: v1)", + ) return parser.parse_args() @@ -120,10 +123,9 @@ def main(): lmcache_connector = "LMCacheConnectorV1" model = "meta-llama/Meta-Llama-3.1-8B-Instruct" - setup_environment_variables() + setup_environment_variables(args.version) with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm: - # This example script runs two requests with a shared prefix. # Define the shared prompt and specific prompts shared_prompt = "Hello, how are you?" * 1000 @@ -134,9 +136,7 @@ def main(): shared_prompt + "Tell me a very long story", ] - sampling_params = SamplingParams(temperature=0, - top_p=0.95, - max_tokens=10) + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) # Print the first output print_output(llm, first_prompt, sampling_params, "first") diff --git a/examples/lmcache/disagg_prefill_lmcache_v0.py b/examples/others/lmcache/disagg_prefill_lmcache_v0.py similarity index 79% rename from examples/lmcache/disagg_prefill_lmcache_v0.py rename to examples/others/lmcache/disagg_prefill_lmcache_v0.py index 7da6fb7aaa23..b2b7b3b2c1f9 100644 --- a/examples/lmcache/disagg_prefill_lmcache_v0.py +++ b/examples/others/lmcache/disagg_prefill_lmcache_v0.py @@ -4,12 +4,13 @@ with LMCache. We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), and launch an additional LMCache server. -KV cache is transferred in the following manner: +KV cache is transferred in the following manner: vLLM prefill node -> LMCache server -> vLLM decode node. Note that `pip install lmcache` is needed to run this example. Learn more about LMCache in https://github.com/LMCache/LMCache. """ + import os import subprocess import time @@ -49,18 +50,23 @@ def run_prefill(prefill_done, prompts): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' + ktc = KVTransferConfig( + kv_connector="LMCacheConnector", + kv_role="kv_producer", + kv_rank=0, + kv_parallel_size=2, ) # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. - llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8, - enforce_eager=True) + llm = LLM( + model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) - #llm.generate(prompts, sampling_params) + # llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params) for output in outputs: generated_text = output.outputs[0].text @@ -78,16 +84,21 @@ def run_decode(prefill_done, prompts, timeout=1): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' + ktc = KVTransferConfig( + kv_connector="LMCacheConnector", + kv_role="kv_consumer", + kv_rank=1, + kv_parallel_size=2, ) # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # of memory. Reduce the value if your GPU has less memory. - llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8, - enforce_eager=True) + llm = LLM( + model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) print("Waiting for prefill node to finish...") prefill_done.wait() @@ -103,10 +114,9 @@ def run_decode(prefill_done, prompts, timeout=1): def run_lmcache_server(port): - server_proc = subprocess.Popen([ - "python", "-m", "lmcache.experimental.server", "localhost", - str(port) - ]) + server_proc = subprocess.Popen( + ["python", "-m", "lmcache.experimental.server", "localhost", str(port)] + ) return server_proc diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml b/examples/others/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml similarity index 100% rename from examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml rename to examples/others/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml b/examples/others/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml similarity index 100% rename from examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml rename to examples/others/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh similarity index 100% rename from examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh rename to examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py similarity index 59% rename from examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py rename to examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py index 8db93bc8931b..20155c203658 100644 --- a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py +++ b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py @@ -17,13 +17,17 @@ async def lifespan(app: FastAPI): Lifespan context manager to handle startup and shutdown events. """ # Startup: Initialize clients - prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' - decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' - - app.state.prefill_client = httpx.AsyncClient(timeout=None, - base_url=prefiller_base_url) - app.state.decode_client = httpx.AsyncClient(timeout=None, - base_url=decoder_base_url) + prefiller_base_url = ( + f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1" + ) + decoder_base_url = ( + f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1" + ) + + app.state.prefill_client = httpx.AsyncClient( + timeout=None, base_url=prefiller_base_url + ) + app.state.decode_client = httpx.AsyncClient(timeout=None, base_url=decoder_base_url) yield @@ -37,7 +41,6 @@ async def lifespan(app: FastAPI): class StatsCalculator: - def __init__(self): self._stats = [] self._last_log_time = time.time() @@ -51,13 +54,18 @@ def add(self, value): def _log_stats(self): # Print average, median, and 99th percentile np_arr = np.array(self._stats) - output_str = f"\nNum requests: {len(self._stats)}" + \ - "\nPrefill node TTFT stats:" + \ - f"\n - Average (ms): {np.mean(np_arr)}" + \ - f"\n - Median (ms): {np.median(np_arr)}" + \ - f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" - print("===============================", output_str, - "===============================") + output_str = ( + f"\nNum requests: {len(self._stats)}" + + "\nPrefill node TTFT stats:" + + f"\n - Average (ms): {np.mean(np_arr)}" + + f"\n - Median (ms): {np.median(np_arr)}" + + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + ) + print( + "===============================", + output_str, + "===============================", + ) stats_calculator = StatsCalculator() @@ -82,15 +90,16 @@ def parse_args(): app.state.decode_client = None -async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, - req_data: dict): +async def send_request_to_service( + client: httpx.AsyncClient, endpoint: str, req_data: dict +): """ Send a request to a service using a persistent client. """ req_data = req_data.copy() - req_data['max_tokens'] = 1 - if 'max_completion_tokens' in req_data: - req_data['max_completion_tokens'] = 1 + req_data["max_tokens"] = 1 + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} response = await client.post(endpoint, json=req_data, headers=headers) @@ -98,14 +107,16 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, return response -async def stream_service_response(client: httpx.AsyncClient, endpoint: str, - req_data: dict): +async def stream_service_response( + client: httpx.AsyncClient, endpoint: str, req_data: dict +): """ Asynchronously stream the response from a service using a persistent client. """ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} - async with client.stream("POST", endpoint, json=req_data, - headers=headers) as response: + async with client.stream( + "POST", endpoint, json=req_data, headers=headers + ) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk @@ -121,28 +132,28 @@ async def handle_completions(request: Request): req_data = await request.json() # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, "/completions", - req_data) + await send_request_to_service( + app.state.prefill_client, "/completions", req_data + ) et = time.time() stats_calculator.add(et - st) # Stream response from decode service async def generate_stream(): - async for chunk in stream_service_response(app.state.decode_client, - "/completions", - req_data): + async for chunk in stream_service_response( + app.state.decode_client, "/completions", req_data + ): yield chunk - return StreamingResponse(generate_stream(), - media_type="application/json") + return StreamingResponse(generate_stream(), media_type="text/event-stream") except Exception as e: import sys import traceback + exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server" - " - completions endpoint") + print("Error occurred in disagg prefill proxy server - completions endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise @@ -158,36 +169,39 @@ async def handle_chat_completions(request: Request): req_data = await request.json() # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, - "/chat/completions", req_data) + await send_request_to_service( + app.state.prefill_client, "/chat/completions", req_data + ) et = time.time() stats_calculator.add(et - st) # Stream response from decode service async def generate_stream(): - async for chunk in stream_service_response(app.state.decode_client, - "/chat/completions", - req_data): + async for chunk in stream_service_response( + app.state.decode_client, "/chat/completions", req_data + ): yield chunk - return StreamingResponse(generate_stream(), - media_type="application/json") + return StreamingResponse(generate_stream(), media_type="text/event-stream") except Exception as e: import sys import traceback + exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server " - " - chat completions endpoint") + print( + "Error occurred in disagg prefill proxy server - chat completions endpoint" + ) print(e) print("".join(traceback.format_exception(*exc_info))) raise -if __name__ == '__main__': +if __name__ == "__main__": global global_args global_args = parse_args() import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh similarity index 97% rename from examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh rename to examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh index 831ef0bb574b..5719fa821292 100644 --- a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh +++ b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh @@ -54,6 +54,6 @@ elif [[ $1 == "decoder" ]]; then else echo "Invalid role: $1" - echo "Should be either prefill, decode" + echo "Should be either prefiller, decoder" exit 1 fi diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1.py b/examples/others/lmcache/kv_cache_sharing_lmcache_v1.py similarity index 81% rename from examples/lmcache/kv_cache_sharing_lmcache_v1.py rename to examples/others/lmcache/kv_cache_sharing_lmcache_v1.py index af1b4351dd54..89945d67a6f3 100644 --- a/examples/lmcache/kv_cache_sharing_lmcache_v1.py +++ b/examples/others/lmcache/kv_cache_sharing_lmcache_v1.py @@ -3,13 +3,14 @@ This file demonstrates the example usage of remote KV cache sharing with LMCache. We will launch 2 vllm instances, and launch an additional LMCache server. -KV cache is transferred in the following manner: +KV cache is transferred in the following manner: (1) vLLM instance 1 -> LMCache server (KV cache store). (2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve). Note that lmcache needs to be installed to run this example. Learn more about LMCache in https://github.com/LMCache/LMCache. """ + import os import subprocess import time @@ -49,15 +50,16 @@ def run_store(store_done, prompts): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both") # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. - llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8, - enforce_eager=True) + llm = LLM( + model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) outputs = llm.generate(prompts, sampling_params) for output in outputs: @@ -76,15 +78,16 @@ def run_retrieve(store_done, prompts, timeout=1): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both") # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # of memory. Reduce the value if your GPU has less memory. - llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8, - enforce_eager=True) + llm = LLM( + model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) print("Waiting for KV cache store to finish...") store_done.wait() @@ -100,10 +103,9 @@ def run_retrieve(store_done, prompts, timeout=1): def run_lmcache_server(port): - server_proc = subprocess.Popen([ - "python", "-m", "lmcache.experimental.server", "localhost", - str(port) - ]) + server_proc = subprocess.Popen( + ["python", "-m", "lmcache.experimental.server", "localhost", str(port)] + ) return server_proc diff --git a/examples/other/logging_configuration.md b/examples/others/logging_configuration.md similarity index 100% rename from examples/other/logging_configuration.md rename to examples/others/logging_configuration.md diff --git a/examples/other/tensorize_vllm_model.py b/examples/others/tensorize_vllm_model.py similarity index 72% rename from examples/other/tensorize_vllm_model.py rename to examples/others/tensorize_vllm_model.py index 7d11ba51a094..175777630833 100644 --- a/examples/other/tensorize_vllm_model.py +++ b/examples/others/tensorize_vllm_model.py @@ -6,11 +6,15 @@ import os import uuid -from vllm import LLM +from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs, - TensorizerConfig, - tensorize_vllm_model) +from vllm.lora.request import LoRARequest +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerArgs, + TensorizerConfig, + tensorize_lora_adapter, + tensorize_vllm_model, +) from vllm.utils import FlexibleArgumentParser # yapf conflicts with isort for this docstring @@ -27,7 +31,7 @@ To serialize a model, install vLLM from source, then run something like this from the root level of this repository: -python -m examples.other.tensorize_vllm_model \ +python examples/others/tensorize_vllm_model.py \ --model facebook/opt-125m \ serialize \ --serialized-directory s3://my-bucket \ @@ -47,7 +51,7 @@ To deserialize a model, you can run something like this from the root level of this repository: -python -m examples.other.tensorize_vllm_model \ +python examples/others/tensorize_vllm_model.py \ --model EleutherAI/gpt-j-6B \ --dtype float16 \ deserialize \ @@ -65,11 +69,11 @@ model-rank-%03d.tensors For more information on the available arguments for serializing, run -`python -m examples.other.tensorize_vllm_model serialize --help`. +`python -m examples.others.tensorize_vllm_model serialize --help`. Or for deserializing: -`python -m examples.other.tensorize_vllm_model deserialize --help`. +`python examples/others/tensorize_vllm_model.py deserialize --help`. Once a model is serialized, tensorizer can be invoked with the `LLM` class directly to load models: @@ -90,11 +94,27 @@ In order to see all of the available arguments usable to configure loading with tensorizer that are given to `TensorizerConfig`, run: -`python -m examples.other.tensorize_vllm_model deserialize --help` +`python examples/others/tensorize_vllm_model.py deserialize --help` under the `tensorizer options` section. These can also be used for deserialization in this example script, although `--tensorizer-uri` and `--path-to-tensors` are functionally the same in this case. + +Tensorizer can also be used to save and load LoRA adapters. A LoRA adapter +can be serialized directly with the path to the LoRA adapter on HF Hub and +a TensorizerConfig object. In this script, passing a HF id to a LoRA adapter +will serialize the LoRA adapter artifacts to `--serialized-directory`. + +You can then use the LoRA adapter with `vllm serve`, for instance, by ensuring +the LoRA artifacts are in your model artifacts directory and specifying +`--enable-lora`. For instance: + +``` +vllm serve <model_path> \ + --load-format tensorizer \ + --model-loader-extra-config '{"tensorizer_uri": "<model_path>.tensors"}' \ + --enable-lora +``` """ @@ -107,6 +127,19 @@ def parse_args(): "also supported, although libsodium must be installed to " "use it.") parser = EngineArgs.add_cli_args(parser) + + parser.add_argument( + "--lora-path", + type=str, + required=False, + help="Path to a LoRA adapter to " + "serialize along with model tensors. This can then be deserialized " + "along with the model by passing a tensorizer_config kwarg to " + "LoRARequest with type TensorizerConfig. See the docstring for this " + "for a usage example." + + ) + subparsers = parser.add_subparsers(dest='command') serialize_parser = subparsers.add_parser( @@ -169,11 +202,42 @@ def parse_args(): def deserialize(): - llm = LLM(model=args.model, - load_format="tensorizer", - tensor_parallel_size=args.tensor_parallel_size, - model_loader_extra_config=tensorizer_config - ) + if args.lora_path: + tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir + llm = LLM(model=args.model, + load_format="tensorizer", + tensor_parallel_size=args.tensor_parallel_size, + model_loader_extra_config=tensorizer_config, + enable_lora=True, + ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=256, + stop=["[/assistant]"] + ) + + # Truncating this as the extra text isn't necessary + prompts = [ + "[user] Write a SQL query to answer the question based on ..." + ] + + # Test LoRA load + print( + llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest("sql-lora", + 1, + args.lora_path, + tensorizer_config = tensorizer_config) + ) + ) + else: + llm = LLM(model=args.model, + load_format="tensorizer", + tensor_parallel_size=args.tensor_parallel_size, + model_loader_extra_config=tensorizer_config + ) return llm @@ -197,7 +261,10 @@ def deserialize(): model_name = model_ref.split("/")[1] - keyfile = args.keyfile if args.keyfile else None + if args.command == "serialize" or args.command == "deserialize": + keyfile = args.keyfile + else: + keyfile = None if args.model_loader_extra_config: config = json.loads(args.model_loader_extra_config) @@ -228,6 +295,10 @@ def deserialize(): encryption_keyfile=keyfile, **credentials) + if args.lora_path: + tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir + tensorize_lora_adapter(args.lora_path, tensorizer_config) + tensorize_vllm_model(engine_args, tensorizer_config) elif args.command == "deserialize": diff --git a/examples/pyproject.toml b/examples/pyproject.toml new file mode 100644 index 000000000000..f825cb203269 --- /dev/null +++ b/examples/pyproject.toml @@ -0,0 +1,54 @@ +# This local pyproject file is part of the migration from yapf to ruff format. +# It uses the same core rules as the main pyproject.toml file, but with the +# following differences: +# - ruff line length is overridden to 88 +# - deprecated typing ignores (UP006, UP035) have been removed + +[tool.ruff] +line-length = 88 +exclude = [ + # External file, leaving license intact + "examples/other/fp8/quantizer/quantize.py", + "vllm/vllm_flash_attn/flash_attn_interface.pyi" +] + +[tool.ruff.lint.per-file-ignores] +"vllm/third_party/**" = ["ALL"] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + +[tool.ruff.lint.isort] +known-first-party = ["vllm"] + +[tool.ruff.format] +docstring-code-format = true \ No newline at end of file diff --git a/examples/tool_chat_template_deepseekv3.jinja b/examples/tool_chat_template_deepseekv3.jinja new file mode 100644 index 000000000000..36f3781439ed --- /dev/null +++ b/examples/tool_chat_template_deepseekv3.jinja @@ -0,0 +1,96 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} + +{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %} + +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{{ bos_token }} +{{ ns.system_prompt }} +{%- if tools %} + {{"\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{"\n</tools>\n\n"}} + + {{"For function call returns, you should first print <๏ฝœtoolโ–callsโ–begin๏ฝœ>"}} + + {{"For each function call, you should return object like:\n" }} + {{"<๏ฝœtoolโ–callโ–begin๏ฝœ>function<๏ฝœtoolโ–sep๏ฝœ><function_name>\n```json\n<function_arguments_in_json_format>\n```<๏ฝœtoolโ–callโ–end๏ฝœ>"}} + + {{"At the end of function call returns, you should print <๏ฝœtoolโ–callsโ–end๏ฝœ><๏ฝœendโ–ofโ–sentence๏ฝœ>"}} +{%- endif %} + +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<๏ฝœUser๏ฝœ>' + message['content'] + '<๏ฝœAssistant๏ฝœ>'}} + {%- endif %} + + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{'<๏ฝœtoolโ–outputsโ–end๏ฝœ>'}} + {%- endif %} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- set ns.is_output_first = true %} + + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<๏ฝœtoolโ–callsโ–begin๏ฝœ><๏ฝœtoolโ–callโ–begin๏ฝœ>' + tool['type'] + '<๏ฝœtoolโ–sep๏ฝœ>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<๏ฝœtoolโ–callโ–end๏ฝœ>'}} + {%- else %} + {{message['content'] + '<๏ฝœtoolโ–callsโ–begin๏ฝœ><๏ฝœtoolโ–callโ–begin๏ฝœ>' + tool['type'] + '<๏ฝœtoolโ–sep๏ฝœ>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<๏ฝœtoolโ–callโ–end๏ฝœ>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'\n' + '<๏ฝœtoolโ–callโ–begin๏ฝœ>' + tool['type'] + '<๏ฝœtoolโ–sep๏ฝœ>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<๏ฝœtoolโ–callโ–end๏ฝœ>'}} + {%- endif %} + {%- endfor %} + {{'<๏ฝœtoolโ–callsโ–end๏ฝœ><๏ฝœendโ–ofโ–sentence๏ฝœ>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{'<๏ฝœtoolโ–outputsโ–end๏ฝœ>' + message['content'] + '<๏ฝœendโ–ofโ–sentence๏ฝœ>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {% set content = message['content'] %} + {{content + '<๏ฝœendโ–ofโ–sentence๏ฝœ>'}} + {%- endif %} + {%- endif %} + + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {%- if ns.is_output_first %} + {{'<๏ฝœtoolโ–outputsโ–begin๏ฝœ><๏ฝœtoolโ–outputโ–begin๏ฝœ>' + message['content'] + '<๏ฝœtoolโ–outputโ–end๏ฝœ>'}} + {%- set ns.is_output_first = false %} + {%- else %} + {{'\n<๏ฝœtoolโ–outputโ–begin๏ฝœ>' + message['content'] + '<๏ฝœtoolโ–outputโ–end๏ฝœ>'}} + {%- endif %} + {%- endif %} +{%- endfor -%} + +{% if ns.is_tool %} + {{'<๏ฝœtoolโ–outputsโ–end๏ฝœ>'}} +{% endif %} + +{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %} + {{'<๏ฝœAssistant๏ฝœ>'}} +{% endif %} diff --git a/examples/tool_chat_template_llama4_pythonic.jinja b/examples/tool_chat_template_llama4_pythonic.jinja index bd18a35bdda9..bbed3d8205e0 100644 --- a/examples/tool_chat_template_llama4_pythonic.jinja +++ b/examples/tool_chat_template_llama4_pythonic.jinja @@ -1,16 +1,17 @@ {{- bos_token }} -{%- if custom_tools is defined %} +{%- if custom_tools is defined and custom_tools%} {%- set tools = custom_tools %} {%- endif %} -{%- if not tools_in_user_message is defined %} - {%- set tools_in_user_message = false %} -{%- endif %} -{%- if not tools is defined %} +{%- if tools is defined and tools %} + {%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %} +{%- else %} {%- set tools = none %} {%- endif %} + {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} + {%- set user_provided_system_message = true %} {%- if messages[0]['content'] is string %} {%- set system_message = messages[0]['content']|trim %} {%- else %} @@ -18,68 +19,33 @@ {%- endif %} {%- set messages = messages[1:] %} {%- else %} - {%- if tools is not none %} - {#- Add default tool system message when tools are provided #} - {%- set system_message = "You are a helpful assistant with tool calling " - "capabilities. Only reply with a tool call if the function exists in the " - "library provided by the user. If it doesn't exist, just reply directly in " - "natural language. When you receive a tool call response, use the output to " - "format an answer to the original user question." %} + {%- if tools is not none %} + {#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #} + {#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #} + {%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %} {%- else %} {%- set system_message = "" %} {%- endif %} {%- endif %} - -{#- System message if the user supplied one, or if tools are used (default tool system message) #} +{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #} {%- if system_message %} {#- always use user provided system message to override default tool system message #} {{- "<|header_start|>system<|header_end|>\n\n" }} {{- system_message }} - {%- if tools is not none and not tools_in_user_message %} - {{- "Tools: You have access to the following tools. You might need to use one " - "or more function/tool calls to fulfill the task. \n" - "If none are needed, then proceed to the response.\n\n" - "Tool Call Syntax: You can call tools using the following syntax:\n" - "[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n" - "Do not include anything else when calling the tools with the syntax above.\n\n" - "Here is a list of functions in JSON format that you can invoke.\n " }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} + {%- if user_provided_system_message and tools %} + {{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }} + {{- tool_definition -}} + {%- elif tool_definition %} + {{- tool_definition -}} {%- endif %} {{- "<|eot|>" }} {%- endif %} -{#- Custom tools are passed in a user message with some extra guidance #} -{%- if tools_in_user_message and tools is not none %} - {#- Extract the first user message so we can plug it in here #} - {%- if messages | length != 0 %} - {%- if messages[0]['content'] is string %} - {%- set first_user_message = messages[0]['content']|trim %} - {%- else %} - {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} - {%- endif %} - {%- set messages = messages[1:] %} - {%- else %} - {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} - {%- endif %} - {{- '<|header_start|>user<|header_end|>\n\n' -}} - {{- first_user_message}} - {{- "\nHere is a list of functions in JSON format that you can invoke:"}} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} - {{- "Should you decide to return the function call(s), put them in the format " - "of [func_name1(params_name1=params_value1, params_name2=params_value2, " - "...), ...]\nDo not include anything else when calling the tools with the " - "syntax above." }} -{%- endif %} - +{#- Now deal with all other messages #} {%- for message in messages %} - {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {#- Base case: messages that are not from tool role and has empty tool_call list #} + {%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} {%- if message['content'] is string %} {{- message['content'] }} {%- else %} @@ -91,10 +57,12 @@ {%- endif %} {%- endfor %} {%- endif %} - {{- "<|eot|>" }} - {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} - {%- set tool_call = message.tool_calls[0].function %} - {{- '<|header_start|>assistant<|header_end|>\n\n' -}} + {{- "<|eot|>" }} + {#- Tool case: messages has non-empty tool_call list, must from assistant #} + {%- elif 'tool_calls' in message %} + {#- assume tool_calls are always coming from assistant #} + {%- if message.role == 'assistant' %} + {{- '<|header_start|>assistant<|header_end|>\n\n' -}} {%- if message['content'] is string %} {{- message['content'] }} {%- else %} @@ -106,32 +74,36 @@ {%- endif %} {%- endfor %} {%- endif %} + {{- "[" }} {%- for tool_call in message.tool_calls %} {%- if tool_call.function is defined %} {%- set tool_call = tool_call.function %} {%- endif %} - {{- tool_call.name + '(' -}} + {{- tool_call.name + '(' -}} {%- for param in tool_call.arguments %} - {{- param + '=' -}} + {{- param + '="' -}} {{- "%s" | format(tool_call.arguments[param]) -}} + {{- '"' -}} {% if not loop.last %}, {% endif %} {%- endfor %} {{- ')' -}} {% if not loop.last %}, {% endif %} {%- endfor %} - {{- "<|eom|>" }} + {{- "]<|eot|>" }} +{%- endif %} +{#- Tool_response case: messages are from tool_response #} {%- elif message.role == "tool" or message.role == "ipython" %} {{- "<|header_start|>ipython<|header_end|>\n\n" }} {%- if message.content is string %} - {{- message.content | tojson }} + {{- message.content | tojson }} {%- else %} {%- for content in message['content'] %} {%- if content['type'] == 'text' %} - {{- content['text'] | tojson }} + {{- content['text'] | tojson }} {%- endif %} {%- endfor %} {%- endif %} - {{- "<|eom|>" }} + {{- "<|eot|>" }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} diff --git a/mkdocs.yaml b/mkdocs.yaml new file mode 100644 index 000000000000..52de643f5e2b --- /dev/null +++ b/mkdocs.yaml @@ -0,0 +1,130 @@ +site_name: vLLM +site_url: https://docs.vllm.ai +repo_url: https://github.com/vllm-project/vllm +exclude_docs: | + *.inc.md + *.template.md +theme: + name: material + logo: assets/logos/vllm-logo-only-light.ico + favicon: assets/logos/vllm-logo-only-light.ico + palette: + # Palette toggle for automatic mode + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to light mode + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + primary: white + toggle: + icon: material/brightness-7 + name: Switch to dark mode + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: black + toggle: + icon: material/brightness-2 + name: Switch to system preference + features: + - content.code.copy + - content.tabs.link + - navigation.tracking + - navigation.tabs + - navigation.sections + - navigation.prune + - navigation.top + - search.highlight + - search.share + - toc.follow + custom_dir: docs/mkdocs/overrides + +hooks: + - docs/mkdocs/hooks/remove_announcement.py + - docs/mkdocs/hooks/generate_examples.py + - docs/mkdocs/hooks/url_schemes.py + +# Required to stop api-autonav from raising an error +# https://github.com/tlambert03/mkdocs-api-autonav/issues/16 +nav: + - api + +plugins: + - meta + - search + - autorefs + - awesome-nav + # For API reference generation + - api-autonav: + modules: ["vllm"] + api_root_uri: "api" + exclude: + - "re:vllm\\._.*" # Internal modules + - "vllm.third_party" + - "vllm.vllm_flash_attn" + - mkdocstrings: + handlers: + python: + options: + show_symbol_type_heading: true + show_symbol_type_toc: true + filters: [] + summary: + modules: true + show_if_no_docstring: true + show_signature_annotations: true + separate_signature: true + show_overloads: true + signature_crossrefs: true + inventories: + - https://docs.python.org/3/objects.inv + - https://typing-extensions.readthedocs.io/en/latest/objects.inv + - https://docs.aiohttp.org/en/stable/objects.inv + - https://pillow.readthedocs.io/en/stable/objects.inv + - https://numpy.org/doc/stable/objects.inv + - https://pytorch.org/docs/stable/objects.inv + - https://psutil.readthedocs.io/en/stable/objects.inv + +markdown_extensions: + - attr_list + - md_in_html + - admonition + - pymdownx.details + # For content tabs + - pymdownx.superfences + - pymdownx.tabbed: + slugify: !!python/object/apply:pymdownx.slugs.slugify + kwds: + case: lower + alternate_style: true + # For code highlighting + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.snippets + # For emoji and icons + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + # For in page [TOC] (not sidebar) + - toc: + permalink: true + # For math rendering + - mdx_math: + enable_dollar_delimiter: true + +extra_css: + - mkdocs/stylesheets/extra.css + +extra_javascript: + - mkdocs/javascript/run_llm_widget.js + - https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML + +# Makes the url format end in .html rather than act as a dir +# So index.md generates as index.html and is available under URL /index.html +# https://www.mkdocs.org/user-guide/configuration/#use_directory_urls +use_directory_urls: false diff --git a/pyproject.toml b/pyproject.toml index 069e295bfb93..62a734d795d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ requires = [ "setuptools-scm>=8.0", "torch == 2.7.0", "wheel", + "regex", "jinja2", ] build-backend = "setuptools.build_meta" @@ -35,12 +36,15 @@ dynamic = [ "version", "dependencies", "optional-dependencies"] [project.urls] Homepage="https://github.com/vllm-project/vllm" -Documentation="https://vllm.readthedocs.io/en/latest/" -Slack="http://slack.vllm.ai/" +Documentation="https://docs.vllm.ai/en/latest/" +Slack="https://slack.vllm.ai/" [project.scripts] vllm = "vllm.entrypoints.cli.main:main" +[project.entry-points."vllm.general_plugins"] +lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver" + [tool.setuptools_scm] # no extra settings needed, presence enables setuptools-scm @@ -50,42 +54,29 @@ include = ["vllm*"] [tool.yapfignore] ignore_patterns = [ + ".buildkite/**", + "benchmarks/**", "build/**", + "examples/**", ] [tool.ruff] # Allow lines to be as long as 80. line-length = 80 -exclude = [ - # External file, leaving license intact - "examples/other/fp8/quantizer/quantize.py", - "vllm/vllm_flash_attn/flash_attn_interface.pyi" -] [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Python 3.8 typing. TODO: Remove these excludes after v1.0.0 -"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"] +# Python 3.8 typing - skip V0 code "vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/compilation/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"] -"vllm/device_allocator/**/*.py" = ["UP006", "UP035"] -"vllm/distributed/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/lora/**/*.py" = ["UP006", "UP035"] -"vllm/model_executor/**/*.py" = ["UP006", "UP035"] -"vllm/platforms/**/*.py" = ["UP006", "UP035"] -"vllm/plugins/**/*.py" = ["UP006", "UP035"] -"vllm/profiler/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"] -"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"] -"vllm/triton_utils/**/*.py" = ["UP006", "UP035"] -"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"] +# Python 3.8 typing - skip utils for ROCm "vllm/utils.py" = ["UP006", "UP035"] [tool.ruff.lint] @@ -102,6 +93,7 @@ select = [ "SIM", # isort # "I", + # flake8-logging-format "G", ] ignore = [ @@ -150,6 +142,11 @@ ignore-words-list = "dout, te, indicies, subtile, ElementE" skip = "tests/models/fixtures/*,tests/prompts/*,benchmarks/sonnet.txt,tests/lora/data/*,build/*,vllm/third_party/*" [tool.isort] +skip_glob = [ + ".buildkite/*", + "benchmarks/*", + "examples/*", +] use_parentheses = true skip_gitignore = true @@ -166,7 +163,16 @@ markers = [ [tool.pymarkdown] plugins.md004.style = "sublist" # ul-style +plugins.md007.indent = 4 # ul-indent +plugins.md007.start_indented = true # ul-indent plugins.md013.enabled = false # line-length plugins.md041.enabled = false # first-line-h1 plugins.md033.enabled = false # inline-html +plugins.md046.enabled = false # code-block-style plugins.md024.allow_different_nesting = true # no-duplicate-headers + +[tool.ty] +respect-ignore-files = true + +[tool.ty.environment] +python = "./.venv" diff --git a/requirements/build.txt b/requirements/build.txt index 5edc593b9270..320e5b892584 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -7,3 +7,4 @@ setuptools-scm>=8 torch==2.7.0 wheel jinja2>=3.1.6 +regex diff --git a/requirements/common.txt b/requirements/common.txt index d6f59ad0b1ab..625efc3366f4 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -1,3 +1,4 @@ +regex # Replace re for higher-performance regex matching cachetools psutil sentencepiece # Required for LLaMA tokenizer. @@ -7,7 +8,7 @@ tqdm blake3 py-cpuinfo transformers >= 4.51.1 -huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads. +huggingface-hub[hf_xet] >= 0.32.0 # Required for Xet downloads. tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. @@ -19,10 +20,10 @@ pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.11, < 0.11 -llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" +llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 -xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64" +xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs @@ -40,7 +41,7 @@ compressed-tensors == 0.9.4 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files -python-json-logger # Used by logging as per examples/other/logging_configuration.md +python-json-logger # Used by logging as per examples/others/logging_configuration.md scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu opentelemetry-sdk>=1.26.0 # vllm.tracing diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 752931158a05..1213301584ce 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -2,11 +2,12 @@ -r common.txt # Dependencies for CPUs +packaging>=24.2 +setuptools>=77.0.3,<80.0.0 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.7.0+cpu; platform_machine == "x86_64" torch==2.7.0; platform_system == "Darwin" torch==2.7.0; platform_machine == "ppc64le" or platform_machine == "aarch64" -torch==2.7.0.dev20250304; platform_machine == "s390x" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" @@ -19,3 +20,7 @@ datasets # for benchmark scripts # cpu cannot use triton 3.3.0 triton==3.2.0; platform_machine == "x86_64" + +# Intel Extension for PyTorch, only for x86_64 CPUs +intel-openmp==2024.2.1; platform_machine == "x86_64" +intel_extension_for_pytorch==2.7.0; platform_machine == "x86_64" diff --git a/requirements/docs.txt b/requirements/docs.txt index 9c267edaceaf..64c70cb65c55 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,19 +1,9 @@ -sphinx==7.4.7 -sphinx-argparse==0.5.2 -sphinx-book-theme==1.1.4 -sphinx-copybutton==0.5.2 -sphinx-design==0.6.1 -sphinx-togglebutton==0.3.2 -myst-parser==3.0.1 # `myst-parser==4.0.1` breaks inline code in titles -msgspec -snowballstemmer<3 # https://github.com/snowballstem/snowball/issues/229 -commonmark # Required by sphinx-argparse when using :markdownhelp: - -# Custom autodoc2 is necessary for faster docstring processing -# see: https://github.com/sphinx-extensions2/sphinx-autodoc2/issues/33#issuecomment-2856386035 -git+https://github.com/hmellor/sphinx-autodoc2.git # sphinx-autodoc2==0.5.0 - -# packages to install to build the documentation -cachetools --f https://download.pytorch.org/whl/cpu -torch \ No newline at end of file +mkdocs +mkdocs-api-autonav +mkdocs-material +mkdocstrings-python +mkdocs-gen-files +mkdocs-awesome-nav +python-markdown-math +regex +ruff diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 3aebcaa623c0..e9b466d3a82d 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -38,4 +38,4 @@ matplotlib # required for qwen-vl test # required for Multi-Modal Models Test (Standard) num2words # required for smolvlm test pqdm -timm # required for internvl test +timm # required for internvl test \ No newline at end of file diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index 52fbf787f1df..25f950a99ece 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -1,3 +1,5 @@ +# Common dependencies +-r common.txt # entrypoints test # librosa==0.10.2.post1 # required by audio tests in entrypoints/openai @@ -20,4 +22,10 @@ decord==0.6.0 #sentence-transformers # required by entrypoints/openai/test_score.py sentence-transformers==3.4.1 +# Basic Models Test +matplotlib==3.10.3 + +# Multi-Modal Models Test (Extended) 3 +blobfile==3.0.0 + diff --git a/requirements/test.in b/requirements/test.in index cdc7c563f087..87af61769038 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -33,6 +33,7 @@ num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test +mteb>=1.38.11, <2 # required for mteb test transformers==4.51.3 tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. diff --git a/requirements/test.txt b/requirements/test.txt index 9a15d9a0d824..89d477017342 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -99,6 +99,7 @@ datasets==3.0.2 # via # evaluate # lm-eval + # mteb decorator==5.1.1 # via librosa dill==0.3.8 @@ -124,6 +125,8 @@ email-validator==2.2.0 # via pydantic encodec==0.1.1 # via vocos +eval-type-backport==0.2.2 + # via mteb evaluate==0.4.3 # via lm-eval fastparquet==2024.11.0 @@ -291,6 +294,8 @@ msgpack==1.1.0 # via # librosa # ray +mteb==1.38.11 + # via -r requirements/test.in multidict==6.1.0 # via # aiohttp @@ -331,6 +336,7 @@ numpy==1.26.4 # librosa # matplotlib # mistral-common + # mteb # numba # numexpr # opencv-python-headless @@ -443,6 +449,8 @@ plotly==5.24.1 # via genai-perf pluggy==1.5.0 # via pytest +polars==1.29.0 + # via mteb pooch==1.8.2 # via librosa portalocker==2.10.1 @@ -476,6 +484,7 @@ pydantic==2.9.2 # via # datamodel-code-generator # mistral-common + # mteb pydantic-core==2.23.4 # via pydantic pygments==2.18.0 @@ -522,6 +531,8 @@ python-dateutil==2.9.0.post0 # typepy python-rapidjson==1.20 # via tritonclient +pytrec-eval-terrier==0.5.7 + # via mteb pytz==2024.2 # via # pandas @@ -564,6 +575,7 @@ requests==2.32.3 # huggingface-hub # lm-eval # mistral-common + # mteb # pooch # ray # responses @@ -580,6 +592,7 @@ rfc3987==1.3.8 rich==13.9.4 # via # genai-perf + # mteb # typer rouge-score==0.1.2 # via lm-eval @@ -607,16 +620,20 @@ scikit-learn==1.5.2 # via # librosa # lm-eval + # mteb # sentence-transformers scipy==1.13.1 # via # librosa + # mteb # scikit-learn # sentence-transformers # statsmodels # vocos sentence-transformers==3.2.1 - # via -r requirements/test.in + # via + # -r requirements/test.in + # mteb sentencepiece==0.2.0 # via mistral-common setuptools==77.0.3 @@ -696,6 +713,7 @@ torch==2.7.0+cu128 # fastsafetensors # lm-eval # mamba-ssm + # mteb # peft # runai-model-streamer # sentence-transformers @@ -720,6 +738,7 @@ tqdm==4.66.6 # evaluate # huggingface-hub # lm-eval + # mteb # nltk # peft # pqdm @@ -759,6 +778,7 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common + # mteb # pqdm # pydantic # pydantic-core diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 11501bc5d92f..3b204a8f9905 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250430 -torchvision==0.22.0.dev20250430 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250518 +torchvision==0.22.0.dev20250518 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index 7675fbdf3efe..180f2f978501 --- a/setup.py +++ b/setup.py @@ -5,12 +5,12 @@ import json import logging import os -import re import subprocess import sys from pathlib import Path from shutil import which +import regex as re import torch from packaging.version import Version, parse from setuptools import Extension, setup @@ -389,7 +389,6 @@ def run(self) -> None: # vllm_flash_attn python code: # Regex from # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)` - import re compiled_regex = re.compile( r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") file_members += list( diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 48e2e31e5db8..b6f44871497c 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -41,7 +41,7 @@ def __init__(self): self.abort_request_calls = 0 self.request_id = None # Ugly, remove dependency when possible - self.parallel_config = ParallelConfig(1, 1, False) + self.parallel_config = ParallelConfig() self.model_config = MockModelConfig() async def step_async(self, virtual_engine): diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 9f3b0e8ae079..86b5e1e0ab7c 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -8,12 +8,13 @@ from unittest.mock import Mock import pytest +import torch -from vllm import LLM +from vllm import LLM, envs from vllm.platforms import current_platform from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 -from ..conftest import VllmRunner +from ..conftest import HfRunner, VllmRunner from ..models.utils import check_outputs_equal from ..utils import multi_gpu_test @@ -43,11 +44,26 @@ def test_vllm_gc_ed(): assert weak_llm() is None +def _fix_prompt_embed_outputs( + vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner, + example_prompts: list[str]) -> list[tuple[list[int], str]]: + fixed_vllm_outputs = [] + for vllm_output, hf_input, prompt in zip( + vllm_outputs, hf_model.get_inputs(example_prompts), + example_prompts): + hf_input_ids = hf_input["input_ids"].tolist()[0] + fixed_vllm_outputs.append( + (hf_input_ids + vllm_output[0][len(hf_input_ids):], + prompt + vllm_output[1])) + return fixed_vllm_outputs + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models( monkeypatch: pytest.MonkeyPatch, hf_runner, @@ -56,8 +72,13 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, + enable_prompt_embeds: bool, ) -> None: + if enable_prompt_embeds and envs.is_set( + "VLLM_USE_V1") and envs.VLLM_USE_V1: + pytest.skip("enable_prompt_embeds is not supported in v1.") + if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") @@ -78,14 +99,25 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + if enable_prompt_embeds: + with torch.no_grad(): + prompt_embeds = hf_model.get_prompt_embeddings( + example_prompts) with VllmRunner(model, max_model_len=8192, dtype=dtype, enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, gpu_memory_utilization=0.7) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) + if enable_prompt_embeds: + vllm_outputs = vllm_model.generate_greedy( + prompt_embeds, max_tokens) + vllm_outputs = _fix_prompt_embed_outputs( + vllm_outputs, hf_model, example_prompts) + else: + vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -108,6 +140,7 @@ def test_models( ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), ]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models_distributed( monkeypatch: pytest.MonkeyPatch, hf_runner, @@ -117,14 +150,22 @@ def test_models_distributed( distributed_executor_backend: str, attention_backend: str, test_suite: str, + enable_prompt_embeds: bool, ) -> None: + if enable_prompt_embeds and envs.is_set( + "VLLM_USE_V1") and envs.VLLM_USE_V1: + pytest.skip("enable_prompt_embeds is not supported in v1.") + if test_suite != TARGET_TEST_SUITE: pytest.skip(f"Skip test for {test_suite}") with monkeypatch.context() as monkeypatch_context: if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa - # test Ray Compiled Graph + if enable_prompt_embeds: + pytest.skip( + "enable_prompt_embeds does not work with ray compiled dag." + ) monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") @@ -147,12 +188,26 @@ def test_models_distributed( dtype=dtype, tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + if enable_prompt_embeds: + with hf_runner(model, dtype=dtype) as hf_model: + with torch.no_grad(): + prompt_embeds = hf_model.get_prompt_embeddings( + example_prompts) + vllm_outputs = vllm_model.generate_greedy( + prompt_embeds, max_tokens) + vllm_outputs = _fix_prompt_embed_outputs( + vllm_outputs, hf_model, example_prompts) + hf_outputs = hf_model.generate_greedy( + example_prompts, max_tokens) + else: + vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy( + example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, diff --git a/tests/compile/backend.py b/tests/compile/backend.py index a21e8eca3a6e..5a02c4e2b378 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -5,6 +5,8 @@ from torch import fx +from vllm.compilation.fx_utils import (find_specified_fn, + find_specified_fn_maybe) from vllm.compilation.inductor_pass import InductorPass from vllm.config import get_current_vllm_config @@ -44,3 +46,19 @@ def post_pass(self, graph: fx.Graph): self.graph_post_pass = deepcopy(graph) # assign by reference, will reflect the final state of the graph self.final_graph = graph + + def check_before_ops(self, ops, + find_fn=find_specified_fn, \ + find_fn_maybe=find_specified_fn_maybe, \ + ops_fully_replaced=True): + for op in ops: + find_fn(self.graph_pre_pass.nodes, op) + if ops_fully_replaced: + assert find_fn_maybe(self.graph_post_pass.nodes, op) is None + + def check_after_ops(self, ops, + find_fn=find_specified_fn, \ + find_fn_maybe=find_specified_fn_maybe): + for op in ops: + find_fn(self.graph_post_pass.nodes, op) + assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py new file mode 100644 index 000000000000..8e4e0ba83579 --- /dev/null +++ b/tests/compile/test_async_tp.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest +import torch + +import vllm.envs as envs +from vllm.compilation.collective_fusion import AsyncTPPass +from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, + PassConfig, VllmConfig) +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter) +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import (compare_two_settings, create_new_process_for_each_test, + multi_gpu_test) +from .backend import TestBackend + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class TestMMRSModel(torch.nn.Module): + + def __init__(self, hidden_size=16): + super().__init__() + self.hidden_size = hidden_size + self.gate_proj = torch.nn.Parameter(torch.empty( + (self.hidden_size * 2, hidden_size)), + requires_grad=False) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + + def forward(self, hidden_states): + """ + Forward pass implementing the mm + reduce scatter in the FX graph + + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + # matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0) + return reduce_scatter + + def ops_in_model_before(self): + return [torch.ops.vllm.reduce_scatter.default] + + def ops_in_model_after(self): + return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default] + + +class TestAGMMModel(torch.nn.Module): + + def __init__(self, hidden_size=16): + super().__init__() + self.hidden_size = hidden_size + self.weight = torch.nn.Parameter(torch.empty( + (hidden_size, hidden_size)), + requires_grad=False) + # Initialize weights + torch.nn.init.normal_(self.weight, std=0.02) + + def forward(self, hidden_states): + """ + Forward pass implementing the mm + all gather in the FX graph + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + all_gather = tensor_model_parallel_all_gather(view, dim=0) + permute = self.weight.permute(1, 0) + mm = torch.mm(all_gather, permute) + return mm + + def ops_in_model_before(self): + return [torch.ops.vllm.all_gather.default] + + def ops_in_model_after(self): + return [torch.ops.symm_mem.fused_all_gather_matmul.default] + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [16]) +@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn(fn, + args=(num_processes, test_model, + batch_size, seq_len, hidden_size, + dtype), + nprocs=nprocs) + + run_torch_spawn(async_tp_pass_on_test_model, num_processes) + + +def async_tp_pass_on_test_model(local_rank: int, world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # configure vllm config for SequenceParallelismPass + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( + enable_async_tp=True, ), ) + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + + # this is a fake model name to construct the model config + # in the vllm_config, it's not really used. + model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model_name, + task="auto", + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42) + + async_tp_pass = AsyncTPPass(vllm_config) + backend = TestBackend(async_tp_pass) + + model = test_model_cls(hidden_size) + + hidden_states = torch.randn((batch_size * seq_len, hidden_size), + dtype=dtype, + requires_grad=False) + + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states) + + # In pre-nodes, all gather or reduce scatter should exist, + # fused_matmul_reduce_scatter or fused_all_gather_matmul should not + backend.check_before_ops(model.ops_in_model_before(), + ops_fully_replaced=False) + + # In post-nodes, fused_matmul_reduce_scatter or \ + # fused_all_gather_matmul should exist + backend.check_after_ops(model.ops_in_model_after()) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("async_tp_enabled", [True]) +@pytest.mark.parametrize("distributed_backend", ["mp"]) +@pytest.mark.parametrize("eager_mode", [False, True]) +def test_async_tp_pass_correctness( + model_id: str, + tp_size: int, + async_tp_enabled: bool, + distributed_backend: str, + eager_mode: bool, + num_gpus_available: int, +): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + + common_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if eager_mode: + common_args.append("--enforce-eager") + + compilation_config = { + 'level': 3, + 'compile_sizes': [2, 4, 8], + 'splitting_ops': [], + 'pass_config': { + 'enable_async_tp': async_tp_enabled + }, + } + + async_tp_env = tp_env = { + "VLLM_USE_V1": "1", + } + + aysnc_tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + json.dumps(compilation_config), + ] + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + compare_two_settings(model_id, + aysnc_tp_args, + tp_args, + async_tp_env, + tp_env, + method="generate") diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index c09406385987..397517b8665b 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -9,7 +9,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig, CompilationLevel, PassConfig from vllm.platforms import current_platform from ..utils import create_new_process_for_each_test @@ -95,9 +95,6 @@ def test_full_graph( run_model(optimization_level, model, model_kwargs) -PassConfig = CompilationConfig.PassConfig - - # TODO(luka) add other supported compilation config scenarios here @pytest.mark.parametrize( "compilation_config, model_info", diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 1e1364ce7bf6..5d38ff91490e 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -11,7 +11,7 @@ kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, PassConfig, VllmConfig from .backend import TestBackend @@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, torch.set_default_device("cuda") vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config= \ - CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_noop=True)) + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) noop_pass = NoOpEliminationPass(vllm_config) fusion_pass = FusionPass.instance(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6a696fe0226b..509593e7328d 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -9,7 +9,8 @@ FusionPass, QuantKey) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) @@ -28,6 +29,10 @@ def __init__(self, hidden_size: int, eps: float, static: bool, self.cutlass_fp8_enabled = cutlass_fp8_enabled self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.key = QuantKey(dtype=FP8_DTYPE, + static=static, + per_tensor=static, + symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] else: @@ -58,6 +63,15 @@ def forward(self, x): y3, resid = self.norm[2](x3, resid) # use resid here return y3 + def ops_in_model_before(self): + return [QUANT_OPS[self.key]] + + def ops_in_model_after(self): + return [ + FUSED_OPS[FusedRMSQuantKey(self.key, False)], + FUSED_OPS[FusedRMSQuantKey(self.key, True)] + ] + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @@ -78,8 +92,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) vllm_config.compilation_config.pass_config = \ - CompilationConfig.PassConfig(enable_fusion=True, - enable_noop=True) + PassConfig(enable_fusion=True, enable_noop=True) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) @@ -107,25 +120,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes - - # static is per-tensor, dynamic is per-token - key = QuantKey(dtype=FP8_DTYPE, - static=static, - per_tensor=static, - symmetric=True) - rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] - add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] - fp8_quant = QUANT_OPS[key] - # In pre-nodes, fp8 quant should be there and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, rms_quant) is None - assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None - find_auto_fn(pre_nodes, fp8_quant) + backend.check_before_ops(model.ops_in_model_before(), find_auto_fn, + find_auto_fn_maybe) # In post-nodes, fused kernels should be there and fp8 quant should not - find_auto_fn(post_nodes, rms_quant) - find_auto_fn(post_nodes, add_rms_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + backend.check_after_ops(model.ops_in_model_after(), find_auto_fn, + find_auto_fn_maybe) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 673ebe8b6fdc..b630d0e85d31 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -22,7 +22,7 @@ def test_bad_callable(): pass_manager.configure(config) with pytest.raises(AssertionError): - pass_manager.add(simple_callable) # noqa, type wrong on purpose + pass_manager.add(simple_callable) # Pass that inherits from InductorPass diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 79f5486dadcd..2cd7ebaacec0 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -5,12 +5,10 @@ import vllm.envs as envs from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, - find_specified_fn, - find_specified_fn_maybe, is_func) +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - VllmConfig) + PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) @@ -21,17 +19,6 @@ from ..utils import multi_gpu_test from .backend import TestBackend -OPS_IN_MODEL_BEFORE = [ - torch.ops.vllm.all_reduce.default, -] - -OPS_IN_MODEL_AFTER = [ - torch.ops.vllm.reduce_scatter.default, - torch.ops.vllm.all_gather.default, -] - -OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default] - prompts = [ "Hello, my name is", "The president of the United States is", @@ -78,6 +65,18 @@ def forward(self, hidden_states, residual): return norm_output, residual_output + def ops_in_model_before(self): + return [torch.ops.vllm.all_reduce.default] + + def ops_in_model_after(self): + return [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default + ] + + def ops_in_model(self): + return [torch.ops._C.fused_add_rms_norm.default] + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("batch_size", [8]) @@ -126,9 +125,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( - pass_config=CompilationConfig.PassConfig( - enable_sequence_parallelism=True, ), ) + vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( + enable_sequence_parallelism=True)) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config @@ -157,26 +155,16 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, compiled_model_func = torch.compile(model, backend=backend_func) compiled_model_func(hidden_states, residual) - # Check substitution worked - pre_nodes = backend_no_func.graph_pre_pass.nodes - post_nodes = backend_no_func.graph_post_pass.nodes - # In pre-nodes, all reduce should be there, # reduce scatter and all gather should not - for op in OPS_IN_MODEL_BEFORE: - find_specified_fn(pre_nodes, op) - for op in OPS_IN_MODEL_AFTER: - assert find_specified_fn_maybe(pre_nodes, op) is None + backend_no_func.check_before_ops(model.ops_in_model_before()) # In post-nodes, reduce scatter and all gather should be there, # all reduce should not - for op in OPS_IN_MODEL_AFTER: - find_specified_fn(post_nodes, op) - for op in OPS_IN_MODEL_BEFORE: - assert find_specified_fn_maybe(post_nodes, op) is None + backend_no_func.check_after_ops(model.ops_in_model_after()) # check if the functionalization pass is applied - for op in OPS_IN_MODEL: + for op in model.ops_in_model(): find_auto_fn(backend_no_func.graph_post_pass.nodes, op) assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501 @@ -184,7 +172,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # make sure the ops were all de-functionalized found = dict() for node in backend_func.graph_post_pass.nodes: - for op in OPS_IN_MODEL: + for op in model.ops_in_model(): if is_func(node, op): found[op] = True - assert all(found[op] for op in OPS_IN_MODEL) + assert all(found[op] for op in model.ops_in_model()) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 313848372e04..9eae48d60f36 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -6,7 +6,7 @@ from vllm._custom_ops import scaled_fp8_quant from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from .backend import TestBackend @@ -27,8 +27,8 @@ def forward(self, x): @pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("hidden_size", [64]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", - reason="Only test on CUDA") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], + reason="Only test on CUDA and ROCm") def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): # Reshape pass is needed for the fusion pass to work config = VllmConfig() config.compilation_config = CompilationConfig( - pass_config=CompilationConfig.PassConfig(enable_fusion=True, - enable_reshape=True)) + pass_config=PassConfig(enable_fusion=True, enable_noop=True)) fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(fusion_pass) diff --git a/tests/conftest.py b/tests/conftest.py index fa979f1093be..19c2c6247129 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -355,10 +355,16 @@ def __init__( **model_kwargs, ) + # in case some unquantized custom models are not in same dtype + if (getattr(model, "quantization_method", None) is None + and any(p.dtype != self.dtype + for p in model.parameters())): + model = model.to(dtype=self.dtype) + if (getattr(model, "quantization_method", None) != "bitsandbytes" and len({p.device for p in model.parameters()}) < 2): - model = model.to(self.device) + model = model.to(device=self.device) self.model = model @@ -424,6 +430,15 @@ def get_inputs( return all_inputs + def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]: + all_inputs = self.get_inputs(prompts) + embeddings = [] + for inputs in all_inputs: + input_ids = self.wrap_device(inputs)["input_ids"] + embedding = self.model.get_input_embeddings()(input_ids).squeeze(0) + embeddings.append(embedding) + return embeddings + def classify(self, prompts: list[str]) -> list[str]: # output is final logits all_inputs = self.get_inputs(prompts) diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py index 15bcfdb8555f..8de1aa20eabd 100644 --- a/tests/distributed/test_events.py +++ b/tests/distributed/test_events.py @@ -119,13 +119,12 @@ def test_topic_filtering(publisher_config): """ publisher_config.replay_endpoint = None - cfg = publisher_config.model_copy() - cfg.topic = "foo" - pub = EventPublisherFactory.create(cfg) + publisher_config.topic = "foo" + pub = EventPublisherFactory.create(publisher_config) from .conftest import MockSubscriber - sub_foo = MockSubscriber(cfg.endpoint, None, "foo") - sub_bar = MockSubscriber(cfg.endpoint, None, "bar") + sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo") + sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar") try: time.sleep(0.1) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 9c90fe381bb2..5346d67b10d1 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -185,7 +185,7 @@ def iter_params(self, model_id: str): "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), "allenai/OLMo-1B-hf": PPTestSettings.fast(), - "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(), + "allenai/OLMo-2-0425-1B": PPTestSettings.fast(), "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(), "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(), diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 19497ad9c140..c9eba2b43788 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -26,6 +26,7 @@ class ParallelSetup(NamedTuple): tp_size: int + pp_size: int sp_enabled: bool eager_mode: bool chunked_prefill: bool @@ -60,6 +61,7 @@ def __post_init__(self): def detailed( *, tp_base: int = 2, + pp_base: int = 1, multi_node_only: bool = False, task: TaskOption = "auto", load_format: Optional[str] = None, @@ -67,18 +69,42 @@ def detailed( return SPTestSettings( parallel_setups=[ ParallelSetup(tp_size=tp_base, + pp_size=pp_base, sp_enabled=True, eager_mode=False, chunked_prefill=False), ParallelSetup(tp_size=tp_base, + pp_size=pp_base, sp_enabled=True, eager_mode=False, chunked_prefill=True), ParallelSetup(tp_size=tp_base, + pp_size=pp_base, sp_enabled=True, eager_mode=True, chunked_prefill=False), ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, sp_enabled=True, eager_mode=True, chunked_prefill=True) @@ -94,6 +120,7 @@ def detailed( def fast( *, tp_base: int = 2, + pp_base: int = 1, task: TaskOption = "auto", multi_node_only: bool = False, load_format: Optional[str] = None, @@ -101,6 +128,12 @@ def fast( return SPTestSettings( parallel_setups=[ ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, sp_enabled=True, eager_mode=False, chunked_prefill=False), @@ -136,6 +169,7 @@ def _compare_sp( ): ( tp_size, + pp_size, sp_enabled, eager_mode, chunked_prefill, @@ -167,7 +201,6 @@ def _compare_sp( else: model_info.check_available_online(on_fail="skip") - pp_size = 1 if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": @@ -206,7 +239,7 @@ def _compare_sp( 'compile_sizes': [4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_sequence_parallism': sp_enabled, + 'enable_sequence_parallelism': sp_enabled, 'enable_noop': True, 'enable_fusion': True, }, @@ -223,7 +256,7 @@ def _compare_sp( "--distributed-executor-backend", distributed_backend, "--compilation_config", - str(compilation_config), + json.dumps(compilation_config), ] tp_env = { @@ -256,7 +289,7 @@ def _compare_sp( SP_TEXT_GENERATION_MODELS = { # [Decoder-only] - "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(), + "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), } SP_TEST_MODELS = [ diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 711c2441f34b..f9eacc11d75f 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -9,7 +9,7 @@ from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import get_ip, get_open_port, update_environment_variables +from vllm.utils import get_open_port, update_environment_variables def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: @@ -60,12 +60,12 @@ def worker_fn(): rank = dist.get_rank() if rank == 0: port = get_open_port() - ip = get_ip() + ip = '127.0.0.1' dist.broadcast_object_list([ip, port], src=0) else: recv = [None, None] dist.broadcast_object_list(recv, src=0) - ip, port = recv + ip, port = recv # type: ignore stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) @@ -107,10 +107,10 @@ def worker_fn(): if pg == dist.group.WORLD: dist.barrier() - print("torch distributed passed the test!") + print(f"torch distributed passed the test! Rank {rank}") else: pg.barrier() - print("StatelessProcessGroup passed the test!") + print(f"StatelessProcessGroup passed the test! Rank {rank}") def test_shm_broadcast(): diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index 0420a6454d46..bb38e908b734 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # unit test for `examples/offline_inference/torchrun_example.py` - +import os import random import torch.distributed as dist @@ -25,6 +25,7 @@ # to test if all ranks agree on the same kv cache configuration. llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, + pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), distributed_executor_backend="external_launcher", gpu_memory_utilization=random.uniform(0.7, 0.9), swap_space=random.randint(1, 4), diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 65471cb3af38..05d9cfc7ab74 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -8,21 +8,18 @@ import pytest -from vllm.config import config +from vllm.config import CompilationConfig, config from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, get_type, is_not_builtin, is_type, literal_to_kwargs, nullable_kvs, - optional_type) + optional_type, parse_type) from vllm.utils import FlexibleArgumentParser @pytest.mark.parametrize(("type", "value", "expected"), [ (int, "42", 42), - (int, "None", None), (float, "3.14", 3.14), - (float, "None", None), (str, "Hello World!", "Hello World!"), - (str, "None", None), (json.loads, '{"foo":1,"bar":2}', { "foo": 1, "bar": 2 @@ -31,15 +28,20 @@ "foo": 1, "bar": 2 }), - (json.loads, "None", None), ]) -def test_optional_type(type, value, expected): - optional_type_func = optional_type(type) +def test_parse_type(type, value, expected): + parse_type_func = parse_type(type) context = nullcontext() if value == "foo=1,bar=2": context = pytest.warns(DeprecationWarning) with context: - assert optional_type_func(value) == expected + assert parse_type_func(value) == expected + + +def test_optional_type(): + optional_type_func = optional_type(int) + assert optional_type_func("None") is None + assert optional_type_func("42") == 42 @pytest.mark.parametrize(("type_hint", "type", "expected"), [ @@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected): @config @dataclass -class DummyConfigClass: +class NestedConfig: + field: int = 1 + """field""" + + +@config +@dataclass +class FromCliConfig1: + field: int = 1 + """field""" + + @classmethod + def from_cli(cls, cli_value: str): + inst = cls(**json.loads(cli_value)) + inst.field += 1 + return inst + + +@config +@dataclass +class FromCliConfig2: + field: int = 1 + """field""" + + @classmethod + def from_cli(cls, cli_value: str): + inst = cls(**json.loads(cli_value)) + inst.field += 2 + return inst + + +@config +@dataclass +class DummyConfig: regular_bool: bool = True """Regular bool with default True""" optional_bool: Optional[bool] = None @@ -108,18 +143,24 @@ class DummyConfigClass: """Literal of literals with default 1""" json_tip: dict = field(default_factory=dict) """Dict which will be JSON in CLI""" + nested_config: NestedConfig = field(default_factory=NestedConfig) + """Nested config""" + from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1) + """Config with from_cli method""" + from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2) + """Different config with from_cli method""" @pytest.mark.parametrize(("type_hint", "expected"), [ (int, False), - (DummyConfigClass, True), + (DummyConfig, True), ]) def test_is_not_builtin(type_hint, expected): assert is_not_builtin(type_hint) == expected def test_get_kwargs(): - kwargs = get_kwargs(DummyConfigClass) + kwargs = get_kwargs(DummyConfig) print(kwargs) # bools should not have their type set @@ -140,8 +181,13 @@ def test_get_kwargs(): # literals of literals should have merged choices assert kwargs["literal_literal"]["choices"] == [1, 2] # dict should have json tip in help - json_tip = "\n\nShould be a valid JSON string." - assert kwargs["json_tip"]["help"].endswith(json_tip) + json_tip = "Should either be a valid JSON string or JSON keys" + assert json_tip in kwargs["json_tip"]["help"] + # nested config should should construct the nested config + assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) + # from_cli configs should be constructed with the correct method + assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3 + assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4 @pytest.mark.parametrize(("arg", "expected"), [ @@ -177,7 +223,7 @@ def test_compilation_config(): # default value args = parser.parse_args([]) - assert args.compilation_config is None + assert args.compilation_config == CompilationConfig() # set to O3 args = parser.parse_args(["-O3"]) @@ -194,7 +240,7 @@ def test_compilation_config(): # set to string form of a dict args = parser.parse_args([ "--compilation-config", - "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', ]) assert (args.compilation_config.level == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) @@ -202,7 +248,7 @@ def test_compilation_config(): # set to string form of a dict args = parser.parse_args([ "--compilation-config=" - "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', ]) assert (args.compilation_config.level == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index fdbdccd4654c..dd5d17885eb9 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re import weakref from enum import Enum import jsonschema import pytest +import regex as re from pydantic import BaseModel from vllm.distributed import cleanup_dist_env_and_memory diff --git a/tests/entrypoints/openai/correctness/test_mteb.py b/tests/entrypoints/openai/correctness/test_mteb.py new file mode 100644 index 000000000000..ebf2f829b583 --- /dev/null +++ b/tests/entrypoints/openai/correctness/test_mteb.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest + +from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, + OpenAIClientMtebEncoder, + run_mteb_embed_task, + run_mteb_embed_task_st) +from tests.utils import RemoteOpenAIServer + +os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" + +MODEL_NAME = "BAAI/bge-m3" +DTYPE = "float16" +MAIN_SCORE = 0.7873427091972599 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", "embed", "--dtype", DTYPE, "--enforce-eager", + "--max-model-len", "512" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def test_mteb(server): + client = server.get_client() + encoder = OpenAIClientMtebEncoder(MODEL_NAME, client) + vllm_main_score = run_mteb_embed_task(encoder, MTEB_EMBED_TASKS) + st_main_score = MAIN_SCORE or run_mteb_embed_task_st( + MODEL_NAME, MTEB_EMBED_TASKS) + + print("VLLM main score: ", vllm_main_score) + print("SentenceTransformer main score: ", st_main_score) + print("Difference: ", st_main_score - vllm_main_score) + + assert st_main_score == pytest.approx(vllm_main_score, rel=1e-4) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 72e616656775..7f959f312019 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -272,7 +272,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, chat_completion = await client.chat.completions.create( model=model_name, messages=messages, - max_completion_tokens=10, + max_completion_tokens=8, temperature=0.0, ) output = chat_completion.choices[0].message.content @@ -282,7 +282,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, stream = await client.chat.completions.create( model=model_name, messages=messages, - max_completion_tokens=10, + max_completion_tokens=8, temperature=0.0, stream=True, ) @@ -332,7 +332,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, chat_completion = await client.chat.completions.create( model=model_name, messages=messages, - max_completion_tokens=10, + max_completion_tokens=8, temperature=0.0, ) output = chat_completion.choices[0].message.content @@ -342,7 +342,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, stream = await client.chat.completions.create( model=model_name, messages=messages, - max_completion_tokens=10, + max_completion_tokens=8, temperature=0.0, stream=True, ) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index a10b42ea3a4b..2509ef0d280a 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -2,13 +2,13 @@ # imports for guided decoding tests import json -import re from typing import Optional import jsonschema import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re import requests import torch from openai import BadRequestError, OpenAI diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 48ede50e98f7..f18fbb0a9c71 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -122,10 +122,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Call the function and get the result result = apply_hf_chat_template( - model_config, - tokenizer, + tokenizer=tokenizer, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, + model_config=model_config, tools=None, add_generation_prompt=mock_request.add_generation_prompt, continue_final_message=mock_request.continue_final_message, diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py new file mode 100644 index 000000000000..97124c85e0d3 --- /dev/null +++ b/tests/entrypoints/openai/test_classification.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import requests + +from vllm.entrypoints.openai.protocol import ClassificationResponse + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" +DTYPE = "float32" # Use float32 to avoid NaN issue + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--enforce-eager", + "--max-model-len", + "512", + "--dtype", + DTYPE, + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_single_input_classification(server: RemoteOpenAIServer, + model_name: str): + input_text = "This product was excellent and exceeded my expectations" + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": input_text + }, + ) + + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert output.object == "list" + assert output.model == MODEL_NAME + assert len(output.data) == 1 + assert hasattr(output.data[0], "label") + assert hasattr(output.data[0], "probs") + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_multiple_inputs_classification(server: RemoteOpenAIServer, + model_name: str): + input_texts = [ + "The product arrived on time and works perfectly", + "I'm very satisfied with my purchase, would buy again", + "The customer service was helpful and resolved my issue quickly", + "This product broke after one week, terrible quality", + "I'm very disappointed with this purchase, complete waste of money", + "The customer service was rude and unhelpful", + ] + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": input_texts + }, + ) + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert len(output.data) == len(input_texts) + for i, item in enumerate(output.data): + assert item.index == i + assert hasattr(item, "label") + assert hasattr(item, "probs") + assert len(item.probs) == item.num_classes + assert item.label in ["Default", "Spoiled"] + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): + long_text = "hello " * 600 + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": long_text, + "truncate_prompt_tokens": 5 + }, + ) + + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert len(output.data) == 1 + assert output.data[0].index == 0 + assert hasattr(output.data[0], "probs") + assert output.usage.prompt_tokens == 5 + assert output.usage.total_tokens == 5 + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, + model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": "test", + "truncate_prompt_tokens": 513 + }, + ) + + error = classification_response.json() + assert classification_response.status_code == 400 + assert error["object"] == "error" + assert "truncate_prompt_tokens" in error["message"] + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": "" + }, + ) + + error = classification_response.json() + assert classification_response.status_code == 400 + assert error["object"] == "error" + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_batch_classification_empty_list(server: RemoteOpenAIServer, + model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": [] + }, + ) + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert output.object == "list" + assert isinstance(output.data, list) + assert len(output.data) == 0 diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 1d9aa4972b70..9d12f27a2b87 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - # imports for guided decoding tests import json -import re import shutil from tempfile import TemporaryDirectory from typing import Optional @@ -11,6 +9,7 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py new file mode 100644 index 000000000000..dad76b54c5e9 --- /dev/null +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +# downloading lora to test lora requests +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--enable-auto-tool-choice", + "--guided-decoding-backend", + "xgrammar", + "--tool-call-parser", + "hermes" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_required_tool_use(client: openai.AsyncOpenAI, model_name: str): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to find the weather for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "unit": { + "type": "string", + "description": + "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "unit"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to get the forecast for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "days": { + "type": + "integer", + "description": + "Number of days to get the forecast for (1-7)", + }, + "unit": { + "type": "string", + "description": + "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "days", "unit"], + }, + }, + }, + ] + + messages = [ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": + "user", + "content": + "Can you tell me what the current weather is in Berlin and the "\ + "forecast for the next 5 days, in fahrenheit?", + }, + ] + + # Non-streaming test + chat_completion = await client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice="required", + ) + + assert chat_completion.choices[0].message.tool_calls is not None + assert len(chat_completion.choices[0].message.tool_calls) > 0 + + # Streaming test + stream = await client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice="required", + stream=True, + ) + + output = [] + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) + + assert len(output) > 0 diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py new file mode 100644 index 000000000000..b7ee3e33c2d2 --- /dev/null +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import io +import shutil +from tempfile import TemporaryDirectory + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +import torch +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoConfig, AutoTokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +LORA_NAME = "typeof/zephyr-7b-beta-lora" + +CONFIG = AutoConfig.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def default_server_args( + zephyr_lora_files, + zephyr_lora_added_tokens_files, +) -> list[str]: + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # Prompt Embeds server args + "--enable-prompt-embeds", + "--no-enable-chunked-prefill", + ] + + +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def server_with_prompt_embeds(default_server_args, request): + if request.param: + default_server_args.append(request.param) + + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_prompt_embeds(server_with_prompt_embeds): + async with server_with_prompt_embeds.get_async_client() as async_client: + yield async_client + + +def create_dummy_embeds(num_tokens: int = 5) -> str: + """Create dummy embeddings and return them as base64 encoded string.""" + dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) + buffer = io.BytesIO() + torch.save(dummy_embeds, buffer) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test case: Single prompt embeds input + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + # Test case: batch completion with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + assert len(completion.choices) == 2 + assert len(completion.choices[0].text) >= 1 + assert len(completion.choices[1].text) >= 1 + + # Test case: streaming with prompt_embeds + encoded_embeds = create_dummy_embeds() + single_completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + single_output = single_completion.choices[0].text + + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": encoded_embeds}) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + # Test case: batch streaming with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + chunks_stream_embeds: list[list[str]] = [[], []] + finish_reason_count = 0 + async for chunk in stream: + chunks_stream_embeds[chunk.choices[0].index].append( + chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 2 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert len(chunks_stream_embeds[0]) > 0 + assert len(chunks_stream_embeds[1]) > 0 + + # Test case: mixed text and prompt_embeds + encoded_embeds = create_dummy_embeds() + completion_mixed = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices) == 2 + completion_text_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + ) + completion_embeds_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + # Embeddings responses should be handled first + assert completion_mixed.choices[0].text == completion_embeds_only.choices[ + 0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[ + 0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_errors_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test error case: invalid prompt_embeds + with pytest.raises(BadRequestError): + await client_with_prompt_embeds.completions.create( + prompt="", + model=model_name, + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": "invalid_base64"}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_logprobs_and_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, + model_name: str): + # Test case: Logprobs using prompt_embeds + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": encoded_embeds}) + + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 + + # Test case: Log probs with batch completion and prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + + assert len(completion.choices) == 2 + for choice in completion.choices: + logprobs = choice.logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 1ccb803a328d..cae2a3b59553 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Final + import pytest import schemathesis +from hypothesis import settings from schemathesis import GenerationConfig from ...utils import RemoteOpenAIServer @@ -9,6 +12,8 @@ MODEL_NAME = "HuggingFaceTB/SmolVLM-256M-Instruct" MAXIMUM_IMAGES = 2 +DEFAULT_TIMEOUT_SECONDS: Final[int] = 10 +LONG_TIMEOUT_SECONDS: Final[int] = 60 @pytest.fixture(scope="module") @@ -42,8 +47,58 @@ def get_schema(server): schema = schemathesis.from_pytest_fixture("get_schema") +@schemathesis.hook +def before_generate_case(context: schemathesis.hooks.HookContext, strategy): + op = context.operation + assert op is not None + + def no_file_type(case: schemathesis.models.Case): + """ + This filter skips test cases for the `POST /tokenize` endpoint where the + HTTP request body uses `"type": "file"` in any message's content. + We expect these cases to fail because that type isn't implemented here + https://github.com/vllm-project/vllm/blob/0b34593017953051b3225b1483ce0f4670e3eb0e/vllm/entrypoints/chat_utils.py#L1038-L1095 + + Example test cases that are skipped: + curl -X POST -H 'Content-Type: application/json' \ + -d '{"messages": [{"role": "assistant"}, {"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \ + http://localhost:8000/tokenize + + curl -X POST -H 'Content-Type: application/json' \ + -d '{"messages": [{"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \ + http://localhost:8000/tokenize + """ # noqa: E501 + if (op.method.lower() == "post" and op.path == "/tokenize" + and hasattr(case, "body") and isinstance(case.body, dict) + and "messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0): + for message in case.body["messages"]: + if not isinstance(message, dict): + continue + content = message.get("content", []) + if not isinstance(content, list) or len(content) == 0: + continue + if any(item.get("type") == "file" for item in content): + return False + return True + + return strategy.filter(no_file_type) + + @schema.parametrize() @schema.override(headers={"Content-Type": "application/json"}) -async def test_openapi_stateless(case): +@settings(deadline=LONG_TIMEOUT_SECONDS * 1000) +def test_openapi_stateless(case: schemathesis.Case): + key = ( + case.operation.method.upper(), + case.operation.path, + ) + timeout = { + # requires a longer timeout + ("POST", "/v1/chat/completions"): + LONG_TIMEOUT_SECONDS, + }.get(key, DEFAULT_TIMEOUT_SECONDS) + #No need to verify SSL certificate for localhost - await case.call_and_validate(verify=False) + case.call_and_validate(verify=False, timeout=timeout) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index f889189a9968..e384915899d3 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # imports for guided decoding tests -import re - import openai import pytest +import regex as re from ...utils import RemoteOpenAIServer @@ -32,7 +31,7 @@ async def test_out_of_vocab_token_ids(): client = remote_server.get_async_client() with pytest.raises(openai.BadRequestError, - match=re.compile('.*out of vocabulary.*')): + match=re.compile('.*out of vocabulary.*').pattern): await client.completions.create(model=model_name, prompt=[999999], max_tokens=5, @@ -46,9 +45,10 @@ async def test_reject_multistep_with_guided_decoding(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match=re.compile( - '.*Guided decoding .* multi-step decoding.*')): + with pytest.raises( + openai.BadRequestError, + match=re.compile( + '.*Guided decoding .* multi-step decoding.*').pattern): await client.completions.create( model=model_name, prompt="Hello", diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index b756680ea9f2..b373f2912752 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 - -import math from typing import Any import pytest @@ -92,7 +90,7 @@ def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer, hf_outputs = run_transformers(runner, model, text_pairs) for i in range(len(vllm_outputs)): - assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01) + assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer, model: dict[str, Any], runner): @@ -124,7 +122,7 @@ def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer, hf_outputs = run_transformers(runner, model, text_pairs) for i in range(len(vllm_outputs)): - assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01) + assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer, model: dict[str, Any], runner): @@ -150,7 +148,7 @@ def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer, hf_outputs = run_transformers(runner, model, text_pairs) for i in range(len(vllm_outputs)): - assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01) + assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) def test_score_max_model_len(self, server: RemoteOpenAIServer, model: dict[str, Any]): diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py new file mode 100644 index 000000000000..f1ab7223048d --- /dev/null +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +import gc +import json +import tempfile + +import openai +import pytest +import pytest_asyncio +import torch.cuda + +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model) + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "unsloth/llama-3.2-1b-Instruct" +LORA_PATH = "davzoku/finqa_adapter_1b" + + +def _cleanup(): + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(autouse=True) +def cleanup(): + _cleanup() + + +@pytest.fixture(scope='module') +def tmp_dir(): + with tempfile.TemporaryDirectory() as path: + yield path + + +@pytest.fixture(scope='module') +def model_uri(tmp_dir): + yield f"{tmp_dir}/model.tensors" + + +@pytest.fixture(scope="module") +def tensorize_model_and_lora(tmp_dir, model_uri): + tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, + lora_dir=tmp_dir) + args = EngineArgs(model=MODEL_NAME, device="cuda") + + tensorize_lora_adapter(LORA_PATH, tensorizer_config) + tensorize_vllm_model(args, tensorizer_config) + + # Manually invoke a _cleanup() here, as the cleanup() + # fixture won't be guaranteed to be called after this + # when this fixture is used for a test + _cleanup() + yield + + +@pytest.fixture(scope="module") +def server(model_uri, tensorize_model_and_lora): + model_loader_extra_config = { + "tensorizer_uri": model_uri, + } + + ## Start OpenAI API server + args = [ + "--load-format", "tensorizer", "--device", "cuda", + "--model-loader-extra-config", + json.dumps(model_loader_extra_config), "--enable-lora" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): + _cleanup() + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.model == MODEL_NAME + assert len(completion.choices) == 1 + assert len(completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 663b722426c5..9773f3e45b99 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -145,6 +145,83 @@ async def test_tokenize_chat( } +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name,tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], +) +async def test_tokenize_chat_with_tools( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, + tokenizer_mode="fast") + + for add_generation in [False, True]: + for add_special in [False, True]: + conversation = [{ + "role": + "user", + "content": + "What's the weather like in Paris today?", + }] + + tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string" + } + }, + }, + }, + }] + + for continue_final in [False, True]: + if add_generation and continue_final: + continue + if continue_final: + conversation.append({ + "role": "assistant", + "content": "Sure," + }) + + prompt = tokenizer.apply_chat_template( + add_generation_prompt=add_generation, + continue_final_message=continue_final, + conversation=conversation, + tools=tools, + tokenize=False, + ) + tokens = tokenizer.encode(prompt, + add_special_tokens=add_special) + + response = requests.post( + server.url_for("tokenize"), + json={ + "add_generation_prompt": add_generation, + "continue_final_message": continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name, + "tools": tools, + }, + ) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192, + } + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py new file mode 100644 index 000000000000..92ba1376e200 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, run_tool_extraction_streaming) +from vllm.entrypoints.openai.protocol import FunctionCall +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager + +# Test cases similar to pythonic parser but with Llama4 specific format +SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" +SIMPLE_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "LA", "metric": "C"}', +) +MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', " + "age=9, " + "address={'city': 'LA', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])]") +MORE_TYPES_FUNCTION_CALL = FunctionCall( + name="register_user", + arguments='{"name": "Doe", ' + '"age": 9, ' + '"address": {"city": "LA", "state": "CA"}, ' + '"role": null, ' + '"passed_test": true, ' + '"aliases": ["John", "Johnny"]}', +) +PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]" +PARAMETERLESS_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{}', +) +EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]" +EMPTY_DICT_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"additional_data": {}}', +) +EMPTY_LIST_FUNCTION_OUTPUT = "[do_something_cool(steps=[])]" +EMPTY_LIST_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"steps": []}', +) +ESCAPED_STRING_FUNCTION_OUTPUT = ( + r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]") +ESCAPED_STRING_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', +) +PYTHON_TAG_FUNCTION_OUTPUT = ( + "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>") + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tool_call(streaming: bool): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + model_output = "How can I help you today?" + + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=streaming) + + assert content == model_output + assert len(tool_calls) == 0 + + +test_str = "<|python_start|>" +test_str += "[get_weather(city='LA', metric='C')," +test_str += "register_user(name='Doe', age=9)]" +TEST_CASES = [ + pytest.param(True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="simple_streaming"), + pytest.param(False, + SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="simple_nonstreaming"), + pytest.param(True, + MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming"), + pytest.param(False, + MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming"), + pytest.param(True, + PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming"), + pytest.param(False, + PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming"), + pytest.param(True, + EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming"), + pytest.param(False, + EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming"), + pytest.param(True, + EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming"), + pytest.param(False, + EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming"), + pytest.param(True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming"), + pytest.param(False, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming"), + pytest.param( + True, + "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_streaming"), + pytest.param( + False, + "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_nonstreaming"), + pytest.param(True, + PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="python_tag_streaming"), + pytest.param(False, + PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="python_tag_nonstreaming"), + pytest.param(True, + test_str, [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_streaming"), + pytest.param(False, + "<|python_start|>[get_weather(city='LA', metric='C'), " + + "register_user(name='Doe', age=9)]", [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_nonstreaming"), +] + + +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", + TEST_CASES) +def test_tool_call(streaming: bool, model_output: str, + expected_tool_calls: list[FunctionCall]): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=streaming) + + assert len(tool_calls) == len(expected_tool_calls) + for actual, expected in zip(tool_calls, expected_tool_calls): + assert actual.type == "function" + assert actual.function == expected + + +def test_streaming_tool_call_with_large_steps(): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + model_output_deltas = [ + "<|python_start|>[get_weather(city='LA', metric='C'), " + "get_weather(), " + "do_something_cool(steps=[])]<|python_end|>", + ] + + reconstructor = run_tool_extraction_streaming( + tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + + assert reconstructor.other_content == "" + assert len(reconstructor.tool_calls) == 3 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL + assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index 6ad5aa26ffa1..ab8f4bd678fd 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -32,7 +32,7 @@ def append_delta(self, delta: DeltaMessage): assert len(delta.tool_calls) < 2, ( "Streaming should include only one tool call per update.") for call_delta in delta.tool_calls: - assert call_delta.type == "function", ( + assert call_delta.type is None or call_delta.type == "function", ( "Streaming tool calls should only emit function calls. Got " f"{call_delta.type}") current_tool_call = self.tool_calls[ diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index bcb25ed99062..9f1f2321d9e6 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -793,10 +793,10 @@ def get_conversation(is_hf: bool): ) vllm_result = apply_hf_chat_template( - model_config, - tokenizer, + tokenizer=tokenizer, conversation=conversation, chat_template=None, + model_config=model_config, tools=None, add_generation_prompt=True, ) @@ -845,10 +845,10 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( - model_config, tokenizer, chat_template=None, tools=tools, + model_config=model_config, ) assert isinstance(chat_template, str) @@ -890,10 +890,10 @@ def test_resolve_content_format_hf_defined(model, expected_format): # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( - model_config, tokenizer, chat_template=None, tools=None, + model_config=model_config, ) assert isinstance(chat_template, str) @@ -903,11 +903,11 @@ def test_resolve_content_format_hf_defined(model, expected_format): print(_try_extract_ast(chat_template)) resolved_format = resolve_chat_template_content_format( - model_config, None, # Test detecting the tokenizer's chat_template None, "auto", tokenizer, + model_config=model_config, ) assert resolved_format == expected_format @@ -949,10 +949,10 @@ def test_resolve_content_format_fallbacks(model, expected_format): # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( - model_config, tokenizer, chat_template=None, tools=None, + model_config=model_config, ) assert isinstance(chat_template, str) @@ -962,11 +962,11 @@ def test_resolve_content_format_fallbacks(model, expected_format): print(_try_extract_ast(chat_template)) resolved_format = resolve_chat_template_content_format( - model_config, None, # Test detecting the tokenizer's chat_template None, "auto", tokenizer, + model_config=model_config, ) assert resolved_format == expected_format @@ -1021,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format): print(_try_extract_ast(chat_template)) resolved_format = resolve_chat_template_content_format( - model_config, chat_template, None, "auto", dummy_tokenizer, + model_config=model_config, ) assert resolved_format == expected_format diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index e5650136f258..d9f956fbc7c0 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -148,6 +148,11 @@ def test_paged_attention( or (version == "rocm" and head_size not in (64, 128))): pytest.skip() + if (version == "rocm" and current_platform.is_navi() + and (kv_cache_dtype == "fp8" or head_size != 128 + or block_size != 16 or use_alibi)): + pytest.skip() + global PARTITION_SIZE current_platform.seed_everything(seed) @@ -275,6 +280,7 @@ def test_paged_attention( scale, block_tables, seq_lens, + None, block_size, max_seq_len, alibi_slopes, @@ -286,7 +292,7 @@ def test_paged_attention( opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, + seq_lens, None, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index b0414244c215..58da01f0ebbf 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -102,7 +102,10 @@ def test_env( block_size, False, use_mla=use_mla) - assert backend.get_name() == name + if use_v1 and name != "TRITON_MLA": + assert backend.get_name() == f"{name}_VLLM_V1" + else: + assert backend.get_name() == name else: with pytest.raises(ValueError) as exc_info: get_attn_backend(16, @@ -185,8 +188,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - monkeypatch.setattr(torch.cuda, "get_device_capability", lambda: - (7, 5)) + monkeypatch.setattr(torch.cuda, + "get_device_capability", + lambda _=None: (7, 5)) backend = get_attn_backend(16, torch.float16, None, 16, False) assert backend.get_name() != STR_FLASH_ATTN_VAL diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 4cf7bcb01d4d..6ffe27abf709 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") # If attention backend is None # If use_mla is true @@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ROCM_USE_AITER", "1") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") diff --git a/tests/kernels/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py similarity index 98% rename from tests/kernels/test_triton_unified_attention.py rename to tests/kernels/attention/test_triton_unified_attention.py index 50da8e5fd5cd..4e15d00255a4 100644 --- a/tests/kernels/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -99,6 +99,9 @@ def test_triton_unified_attn( ) -> None: torch.set_default_device("cuda") + if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: + pytest.skip("block size must be at least 32 for fp8") + current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index d81c7487b88c..f327deb0e549 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, return (batch_size, seq_len, num_heads * head_size) +# For testing sliced tensors +def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int, + head_size: int) -> tuple[int, ...]: + return (batch_size, seq_len, num_heads, head_size + 64) + + def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, head_size: int) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size) -TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] +TENSORS_SHAPES_FN = [ + _get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape +] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -79,6 +87,10 @@ def test_rotary_embedding( query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None + # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) @@ -140,6 +152,10 @@ def test_batched_rotary_embedding( query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None + # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index 4e54861005f2..8383f943b9fa 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("use_key", [True, False]) +@pytest.mark.parametrize("head_stride_is_contingous", [True, False]) def test_rotary_embedding_opcheck(dist_init, device, max_position, is_neox_style, rotary_dim, head_size, - seq_len, use_key): + seq_len, use_key, head_stride_is_contingous): batch_size = 1 base = 10000 num_heads = 7 @@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + head_stride = head_size + (64 if head_stride_is_contingous else 0) + query = torch.randn(batch_size, seq_len, - num_heads * head_size, + num_heads, + head_stride, dtype=torch.float32, device=device) key = torch.randn_like(query) if use_key else None + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) offsets = torch.zeros(batch_size * seq_len, device=device, dtype=torch.long) rotary_embedding_opcheck(rot, positions, query, key, offsets) + + # if we have a contiguous head stride, test the alternate + # [..., num_heads * head_dim] shape/layout + if head_stride_is_contingous: + rotary_embedding_opcheck( + rot, positions, query.flatten(start_dim=-2), + key.flatten(start_dim=-2) if use_key else None) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py new file mode 100644 index 000000000000..7d369edfc86a --- /dev/null +++ b/tests/kernels/moe/test_batched_moe.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import pytest +import torch +import triton.language as tl + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + invoke_moe_batched_triton_kernel) + + +@dataclass +class BatchedMMConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + K: int + N: int + + +@dataclass +class BatchedMMTensors: + A: torch.Tensor # [E, max_tokens, K] + B: torch.Tensor # [E, K, N] - column major + C: torch.Tensor # [E, max_tokens, N] + num_expert_tokens: torch.Tensor # [E] + + @staticmethod + def make_tensors(config: BatchedMMConfig): + A = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.dtype) / 10 + B = torch.randn((config.num_experts, config.N, config.K), + device="cuda", + dtype=config.dtype) + C = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.N), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) + return BatchedMMTensors(A, B, C, num_expert_tokens) + + +def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e] + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + + return C + + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", + [32, 64, 128, 192, 224, 256, 512]) +@pytest.mark.parametrize("K", [128, 256, 1024]) +@pytest.mark.parametrize("N", [128, 256, 512, 1024]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, + N: int, dtype: torch.dtype): + + config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) + tensors = BatchedMMTensors.make_tensors(config) + + test_output = tensors.C + ref_output = test_output.clone() + + compute_tl_dtype = { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32 + }[test_output.dtype] + invoke_moe_batched_triton_kernel( + tensors.A, + tensors.B, + test_output, + tensors.num_expert_tokens, + compute_tl_dtype, + # Quantization data + None, + None, + None, + # Quantization schemes + False, + False, + False, + config={ + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16 + }) + + ref_output = ref_impl(tensors.A, tensors.B, ref_output, + tensors.num_expert_tokens) + + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[test_output.dtype] + + torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 975cd418a171..7db4fe0f46e3 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -30,6 +30,11 @@ (224, 3072, 1536), ] +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @dataclasses.dataclass class MOETensors: @@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'topk_weights': topk_weights, - 'topk_ids_': topk_ids, + 'topk_ids': topk_ids, 'ab_strides1': moe_tensors.ab_strides1, 'c_strides1': moe_tensors.c_strides1, 'ab_strides2': moe_tensors.ab_strides2, @@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. @@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): dtype = torch.half mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. @@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index abf3e3667a75..299279390fe0 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -12,10 +12,13 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -30,6 +33,10 @@ EP_SIZE = [1, 4] TOP_KS = [2, 6] +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @@ -68,31 +75,33 @@ def test_fused_moe( else: e_map = None - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w1, w2, score, topk, e_map) + iterative_output = iterative_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, @@ -113,7 +122,6 @@ def test_fused_moe( def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, group_size: int, has_zp: bool, weight_bits: int): - print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -192,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=e, - expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + with set_current_vllm_config(vllm_config): + triton_output = fused_moe(a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -222,9 +232,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" + # clear the cache before every test + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + if dtype == torch.float32: + pytest.skip("AITER ROCm test skip for float32") + # Instantiate our and huggingface's MoE blocks config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") @@ -286,20 +303,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, atol=mixtral_moe_tol[dtype]) -@pytest.mark.parametrize("m", [1, 123, 666]) -@pytest.mark.parametrize("n", [128, 1024]) -@pytest.mark.parametrize("k", [256, 2048]) -@pytest.mark.parametrize("e", [4, 12]) -@pytest.mark.parametrize("topk", [2, 3]) -@pytest.mark.parametrize("ep_size", [1, 4]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("group_size", [-1, 32, 128]) -@pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("quant_type", [ - scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn -]) -@pytest.mark.parametrize("is_k_full", [True, False]) +def marlin_moe_generate_valid_test_cases(): + import itertools + m_list = [1, 123, 666] + n_list = [128, 1024] + k_list = [256, 2048] + e_list = [4, 12] + topk_list = [2, 3] + ep_size_list = [1, 4] + dtype_list = [torch.half, torch.bfloat16] + group_size_list = [-1, 16, 32, 128] + act_order_list = [True, False] + quant_type_list = [ + scalar_types.float4_e2m1f, + scalar_types.float8_e4m3fn, + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.uint8b128, + ] + is_k_full_list = [True, False] + + all_combinations = itertools.product(m_list, n_list, k_list, e_list, + topk_list, ep_size_list, dtype_list, + group_size_list, act_order_list, + quant_type_list, is_k_full_list) + + def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, + quant_type, is_k_full): + + if quant_type == scalar_types.float8_e4m3fn and \ + group_size not in [-1, 128]: + return False + if quant_type == scalar_types.float4_e2m1f and group_size != 16: + return False + if quant_type != scalar_types.float4_e2m1f and group_size == 16: + return False + + # Filter act_order + if act_order: + if group_size in (-1, k, n): + return False + if quant_type not in [scalar_types.uint4b8]: + return False + elif not is_k_full: + return False + + return True + + cases = [] + for case in all_combinations: + if is_invalid(*case): + cases.append(case) + return cases + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size," + "act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases()) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, @@ -337,6 +398,11 @@ def test_fused_marlin_moe( if not is_k_full: return + if quant_type == scalar_types.float4_e2m1f and group_size != 16: + return + if quant_type != scalar_types.float4_e2m1f and group_size == 16: + return + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 @@ -354,12 +420,27 @@ def test_fused_marlin_moe( w_ref1_l = [] qweight1_l = [] scales1_l = [] + global_scale1_l = [] zeros1_l = [] g_idx1_l = [] sort_indices1_l = [] for i in range(w1.shape[0]): - if has_zp: + if quant_type == scalar_types.float4_e2m1f: + w_ref1, qweight1, scales1, global_scale1 = \ + rand_marlin_weight_fp4_like(w1[i], group_size) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + global_scale1_l.append(global_scale1) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( + w1[i], group_size) + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + elif has_zp: w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( w1[i].transpose(1, 0), quant_type, group_size) @@ -367,7 +448,7 @@ def test_fused_marlin_moe( qweight1_l.append(qweight1) scales1_l.append(scales1) zeros1_l.append(zeros1) - elif quant_type != scalar_types.float8_e4m3fn: + else: test_perm = torch.randperm(k) w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ marlin_quantize(w1[i].transpose(1, 0), quant_type, @@ -378,16 +459,11 @@ def test_fused_marlin_moe( scales1_l.append(scales1) g_idx1_l.append(g_idx1) sort_indices1_l.append(sort_indices1) - else: - w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( - w1[i], group_size) - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweight1_l).contiguous() scales1 = stack_and_dev(scales1_l) + global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None @@ -395,12 +471,27 @@ def test_fused_marlin_moe( w_ref2_l = [] qweight2_l = [] scales2_l = [] + global_scale2_l = [] zeros2_l = [] g_idx2_l = [] sort_indices2_l = [] for i in range(w2.shape[0]): - if has_zp: + if quant_type == scalar_types.float4_e2m1f: + w_ref2, qweight2, scales2, global_scale2 = \ + rand_marlin_weight_fp4_like(w2[i], group_size) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + global_scale2_l.append(global_scale2) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( + w2[i], group_size) + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + elif has_zp: w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( w2[i].transpose(1, 0), quant_type, group_size) @@ -408,7 +499,7 @@ def test_fused_marlin_moe( qweight2_l.append(qweight2) scales2_l.append(scales2) zeros2_l.append(zeros2) - elif quant_type != scalar_types.float8_e4m3fn: + else: test_perm = torch.randperm(n) w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ marlin_quantize(w2[i].transpose(1, 0), quant_type, @@ -419,26 +510,21 @@ def test_fused_marlin_moe( scales2_l.append(scales2) g_idx2_l.append(g_idx2) sort_indices2_l.append(sort_indices2) - else: - w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( - w2[i], group_size) - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweight2_l).contiguous() scales2 = stack_and_dev(scales2_l) + global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score, topk, False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, @@ -451,6 +537,8 @@ def test_fused_marlin_moe( topk_ids, global_num_experts=e, expert_map=e_map, + global_scale1=global_scale1, + global_scale2=global_scale2, g_idx1=g_idx1, g_idx2=g_idx2, sort_indices1=sort_indices1, @@ -487,3 +575,21 @@ def test_moe_align_block_size_opcheck(): opcheck(torch.ops._moe_C.moe_align_block_size, (topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): + input = torch.randn((m, topk, k), device="cuda", dtype=dtype) + actual = torch.empty((m, k), device="cuda", dtype=dtype) + + expected = input.sum(dim=1) + torch.ops._moe_C.moe_sum(input, actual) + + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) + + opcheck(torch.ops._moe_C.moe_sum, (input, actual)) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index dfcd61f77587..10e6ac64df87 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_unpermute) + moe_permute, moe_permute_unpermute_supported, moe_unpermute) from vllm.platforms import current_platform NUM_EXPERTS = [16, 64] @@ -167,6 +167,8 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, n_expert: int, ep_size: int, dtype: torch.dtype, align_block_size: Optional[int]): + if not moe_permute_unpermute_supported(): + pytest.skip("moe_permute_unpermute is not supported on this platform.") fill_invalid_expert = 0 ep_rank = np.random.randint(0, ep_size) expert_map = None diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py new file mode 100644 index 000000000000..ae63b379f39d --- /dev/null +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) +from tests.kernels.utils import torch_moe +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1536), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + quant_blocksize = 16 + round_up = lambda x, y: (x + y - 1) // y * y + sf_w1_2n = round_up(2 * n, 128) + sf_w1_k = round_up(k // quant_blocksize, 4) + w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), + device="cuda", + dtype=torch.float8_e4m3fn) + + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + sf_w2_k = round_up(k, 128) + sf_w2_n = round_up(n // quant_blocksize, 4) + w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), + device="cuda", + dtype=torch.float8_e4m3fn) + + w1_q = torch.empty((e, 2 * n, k // 2), + device="cuda", + dtype=torch.uint8) + w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) + w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + + for expert in range(e): + w1_amax = torch.abs(w1).max().to(torch.float32) + w2_amax = torch.abs(w2).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1[expert], w1_gs[expert]) + + w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2[expert], w2_gs[expert]) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + + cutlass_output = cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_q, + w1_blockscale=w1_blockscale, + w1_alphas=(1 / w1_gs), + a2_gscale=a2_gs, + w2_fp4=w2_q, + w2_blockscale=w2_blockscale, + w2_alphas=(1 / w2_gs), + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=e, + device=a.device, + ) + + # Reference check: + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=w1.dtype, + device=w1.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=quant_blocksize) + + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + + torch.testing.assert_close(torch_output, + cutlass_output, + atol=1e-1, + rtol=1e-1) + + +if __name__ == "__main__": + test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py new file mode 100644 index 000000000000..8c4a2c3fa440 --- /dev/null +++ b/tests/kernels/moe/test_pplx_moe.py @@ -0,0 +1,691 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_pplx_moe.py`. +""" +import dataclasses +import os +import traceback +from typing import Callable, Optional + +import pytest +import torch + +try: + from pplx_kernels import AllToAll + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) + has_pplx = True +except ImportError: + has_pplx = False + +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import override_config +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, + get_default_config) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.platforms import current_platform + +PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), + (222, 2048, 1024)] + +PPLX_MOE_COMBOS = [ + (1, 128, 128), + (2, 128, 512), + (3, 1024, 2048), + (32, 128, 1024), + (45, 512, 2048), + (64, 1024, 1024), + (222, 1024, 2048), +] + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [1, 2, 6] + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +P = ParamSpec("P") + +requires_pplx = pytest.mark.skipif( + not has_pplx, + reason="Requires PPLX kernels", +) + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + args, + nprocs=world_local_size, + join=True, + ) + + +def torch_prepare( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + max_num_tokens: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + num_tokens, hidden_dim = a.shape + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) + + assert tokens_per_expert.numel() == num_experts + + if max_num_tokens is None: + max_num_tokens = int(tokens_per_expert.max().item()) + + b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim), + dtype=a.dtype, + device=a.device) + + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = token_counts[expert_id] + b_a[expert_id, idx:idx + 1, :] = a[token, :] + token_counts[expert_id] = token_counts[expert_id] + 1 + + return b_a, tokens_per_expert + + +def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + num_tokens = topk_ids.shape[0] + num_experts = b_out.shape[0] + K = b_out.shape[-1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx + + 1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return out + + +def torch_batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + num_experts = w1.shape[0] + b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and w2.shape[1] == K + out = torch.zeros((num_experts, max_num_tokens, K), + dtype=b_a.dtype, + device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), + dtype=b_a.dtype, + device=b_a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + torch.ops._C.silu_and_mul( + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) + + return torch_finalize(out, topk_weight, topk_ids) + + +def batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + num_experts = w1.shape[0] + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0), + BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) + + return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) + + +# Note: same as torch_moe but with fused_topk factored out. +def torch_moe2( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) + batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) + + torch.testing.assert_close(baseline_output, + torch_output, + atol=2e-2, + rtol=0) + torch.testing.assert_close(baseline_output, + batched_output, + atol=2e-2, + rtol=0) + + +def rank_chunk(num: int, r: int, w: int) -> int: + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: + chunk = rank_chunk(t.shape[0], r, w) + return t[(r * chunk):(r + 1) * chunk] + + +def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, + topk_weight: torch.Tensor, topk_ids: torch.Tensor, + num_experts: int) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + + assert torch.cuda.current_device() == pgi.local_rank + + topk = topk_ids.shape[1] + num_tokens, hidden_dim = a.shape + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + max_num_tokens = rank_chunk(num_tokens, 0, world_size) + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), + ) + + topk_ids = topk_ids.to(dtype=torch.uint32) + + prepare_finalize = PplxPrepareAndFinalize( + ata, + max_num_tokens, + world_size, + rank, + dp_size, + a.dtype, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare( + a_chunk, + None, + None, + chunk_topk_weight, + chunk_topk_ids, + num_experts, + None, + False, + ) + + b_a = b_a * 1.5 + + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + + prepare_finalize.finalize( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + False, + ) + + torch.cuda.synchronize() + + ata.destroy() + + num_tokens = a_chunk.shape[0] + + return out[:num_tokens] + + +def _pplx_prepare_finalize( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + score: torch.Tensor, + topk: torch.Tensor, + num_experts: int, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + device = pgi.device + + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + k = a.shape[1] + + a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) + + torch_output = (a_rep.view(-1, topk, k) * 1.5 * + topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to( + a.dtype) + + pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, + num_experts) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +# TODO (bnell): this test point does not work for odd M due to how the test is +# written, not due to limitations of the pplx kernels. The pplx_moe +# test below is able to deal with odd M. +@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@requires_pplx +def test_pplx_prepare_finalize( + mnk: tuple[int, int, int], + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + m, n, k = mnk + world_size, dp_size = world_dp_size + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, + topk, e) + + +def pplx_moe( + rank: int, + world_size: int, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + use_compile: bool = True, + use_cudagraphs: bool = True, +) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + + device = torch.device("cuda", rank) + hidden_dim = a.shape[1] + num_experts = w1.shape[0] + block_size = 128 + topk = topk_ids.shape[1] + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), + ) + + topk_ids = topk_ids.to(dtype=torch.uint32) + + prepare_finalize = PplxPrepareAndFinalize( + ata, + max_num_tokens, + world_size, + rank, + dp_size, + ) + + experts = BatchedTritonExperts(max_num_tokens=a.shape[0], + world_size=world_size, + dp_size=dp_size) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + # Note: workers with the same dp_rank must use the exact same inputs. + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + # Chunking weights like this only works for batched format + w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) + w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + + if use_compile: + _fused_experts = torch.compile(fused_experts, + backend='inductor', + fullgraph=True) + else: + _fused_experts = fused_experts + + out = _fused_experts(a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + if use_cudagraphs: + out.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + out = _fused_experts(a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + torch.cuda.synchronize() + graph.replay() + + torch.cuda.synchronize() + + ata.destroy() + + return out + + +def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): + assert torch.cuda.current_device() == pgi.local_rank + + num_experts = w1.shape[0] + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + prepare_finalize = BatchedPrepareAndFinalize( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + rank=rank, + ) + + experts = BatchedExperts(max_num_tokens=a.shape[0], + world_size=1, + dp_size=1) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + # Note: workers with the same dp_rank must use the exact same inputs. + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + out = fused_experts( + a_chunk, + # Chunking weights like this only works for batched format + chunk_by_rank(w1, rank, world_size).to(device), + chunk_by_rank(w2, rank, world_size).to(device), + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + return out + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + m, k = a.shape + e, _, n = w2.shape + + moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) + + with set_current_vllm_config(vllm_config), override_config(moe_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2, + topk_weight, topk_ids) + # TODO (bnell): fix + re-enable + #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, + # topk_ids) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@requires_pplx +def test_pplx_moe( + mnk: tuple[int, int, int], + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + m, n, k = mnk + world_size, dp_size = world_dp_size + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py new file mode 100644 index 000000000000..b0d34ddfd423 --- /dev/null +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# This is a test for the AITER ops. +# It tests if the AITER ops are +# 1. correctly registered as custom ops +# 2. correctly defined the relationship between +# implementation and fake function +# 3. can be used with torch.compile +# This file will be skipped if AITER is not installed +# and the platform is not ROCm. + +import importlib.util + +import pytest +import torch + +# this import statement is needed to ensure the ops are registered +import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401 +from vllm.platforms import current_platform + +# need to import once to ensure the ops are registered +# Check if aiter package is installed +aiter_available = importlib.util.find_spec("aiter") is not None + +pytestmark = pytest.mark.skipif( + not (current_platform.is_rocm() and aiter_available), + reason="AITER ops are only available on ROCm with aiter package installed") + + +def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): + """Test that the custom op is correctly registered.""" + # Check if the op exists in torch.ops.vllm + assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk') + + # Check if the op is callable + assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk) + + +def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): + """Test that the op can be used with torch.compile.""" + # Create test tensors + token = 64 + expert = 256 + num_expert_group = 8 + topk = 8 + topk_group = 4 + renormalize = True + scale_factor = 1.0 + + gating_output = torch.randn((token, expert), + dtype=torch.bfloat16, + device="cuda") + e_score_correction_bias = torch.randn((expert, ), + dtype=torch.bfloat16, + device="cuda") + + device = gating_output.device + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), + dtype=torch.float32, + device=device) + + # Define a function that uses the op + def biased_grouped_topk_fn(gating_output, e_score_correction_bias, + topk_weights, topk_ids): + return torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, e_score_correction_bias, topk_weights, topk_ids, + num_expert_group, topk_group, renormalize, scale_factor) + + # Verify the op's fake implementation + torch.library.opcheck( + torch.ops.vllm.rocm_aiter_biased_grouped_topk, + (gating_output, e_score_correction_bias, topk_weights, topk_ids), + kwargs={ + "num_expert_group": num_expert_group, + "topk_group": topk_group, + "need_renorm": renormalize, + "routed_scaling_factor": scale_factor + }, + test_utils=("test_faketensor")) + + # Compile the function with appropriate settings + compiled_fn = torch.compile(biased_grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False) + + topk_weights_original = torch.empty((token, topk), + dtype=torch.float32, + device=device) + topk_ids_original = torch.empty((token, topk), + dtype=torch.int32, + device=device) + + topk_weights_compiled = torch.empty((token, topk), + dtype=torch.float32, + device=device) + topk_ids_compiled = torch.empty((token, topk), + dtype=torch.int32, + device=device) + + # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) + biased_grouped_topk_fn(gating_output, e_score_correction_bias, + topk_weights_original, topk_ids_original) + compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled, + topk_ids_compiled) + + # Sort the results for comparison since the order might not be deterministic + topk_ids_original, indices_original = torch.sort(topk_ids_original) + topk_weights_original = torch.gather(topk_weights_original, 1, + indices_original) + + topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, + indices_compiled) + + # Verify results match + assert torch.allclose(topk_weights_original, + topk_weights_compiled, + rtol=1e-2, + atol=1e-2) + assert torch.allclose(topk_ids_original, topk_ids_compiled) diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 44734e9340aa..3b5838a99fa1 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,6 +7,7 @@ import torch from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.platforms import current_platform @@ -15,6 +16,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): """Matrix multiplication function that supports per-token input @@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization - ) + with set_current_vllm_config(vllm_config): + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, # using fp8 + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) # Check results rel_diff = (torch.mean( diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py new file mode 100644 index 000000000000..58eaeee1c0b8 --- /dev/null +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.scalar_type import scalar_types + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_nvfp4_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 38c7e461bb9c..ae05d61173f3 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -30,18 +30,22 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] -NUM_TOKENS = [7, 83, 2048] +NUM_TOKENS = [7, 2050] D = [512, 4096, 5120, 13824] -GROUP_SIZE = [64, 128, 256, 512] -M = [1, 7, 8, 83, 84, 512, 2048, 4096] -N = [128, 512, 1024, 4096, 7168, 7748, 13824] -K = [256, 4096, 5120, 3884, 13824, 16384] +GROUP_SIZE = [64, 128, 512] +M = [1, 7, 8, 83, 84, 4096] +N = [128, 512, 7168, 7748, 13824] +K = [256, 3884, 4096, 13824, 16384] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 192, 512, 1335, 2048] +M_moe = [1, 2, 7, 83, 128, 2048] +M_moe_dg = [128, 192, 1335, 2048] N_moe = [128, 256, 1024, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] @@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -258,6 +261,7 @@ def per_block_cast_to_fp8( @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): block_size = [block_m, block_m] dtype = torch.bfloat16 - # only aligned sizes - if (N % block_m != 0 or K % block_m != 0 or topk > E): - pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - - if N <= 512: - pytest.skip("Skipping N <= 512 until performance issues solved.") + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") - vllm_config = VllmConfig() + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") torch.manual_seed(seed) fp8_info = torch.finfo(torch.float8_e4m3fn) diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index 104f23fd7cd2..a4e9f83f0eaf 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -18,6 +18,10 @@ pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # For test def native_per_token_group_quant_int8(x, @@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8084d9bf2c2d..633addd421f4 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int, out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) + torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1) opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) @@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, return if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: return + if m % 4 != 0 and current_platform.has_device_capability(100): + return cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) diff --git a/tests/kernels/quantization/test_gguf.py b/tests/kernels/quantization/test_gguf.py index 6cf88604ec65..ad755fe7f7a0 100644 --- a/tests/kernels/quantization/test_gguf.py +++ b/tests/kernels/quantization/test_gguf.py @@ -8,7 +8,6 @@ from huggingface_hub import snapshot_download import vllm._custom_ops as ops -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf from vllm.platforms import current_platform @@ -35,11 +34,11 @@ def get_gguf_MoE_tensors( return GGUFReader(sample_file).tensors -DTYPES = [torch.half, torch.bfloat16, torch.float32] +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] # Hidden_size for testing, must match the sample file in HF repo, # we have `hidden_size = 256, 1024` for test in HF repo currently. HIDDEN_SIZES = [256, 1024] -NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing +NUM_TOKENS = [7, 2050] # Arbitrary values for testing SEEDS = [0] QUANT_TYPES = [ # i-matrix @@ -176,12 +175,11 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype) - act = SiluAndMul() output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"), torch.tensor(w2.data, device="cuda"), topk_weights, - topk_ids, quant_type, quant_type, act) + topk_ids, quant_type, quant_type, "silu") ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights, topk_ids).reshape(output.shape) diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index c125e0b5ec75..52507b375c27 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -20,6 +20,8 @@ MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_scales, query_marlin_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -190,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) +@pytest.mark.parametrize( + "group_size", + set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @@ -210,6 +213,7 @@ def test_gptq_marlin_gemm( use_fp32_reduce, ): m_factor, n_factor, k_factor = mnk_factors + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] size_m = m_factor size_k = k_chunk * k_factor @@ -220,6 +224,8 @@ def test_gptq_marlin_gemm( return if group_size == size_k: return + if has_zp: + return if size_k % group_size != 0: return @@ -227,7 +233,15 @@ def test_gptq_marlin_gemm( a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - if quant_type == scalar_types.float8_e4m3fn: + if quant_type == scalar_types.float4_e2m1f: + if group_size != 16 or act_order: + return + w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( + b_weight.T, group_size) + g_idx = None + sort_indices = None + marlin_zp = None + elif quant_type == scalar_types.float8_e4m3fn: if group_size not in [-1, 128]: return if act_order: @@ -236,26 +250,39 @@ def test_gptq_marlin_gemm( b_weight.T, group_size) g_idx = None sort_indices = None + marlin_zp = None + marlin_s2 = None + elif has_zp: + if group_size == 16: + return + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight, quant_type, group_size) + g_idx = None + sort_indices = None + marlin_s2 = None else: + if group_size == 16: + return w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( b_weight, quant_type, group_size, act_order) - - marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + marlin_zp = None + marlin_s2 = None workspace = marlin_make_workspace_new(w_ref.device) - opcheck( - torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], - a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck(torch.ops._C.gptq_marlin_gemm, + (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, + sort_indices, workspace, quant_type.id, a_input.shape[0], + b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, + use_fp32_reduce, False), + test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.gptq_marlin_gemm( a_input, None, marlin_q_w, marlin_s, + marlin_s2, marlin_zp, g_idx, sort_indices, @@ -339,67 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(True)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) -def test_awq_marlin_gemm( - k_chunk, - n_chunk, - quant_type, - group_size, - mnk_factors, - use_fp32_reduce, -): - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, quant_type, group_size) - - g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) - sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) - is_k_full = True - - workspace = marlin_make_workspace_new(a_input.device) - - output = ops.gptq_marlin_gemm( - a_input, - None, - marlin_q_w, - marlin_s, - marlin_zp, - g_idx, - sort_indices, - workspace, - quant_type, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - is_k_full=is_k_full, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - output_ref = torch.matmul(a_input, w_ref) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @@ -452,6 +418,7 @@ def test_hqq_marlin_gemm( None, marlin_w_q, marlin_s, + None, marlin_zp, g_idx, g_idx_sort_indices, @@ -564,6 +531,7 @@ def test_marlin_gemm_subset_input(): None, marlin_q_w, marlin_s, + None, marlin_zp, g_idx, sort_indices, diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index b08026c5867d..1f49900b2d90 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import pytest import torch +from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types if not current_platform.has_device_capability(100): pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", @@ -19,95 +20,24 @@ SEEDS = [42] CUDA_DEVICES = ['cuda:0'] -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - -kE2M1ToFloatArray = [ - 0., - 0.5, - 1., - 1.5, - 2., - 3., - 4., - 6., -] - - -def e2m1_to_fp32(int4_value): - signBit = (int4_value & 0x8) - int4_absValue = int4_value & 0x7 - float_result = kE2M1ToFloatArray[int4_absValue] - if (signBit): - float_result = -float_result - return float_result - - -def break_fp4_bytes(a, dtype): - assert (a.dtype == torch.uint8) - m, n = a.shape - a = a.flatten() - # Get upper 4 bits - highHalfByte = (a & 0xF0) >> 4 - # Get lower 4 bits - lowHalfByte = a & 0x0F - fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device) - fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device) - # [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC] - out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2) - return out - - -def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): - sf_m, sf_k = a_sf_swizzled.shape - m_tiles = (m + 128 - 1) // 128 - f = block_size * 4 - k_tiles = (k + f - 1) // f - tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) - tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) - out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) - return out[0:m, 0:k] - - -def dequantize_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): - """Dequantize the fp4 tensor back to high precision.""" - # Two fp4 values are packed into one uint8. - assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape - k = packed_k * 2 - tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) - tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) - tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale - - # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) - return out - def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, m, n, dtype, block_size, device): _, m_k = a_fp4.shape _, n_k = b_fp4.shape assert (m_k == n_k) - a_in_dtype = dequantize_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) return torch.matmul(a_in_dtype, b_in_dtype.t()) diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 76d33169081a..c7eee899896a 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("m", M + [28672]) # m >= 16 @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="only test for rocm fp8") def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 45f10b0eb1d5..30e6eeb8d566 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -13,8 +13,13 @@ device = "cuda" +triton_scaled_mm_module = importlib.import_module( + "vllm.model_executor.layers.quantization.compressed_tensors." + "triton_scaled_mm") +triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm -def scaled_mm_torch(a: torch.Tensor, + +def torch_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, @@ -101,21 +106,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, if use_bias: bias = torch.rand((N, ), device=device, dtype=out_dtype) - triton_scaled_mm_module = importlib.import_module( - "vllm.model_executor.layers.quantization.compressed_tensors." - "triton_scaled_mm") - triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - a_cpu = a.cpu() - b_cpu = b.cpu() - scale_a_cpu = scale_a.cpu() - scale_b_cpu = scale_b.cpu() - bias_cpu = None if bias is None else bias.cpu() - - c_actual = scaled_mm_torch(a_cpu, b_cpu, scale_a_cpu, scale_b_cpu, - out_dtype, bias_cpu) + c_actual = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - c_check_cpu = c_check.cpu() - torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(c_check, c_actual, rtol=1e-1, atol=1e-1) diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index fa84ad74cd88..faa8d49ce41b 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -5,9 +5,10 @@ import vllm._custom_ops as ops from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] -QUANT_DTYPES = [torch.float8_e4m3fn] +QUANT_DTYPES = [current_platform.fp8_dtype()] NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing SEEDS = [0] @@ -26,7 +27,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: out_shape = (x.shape[0], x.shape[1] // 2) out = torch.empty(out_shape, - dtype=torch.torch.float8_e4m3fn, + dtype=current_platform.fp8_dtype(), device=x.device) torch.ops._C.silu_and_mul_quant(out, x, scale) return out diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index b940f7190bb2..399311ce65bb 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module: return model +@pytest.fixture(scope="session") +def llama_2_7b_base_huggingface_id(): + # used as a base model for testing with sql lora adapter + return "meta-llama/Llama-2-7b-hf" + + @pytest.fixture(scope="session") def sql_lora_huggingface_id(): # huggingface repo id is used to test lora runtime downloading. @@ -198,6 +204,12 @@ def qwen2vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") +@pytest.fixture(scope="session") +def qwen25vl_base_huggingface_id(): + # used as a base model for testing with qwen25vl lora adapter + return "Qwen/Qwen2.5-VL-3B-Instruct" + + @pytest.fixture(scope="session") def qwen25vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon") @@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch): @pytest.fixture def reset_default_device(): """ - Some tests, such as `test_punica_ops.py`, explicitly set the - default device, which can affect subsequent tests. Adding this fixture + Some tests, such as `test_punica_ops.py`, explicitly set the + default device, which can affect subsequent tests. Adding this fixture helps avoid this problem. """ original_device = torch.get_default_device() diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index e3a054bd6206..580992dea53d 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +import subprocess +import sys +from typing import Union import pytest import ray import vllm +from vllm import LLM from vllm.lora.request import LoRARequest +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from ..utils import create_new_process_for_each_test, multi_gpu_test +from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test MODEL_PATH = "meta-llama/Llama-2-7b-hf" @@ -36,7 +41,10 @@ def v1(run_with_both_engines_lora): pass -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: +def do_sample(llm: vllm.LLM, + lora_path: str, + lora_id: int, + tensorizer_config_dict: Union[dict, None] = None) -> list[str]: prompts = [ "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 @@ -45,15 +53,28 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"]) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + + if tensorizer_config_dict is not None: + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest( + str(lora_id), + lora_id, + lora_path, + tensorizer_config_dict=tensorizer_config_dict) + if lora_id else None) + else: + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -64,18 +85,32 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: return generated_texts -def generate_and_test(llm, sql_lora_files): +def generate_and_test(llm, + sql_lora_files, + tensorizer_config_dict: Union[dict, None] = None): print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + assert do_sample(llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0) == EXPECTED_NO_LORA_OUTPUT print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + assert do_sample(llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=1) == EXPECTED_LORA_OUTPUT print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + assert do_sample(llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0) == EXPECTED_NO_LORA_OUTPUT print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT + assert do_sample(llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=2) == EXPECTED_LORA_OUTPUT print("removing lora") @@ -153,3 +188,64 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) + + +@multi_gpu_test(num_gpus=2) +@create_new_process_for_each_test() +def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, + sql_lora_huggingface_id): + + # Run the tensorizing of the LoRA adapter and the model in a subprocess + # to guarantee cleanup + + tp_size = 2 + model_name = "model-rank-%03d.tensors" + + model_ref = MODEL_PATH + lora_path = sql_lora_huggingface_id + suffix = "test" + try: + result = subprocess.run([ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", + MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", + str(tp_size), "serialize", "--serialized-directory", + str(tmp_path), "--suffix", suffix + ], + check=True, + capture_output=True, + text=True) + except subprocess.CalledProcessError as e: + print("Tensorizing failed.") + print("STDOUT:\n", e.stdout) + print("STDERR:\n", e.stderr) + raise + + print("STDOUT:\n", result.stdout) + + model_uri = tmp_path / "vllm" / model_ref / suffix / model_name + tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) + tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir + + loaded_vllm_model = LLM(model=model_ref, + load_format="tensorizer", + enable_lora=True, + enforce_eager=True, + model_loader_extra_config=tensorizer_config, + max_num_seqs=13, + tensor_parallel_size=2, + max_loras=2) + + tensorizer_config_dict = tensorizer_config.to_dict() + + print("lora adapter created") + assert do_sample(loaded_vllm_model, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(loaded_vllm_model, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=1) == EXPECTED_LORA_OUTPUT diff --git a/tests/lora/test_lora_allowed_token_ids.py b/tests/lora/test_lora_allowed_token_ids.py new file mode 100644 index 000000000000..094541aef02b --- /dev/null +++ b/tests/lora/test_lora_allowed_token_ids.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + VllmConfig) +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.v1.engine.processor import Processor + + +def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, + sql_lora_files): + """ + Test that we properly resolve the range of allowed token ids for lora + adapters that define additional tokens. + """ + + # Setup a base model compatible with the sql_lora_files adapter and + # a known number of tokens in the base model. + model_config = ModelConfig( + model=llama_2_7b_base_huggingface_id, + tokenizer=llama_2_7b_base_huggingface_id, + tokenizer_mode="auto", + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + device_config=DeviceConfig(), + lora_config=LoRAConfig(), + ) + + tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + processor = Processor(vllm_config, tokenizer) + + lora_request = LoRARequest("1", 1, str(sql_lora_files)) + request_id = "1" + prompt = "a prompt" + + # tokens added in the lora adapter should not raise an error + lora_token_ids = [32000, 32001, 32002, 32003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=lora_token_ids), + lora_request=lora_request) + + # tokens in the base model should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + lora_request=lora_request) + + # tokens not in the lora adapter should raise an error + invalid_token_ids = [35000, 35001, 35002, 35003] + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=invalid_token_ids), + lora_request=lora_request) + + # tokens in the lora adapter with no lora request should raise an error + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=lora_token_ids), + ) + + +def test_allowed_token_ids_with_lora_adapter_no_vocab( + qwen25vl_base_huggingface_id, qwen25vl_lora_files): + """ + Test that we properly resolve the range of allowed token ids for lora + adapters that do not define additional tokens. + """ + + # Setup a base model compatible with the qwen25vl_lora_files adapter and + # a known number of tokens in the base model. + model_config = ModelConfig( + model=qwen25vl_base_huggingface_id, + tokenizer=qwen25vl_base_huggingface_id, + tokenizer_mode="auto", + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + device_config=DeviceConfig(), + lora_config=LoRAConfig(), + ) + + tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + processor = Processor(vllm_config, tokenizer) + + lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files)) + request_id = "1" + prompt = "a prompt" + + # tokens in the base model should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + lora_request=lora_request) + + # tokens in the base model with no lora request should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + ) + + # tokens not in the base model should raise an error + invalid_token_ids = [200000, 200001, 200002, 200003] + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=invalid_token_ids), + lora_request=lora_request) diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 204624a0540a..7ae33a848a0a 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -69,7 +69,7 @@ def run_check(fn, args, expected: list): run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11]) run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11]) - # Remove all LoRAs + # Remove all LoRAs. run_check(llm.remove_lora, 13, [12, 10, 11]) run_check(llm.remove_lora, 12, [10, 11]) run_check(llm.remove_lora, 11, [10]) diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index 0875128c4ff1..90498c47fb10 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -30,7 +30,7 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_path = get_adapter_absolute_path(lora_name) - # lora loading should work for either absolute path and hugggingface id. + # lora loading should work for either absolute path and huggingface id. peft_helper = PEFTHelper.from_local_dir(lora_path, 4096) lora_model = LoRAModel.from_local_checkpoint( lora_path, diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 30b74ce3ef70..e5ae660af140 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -58,13 +58,19 @@ def set_active_loras(worker: Union[Worker, V1Worker], download_dir=None, load_format="dummy", ), - parallel_config=ParallelConfig(1, 1, False), + parallel_config=ParallelConfig( + pipeline_parallel_size=1, + tensor_parallel_size=1, + data_parallel_size=1, + ), scheduler_config=SchedulerConfig("generate", 32, 32, 32), device_config=DeviceConfig("cuda"), - cache_config=CacheConfig(block_size=16, - gpu_memory_utilization=1., - swap_space=0, - cache_dtype="auto"), + cache_config=CacheConfig( + block_size=16, + gpu_memory_utilization=1.0, + swap_space=0, + cache_dtype="auto", + ), lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, max_loras=32), ) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 2d9cf1d48fd5..e957db5b3f16 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,21 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - dispatch_fused_experts_func, dispatch_topk_func, - torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, - vllm_topk_softmax) +from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, + vllm_topk_softmax) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -98,35 +99,45 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() +@pytest.mark.skipif( + not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), + reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") +@pytest.mark.parametrize("use_cutlass", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): +@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) +def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, + use_rocm_aiter_gemm_w8a8_blockscale: str, + monkeypatch): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - topk_func = dispatch_topk_func() - is_rocm_aiter_moe_enabled.cache_clear() - if current_platform.is_rocm() and int(use_rocm_aiter): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax) - assert topk_func == rocm_aiter_topk_softmax + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", + use_rocm_aiter_gemm_w8a8_blockscale) + + use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( + int(use_rocm_aiter_gemm_w8a8_blockscale))) + block_scale_func = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) + if use_cutlass: + assert block_scale_func == cutlass_scaled_mm + elif current_platform.is_rocm() and int(use_rocm_aiter) and int( + use_rocm_aiter_gemm_w8a8_blockscale): + assert block_scale_func == ( + torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) else: - assert topk_func == vllm_topk_softmax + assert block_scale_func == w8a8_block_fp8_matmul @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("inplace", [True, False]) -def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, - monkeypatch): - +def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + topk_func = dispatch_topk_func() is_rocm_aiter_moe_enabled.cache_clear() - fused_experts_func = dispatch_fused_experts_func(inplace) if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts) - assert fused_experts_func == rocm_aiter_fused_experts - elif inplace: - assert fused_experts_func == torch_vllm_inplace_fused_experts + rocm_aiter_topk_softmax) + assert topk_func == rocm_aiter_topk_softmax else: - assert fused_experts_func == torch_vllm_outplace_fused_experts + assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py index 11dfe4d4995d..bdaba22c3c7a 100644 --- a/tests/model_executor/weight_utils.py +++ b/tests/model_executor/weight_utils.py @@ -20,11 +20,11 @@ def test_hf_transfer_auto_activation(): try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa - HF_TRANFER_ACTIVE = True + HF_TRANSFER_ACTIVE = True except ImportError: - HF_TRANFER_ACTIVE = False + HF_TRANSFER_ACTIVE = False assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == - HF_TRANFER_ACTIVE) + HF_TRANSFER_ACTIVE) def test_download_weights_from_hf(): diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index c755593c9acb..05dd18fbdf8b 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -28,6 +28,7 @@ "Qwen/Qwen-7B-Chat", "Qwen/Qwen2.5-0.5B-Instruct", "TitanML/tiny-mixtral", + "Qwen/Qwen3-8B", ] @@ -78,6 +79,9 @@ "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 marks=[pytest.mark.core_model], ), + pytest.param( + "Qwen/Qwen3-8B", # qwen (text-only) + ), pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 9b7a42acece5..604cb854b32f 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -31,7 +31,7 @@ # not compatible with pip-compile. "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", - "hmellor/bamba-tiny-random", + "hmellor/tiny-random-BambaForCausalLM", ] # Avoid OOM diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py new file mode 100644 index 000000000000..f83c9940d524 --- /dev/null +++ b/tests/models/language/pooling/mteb_utils.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence + +import mteb +import numpy as np +import pytest + +from tests.models.utils import EmbedModelInfo +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +# Most models on the STS12 task (See #17175): +# - Model implementation and minor changes in tensor dtype +# results in differences less than 1e-4 +# - Different model results in differences more than 1e-3 +# 1e-4 is a good tolerance threshold +MTEB_EMBED_TASKS = ["STS12"] +MTEB_EMBED_TOL = 1e-4 + + +class VllmMtebEncoder(mteb.Encoder): + + def __init__(self, vllm_model): + super().__init__() + self.model = vllm_model + self.rng = np.random.default_rng(seed=42) + + def encode( + self, + sentences: Sequence[str], + *args, + **kwargs, + ) -> np.ndarray: + # Hoping to discover potential scheduling + # issues by randomizing the order. + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + outputs = self.model.encode(sentences, use_tqdm=False) + embeds = np.array(outputs) + embeds = embeds[np.argsort(r)] + return embeds + + +class OpenAIClientMtebEncoder(mteb.Encoder): + + def __init__(self, model_name: str, client): + super().__init__() + self.model_name = model_name + self.client = client + self.rng = np.random.default_rng(seed=42) + + def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: + # Hoping to discover potential scheduling + # issues by randomizing the order. + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + + embeddings = self.client.embeddings.create(model=self.model_name, + input=sentences) + outputs = [d.embedding for d in embeddings.data] + embeds = np.array(outputs) + embeds = embeds[np.argsort(r)] + return embeds + + +def run_mteb_embed_task(encoder, tasks): + tasks = mteb.get_tasks(tasks=tasks) + evaluation = mteb.MTEB(tasks=tasks) + results = evaluation.run(encoder, verbosity=0, output_folder=None) + + main_score = results[0].scores["test"][0]["main_score"] + return main_score + + +def run_mteb_embed_task_st(model_name, tasks): + from sentence_transformers import SentenceTransformer + model = SentenceTransformer(model_name) + return run_mteb_embed_task(model, tasks) + + +def mteb_test_embed_models(hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + vllm_extra_kwargs=None): + if not model_info.enable_test: + # A model family has many models with the same architecture, + # and we don't need to test each one. + pytest.skip("Skipping test.") + + vllm_extra_kwargs = vllm_extra_kwargs or {} + + with vllm_runner(model_info.name, + task="embed", + max_model_len=None, + dtype=model_info.dtype, + **vllm_extra_kwargs) as vllm_model: + + if model_info.architecture: + assert (model_info.architecture + in vllm_model.model.llm_engine.model_config.architectures) + + vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), + MTEB_EMBED_TASKS) + vllm_dtype = vllm_model.model.llm_engine.model_config.dtype + model_dtype = getattr( + vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", + vllm_dtype) + + with set_default_torch_dtype(model_dtype) and hf_runner( + model_info.name, is_sentence_transformer=True, + dtype=model_dtype) as hf_model: + st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) + + print("VLLM:", vllm_dtype, vllm_main_score) + print("SentenceTransformer:", model_dtype, st_main_score) + print("Difference:", st_main_score - vllm_main_score) + + assert st_main_score == pytest.approx(vllm_main_score, rel=MTEB_EMBED_TOL) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 9db385e77bdb..a44b2154b137 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -15,13 +15,12 @@ marks=[pytest.mark.core_model, pytest.mark.cpu_model]), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), - pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), + pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", marks=[pytest.mark.core_model]), pytest.param("intfloat/e5-mistral-7b-instruct", marks=[pytest.mark.core_model, pytest.mark.cpu_model]), - pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), # [Cross-Encoder] pytest.param("sentence-transformers/stsb-roberta-base-v2"), @@ -47,9 +46,6 @@ def test_models( vllm_extra_kwargs["override_pooler_config"] = \ PoolerConfig(pooling_type="MEAN") - if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": - vllm_extra_kwargs["hf_overrides"] = {"is_causal": True} - # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index 3ad6e7190942..f450edd82162 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -2,7 +2,6 @@ from __future__ import annotations import importlib.util -import math from array import array import openai @@ -11,7 +10,6 @@ from vllm import LLM, SamplingParams from vllm.config import ModelConfig -from vllm.utils import STR_BACKEND_ENV_VAR from ....utils import RemoteOpenAIServer @@ -105,56 +103,49 @@ def get_test_data(): def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]): cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) - assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001) + assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=0.001) cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1]) - assert math.isclose(cosine_sim_q0_d1, 0.101, abs_tol=0.001) + assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=0.001) cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0]) - assert math.isclose(cosine_sim_q1_d0, 0.120, abs_tol=0.001) + assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=0.001) cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1]) - assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001) + assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=0.001) -def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch, - vllm_runner): - # GritLM embedding implementation is only supported by XFormers backend. - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") - - queries, q_instruction, documents, d_instruction = get_test_data() +def test_gritlm_offline_embedding(vllm_runner): + queries, q_instruction, documents, d_instruction = get_test_data() - with vllm_runner( - MODEL_NAME, - task="embed", - max_model_len=MAX_MODEL_LEN, - ) as vllm_model: - llm = vllm_model.model + with vllm_runner( + MODEL_NAME, + task="embed", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + llm = vllm_model.model - d_rep = run_llm_encode( - llm, - documents, - d_instruction, - ) - q_rep = run_llm_encode( - llm, - queries, - q_instruction, - ) + d_rep = run_llm_encode( + llm, + documents, + d_instruction, + ) + q_rep = run_llm_encode( + llm, + queries, + q_instruction, + ) - validate_embed_output(q_rep, d_rep) + validate_embed_output(q_rep, d_rep) @pytest.mark.asyncio async def test_gritlm_api_server_embedding(): queries, q_instruction, documents, d_instruction = get_test_data() - # GritLM embedding implementation is only supported by XFormers backend. args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] - env_dict = {STR_BACKEND_ENV_VAR: "XFORMERS"} - with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server: + with RemoteOpenAIServer(MODEL_NAME, args) as server: client_embedding = server.get_async_client() d_rep = await run_client_embeddings( @@ -172,35 +163,28 @@ async def test_gritlm_api_server_embedding(): def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner): - # GritLM embedding implementation is only supported by XFormers backend. - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") - - input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" + input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" - with vllm_runner( - MODEL_NAME, - task="generate", - max_model_len=MAX_MODEL_LEN, - ) as vllm_model: - llm = vllm_model.model + with vllm_runner( + MODEL_NAME, + task="generate", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + llm = vllm_model.model - sampling_params = SamplingParams(temperature=0.0, max_tokens=256) - outputs = llm.generate(input, sampling_params=sampling_params) + sampling_params = SamplingParams(temperature=0.0, max_tokens=256) + outputs = llm.generate(input, sampling_params=sampling_params) - assert outputs[0].outputs[0].text == "The capital of France is Paris." + assert outputs[0].outputs[0].text == "The capital of France is Paris." @pytest.mark.asyncio async def test_gritlm_api_server_generate(): input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" - # GritLM embedding implementation is only supported by XFormers backend. args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] - env_dict = {"VLLM_USE_V1": "0", STR_BACKEND_ENV_VAR: "XFORMERS"} - with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server: + with RemoteOpenAIServer(MODEL_NAME, args) as server: client_generate = server.get_async_client() outputs = await client_generate.completions.create( diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py new file mode 100644 index 000000000000..91d10f529cd6 --- /dev/null +++ b/tests/models/language/pooling/test_gte.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import pytest + +from ...utils import EmbedModelInfo, run_embedding_correctness_test + +MODELS = [ + ########## BertModel + EmbedModelInfo("thenlper/gte-large", + architecture="BertModel", + dtype="float32", + enable_test=True), + EmbedModelInfo("thenlper/gte-base", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-small", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-large-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-base-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-small-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + ########### NewModel + EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + enable_test=True), + ########### Qwen2ForCausalLM + EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", + architecture="Qwen2ForCausalLM", + enable_test=True), + ########## ModernBertModel + EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", + architecture="ModernBertModel", + enable_test=True), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + from .mteb_utils import mteb_test_embed_models + + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + mteb_test_embed_models(hf_runner, vllm_runner, model_info, + vllm_extra_kwargs) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, + example_prompts) -> None: + if not model_info.enable_test: + pytest.skip("Skipping test.") + + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] + + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + with vllm_runner(model_info.name, + task="embed", + dtype=model_info.dtype, + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner( + model_info.name, + dtype=model_info.dtype, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 5287ca37c0fb..0ddff2146caa 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import math - import pytest from vllm import PoolingParams @@ -60,7 +58,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): assert len(vllm_outputs) == 1 assert len(hf_outputs) == 1 - assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) @pytest.mark.parametrize("dtype", ["half"]) @@ -78,8 +76,8 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): assert len(vllm_outputs) == 10 assert len(hf_outputs) == 10 - assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) - assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) @pytest.fixture(scope="module", params=EMBEDDING_MODELS) diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py new file mode 100644 index 000000000000..28df32e0c230 --- /dev/null +++ b/tests/models/language/pooling/test_nomic.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from ...utils import EmbedModelInfo, run_embedding_correctness_test + +MODELS = [ + EmbedModelInfo("nomic-ai/nomic-embed-text-v1", + architecture="NomicBertModel", + dtype="float32", + enable_test=True), + EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", + architecture="NomicBertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", + architecture="NomicBertModel", + dtype="float32", + enable_test=True) +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + from .mteb_utils import mteb_test_embed_models + mteb_test_embed_models(hf_runner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, + example_prompts) -> None: + if not model_info.enable_test: + pytest.skip("Skipping test.") + + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] + + with vllm_runner(model_info.name, + task="embed", + dtype=model_info.dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner( + model_info.name, + dtype=model_info.dtype, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index e9527700c3ca..6b10aeffc4b7 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import math - import pytest import torch import torch.nn.functional as F @@ -45,7 +43,7 @@ def test_cross_encoder_1_to_1(vllm_runner, hf_runner, model_name): assert len(vllm_outputs) == 1 assert len(hf_outputs) == 1 - assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) def test_cross_encoder_1_to_N(vllm_runner, hf_runner, model_name): @@ -64,8 +62,8 @@ def test_cross_encoder_1_to_N(vllm_runner, hf_runner, model_name): assert len(vllm_outputs) == 2 assert len(hf_outputs) == 2 - assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) - assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) def test_cross_encoder_N_to_N(vllm_runner, hf_runner, model_name): @@ -84,8 +82,8 @@ def test_cross_encoder_N_to_N(vllm_runner, hf_runner, model_name): assert len(vllm_outputs) == 2 assert len(hf_outputs) == 2 - assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) - assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) @pytest.fixture(scope="module", params=EMBEDDING_MODELS) @@ -112,7 +110,7 @@ def test_embedding_1_to_1(vllm_runner, hf_runner, emb_model_name): assert len(vllm_outputs) == 1 assert len(hf_outputs) == 1 - assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) def test_embedding_1_to_N(vllm_runner, hf_runner, emb_model_name): @@ -140,8 +138,8 @@ def test_embedding_1_to_N(vllm_runner, hf_runner, emb_model_name): assert len(vllm_outputs) == 2 assert len(hf_outputs) == 2 - assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) - assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) def test_embedding_N_to_N(vllm_runner, hf_runner, emb_model_name): @@ -169,5 +167,5 @@ def test_embedding_N_to_N(vllm_runner, hf_runner, emb_model_name): assert len(vllm_outputs) == 2 assert len(hf_outputs) == 2 - assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) - assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index c050b35b76ba..5679e0e1ce00 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -1,12 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import pytest -from ...utils import EmbedModelInfo, check_embeddings_close +import pytest -EMBEDDING_PROMPTS = [ - 'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!', - 'Mexico City of Course!' -] +from ...utils import EmbedModelInfo, run_embedding_correctness_test MODELS = [ EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", @@ -45,51 +41,37 @@ @pytest.mark.parametrize("model_info", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -def test_models( +def test_models_mteb( hf_runner, vllm_runner, - example_prompts, model_info: EmbedModelInfo, - dtype: str, - monkeypatch, ) -> None: - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. - pytest.skip("Skipping test.") + from .mteb_utils import mteb_test_embed_models + mteb_test_embed_models(hf_runner, vllm_runner, model_info) - example_prompts = example_prompts + EMBEDDING_PROMPTS - vllm_extra_kwargs = { - "hf_overrides": { - "is_matryoshka": model_info.is_matryoshka - } - } +@pytest.mark.parametrize("model_info", MODELS) +def test_models_correctness( + hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + example_prompts, +) -> None: + if not model_info.enable_test: + pytest.skip("Skipping test.") - with hf_runner(model_info.name, dtype=dtype, - is_sentence_transformer=True) as hf_model: - hf_outputs = hf_model.encode(example_prompts) + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] with vllm_runner(model_info.name, task="embed", - dtype=dtype, - max_model_len=None, - **vllm_extra_kwargs) as vllm_model: - - assert (vllm_model.model.llm_engine.model_config.is_matryoshka == - model_info.is_matryoshka) - - if model_info.architecture: - assert (model_info.architecture - in vllm_model.model.llm_engine.model_config.architectures) - + dtype=model_info.dtype, + max_model_len=None) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) - check_embeddings_close( - embeddings_0_lst=hf_outputs, - embeddings_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) + with hf_runner( + model_info.name, + dtype=model_info.dtype, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 6e915a9f6005..e4e48f9951cf 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -8,14 +8,14 @@ from pathlib import PosixPath import pytest -from transformers import (AutoModelForImageTextToText, +from transformers import (AutoModel, AutoModelForImageTextToText, AutoModelForTextToWaveform, AutoModelForVision2Seq) from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, - VideoTestAssets, VllmRunner) +from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner, + ImageTestAssets, VideoTestAssets, VllmRunner) from ....utils import (create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks) from ...utils import check_outputs_equal @@ -158,6 +158,17 @@ image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + "ultravox": VLMTestInfo( + models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"], + test_type=VLMTestType.AUDIO, + prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + audio_idx_to_prompt=lambda idx: "<|audio|>", + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModel, + hf_output_post_proc=model_utils.ultravox_trunc_hf_output, + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -338,6 +349,17 @@ use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "intern_vl-video": VLMTestInfo( + models=[ + "OpenGVLab/InternVL3-1B", + ], + test_type=VLMTestType.VIDEO, + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + video_idx_to_prompt=lambda idx: "<video>", + max_model_len=8192, + use_tokenizer_eos=True, + patch_hf_runner=model_utils.internvl_patch_hf_runner, + ), "kimi_vl": VLMTestInfo( models=["moonshotai/Kimi-VL-A3B-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -393,7 +415,6 @@ formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 ), limit_mm_per_prompt={"video": 4}, - runner_mm_key="videos", )], ), "llava_next_video": VLMTestInfo( @@ -476,6 +497,31 @@ max_num_seqs=2, patch_hf_runner=model_utils.molmo_patch_hf_runner, ), + "ovis1_6-gemma2": VLMTestInfo( + models=["AIDC-AI/Ovis1.6-Gemma2-9B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + marks=[large_gpu_mark(min_gb=32)], + ), + "ovis1_6": VLMTestInfo( + models=["AIDC-AI/Ovis1.6-Llama3.2-3B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + ), "ovis2": VLMTestInfo( models=["AIDC-AI/Ovis2-1B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -486,7 +532,7 @@ dtype="half", # use sdpa mode for hf runner since ovis2 didn't work with flash_attn hf_model_kwargs={"llm_attn_implementation": "sdpa"}, - patch_hf_runner=model_utils.ovis2_patch_hf_runner, + patch_hf_runner=model_utils.ovis_patch_hf_runner, ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], @@ -681,6 +727,7 @@ def _mark_splits( # - multi-image # - image embeddings # - video +# - audio # - custom inputs @pytest.mark.parametrize( "model_type,test_case", @@ -778,6 +825,28 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, ) +@pytest.mark.parametrize( + "model_type,test_case", + get_parametrized_options( + VLM_TEST_SETTINGS, + test_type=VLMTestType.AUDIO, + create_new_process_for_each_test=False, + )) +def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, monkeypatch): + if model_type in REQUIRES_V0_MODELS: + monkeypatch.setenv("VLLM_USE_V1", "0") + model_test_info = VLM_TEST_SETTINGS[model_type] + runners.run_audio_test( + model_test_info=model_test_info, + test_case=test_case, + hf_runner=hf_runner, + vllm_runner=vllm_runner, + audio_assets=audio_assets, + ) + + @pytest.mark.parametrize( "model_type,test_case", get_parametrized_options( @@ -905,6 +974,29 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, ) +@pytest.mark.parametrize( + "model_type,test_case", + get_parametrized_options( + VLM_TEST_SETTINGS, + test_type=VLMTestType.AUDIO, + create_new_process_for_each_test=True, + )) +def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, monkeypatch): + if model_type in REQUIRES_V0_MODELS: + monkeypatch.setenv("VLLM_USE_V1", "0") + model_test_info = VLM_TEST_SETTINGS[model_type] + runners.run_audio_test( + model_test_info=model_test_info, + test_case=test_case, + hf_runner=hf_runner, + vllm_runner=vllm_runner, + audio_assets=audio_assets, + ) + + @pytest.mark.parametrize( "model_type,test_case", get_parametrized_options( diff --git a/tests/models/multimodal/generation/test_interleaved.py b/tests/models/multimodal/generation/test_interleaved.py index eec84751e450..972db40e8bd6 100644 --- a/tests/models/multimodal/generation/test_interleaved.py +++ b/tests/models/multimodal/generation/test_interleaved.py @@ -4,6 +4,7 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset +from vllm.multimodal.image import convert_image_mode models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"] @@ -26,8 +27,9 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: give the same result. """ - image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB") - image_stop = ImageAsset("stop_sign").pil_image.convert("RGB") + image_cherry = convert_image_mode( + ImageAsset("cherry_blossom").pil_image, "RGB") + image_stop = convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB") images = [image_cherry, image_stop] video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 11460a1a8d2b..e51dbee479c5 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -1,18 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 import os -import re from collections.abc import Sequence from typing import Optional import librosa import pytest +import regex as re from huggingface_hub import snapshot_download from transformers import AutoTokenizer from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest -from vllm.multimodal.image import rescale_image_size +from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs @@ -267,7 +267,7 @@ def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, # use the example speech question so that the model outputs are reasonable audio = librosa.load(speech_question, sr=None) - image = ImageAsset("cherry_blossom").pil_image.convert("RGB") + image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") inputs_vision_speech = [ ( diff --git a/tests/models/multimodal/generation/test_ultravox.py b/tests/models/multimodal/generation/test_ultravox.py index 322d886a593d..2c8a06688ca0 100644 --- a/tests/models/multimodal/generation/test_ultravox.py +++ b/tests/models/multimodal/generation/test_ultravox.py @@ -1,20 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Any, Optional +from typing import Any import numpy as np import pytest import pytest_asyncio -from transformers import AutoModel, AutoTokenizer +from transformers import AutoTokenizer -from vllm.multimodal.audio import resample_audio_librosa -from vllm.sequence import SampleLogprobs - -from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner +from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner from ....utils import RemoteOpenAIServer from ...registry import HF_EXAMPLE_MODELS -from ...utils import check_logprobs_close MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" @@ -88,79 +84,6 @@ def _get_prompt(audio_count, question, placeholder): add_generation_prompt=True) -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - tokenizer = AutoTokenizer.from_pretrained(model) - eos_token_id = tokenizer.eos_token_id - - hf_output_ids = output_ids[:] - hf_output_str = output_str - if hf_output_ids[-1] == eos_token_id: - hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - - return hf_output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - prompts_and_audios: list[tuple[str, str, AudioTuple]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - **kwargs, -): - """Inference result should be the same between hf and vllm.""" - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip") - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - with vllm_runner(model, dtype=dtype, enforce_eager=True, - **kwargs) as vllm_model: - vllm_outputs_per_audio = [ - vllm_model.generate_greedy_logprobs([vllm_prompt], - max_tokens, - num_logprobs=num_logprobs, - audios=[audio]) - for vllm_prompt, _, audio in prompts_and_audios - ] - - with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: - hf_outputs_per_audio = [ - hf_model.generate_greedy_logprobs_limit( - [hf_prompt], - max_tokens, - num_logprobs=num_logprobs, - audios=[(resample_audio_librosa(audio[0], - orig_sr=audio[1], - target_sr=16000), 16000)]) - for _, hf_prompt, audio in prompts_and_audios - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio, - vllm_outputs_per_audio): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - def run_multi_audio_test( vllm_runner: type[VllmRunner], prompts_and_audios: list[tuple[str, list[AudioTuple]]], @@ -194,35 +117,6 @@ def run_multi_audio_test( assert all(tokens for tokens, *_ in vllm_outputs) -@pytest.mark.core_model -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("vllm_kwargs", [ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) -def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets, - dtype: str, max_tokens: int, num_logprobs: int, - vllm_kwargs: dict) -> None: - audio_inputs = [( - _get_prompt(1, audio, VLLM_PLACEHOLDER), - _get_prompt(1, audio, HF_PLACEHOLDER), - audio.audio_and_sample_rate, - ) for audio in audio_assets] - - run_test( - hf_runner, - vllm_runner, - audio_inputs, - MODEL_NAME, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - **vllm_kwargs, - ) - - @pytest.mark.core_model @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index e3ba955a96a6..32117c8d8dca 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -7,18 +7,21 @@ import torch +from vllm.multimodal.audio import AudioResampler from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import (rescale_video_size, resize_video, sample_frames_from_video) -from .....conftest import ImageTestAssets, VideoTestAssets -from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER, +from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets +from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS, + TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER, TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, - ImageSizeWrapper, SizeType, VLMTestInfo) + ImageSizeWrapper, PromptWithMultiModalInput, SizeType, + VLMTestInfo) -def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], - str], +def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], + str], test_placeholder: str) -> str: """Given a prompt, replaces each test placeholder with the model-specific tag. @@ -26,7 +29,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], prompt_segments = prompt.split(test_placeholder) img_prompt = prompt_segments[0] for placeholder_idx, next_seg in enumerate(prompt_segments[1:], start=1): - img_prompt += img_idx_to_prompt(placeholder_idx) + img_prompt += mm_idx_to_prompt(placeholder_idx) img_prompt += next_seg return img_prompt @@ -34,6 +37,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], def get_model_prompts(base_prompts: Iterable[str], img_idx_to_prompt: Optional[Callable[[int], str]], video_idx_to_prompt: Optional[Callable[[int], str]], + audio_idx_to_prompt: Optional[Callable[[int], str]], prompt_formatter: Callable[[str], str]) -> list[str]: """Given a model-agnostic base prompt and test configuration for a model(s) to be tested, update the media placeholders and apply the prompt formatting @@ -60,6 +64,11 @@ def get_model_prompts(base_prompts: Iterable[str], video_idx_to_prompt, TEST_VIDEO_PLACEHOLDER) + if audio_idx_to_prompt: + base_prompt = replace_test_placeholder(base_prompt, + audio_idx_to_prompt, + TEST_AUDIO_PLACEHOLDER) + # Apply the prompt formatter to wrap the base prompt with # the correct media placeholders to get the model test prompt model_prompt = prompt_formatter(base_prompt) @@ -68,10 +77,11 @@ def get_model_prompts(base_prompts: Iterable[str], def build_single_image_inputs_from_test_info( - test_info: VLMTestInfo, - image_assets: ImageTestAssets, - size_wrapper: ImageSizeWrapper, - tmp_path: Optional[PosixPath] = None): + test_info: VLMTestInfo, + image_assets: ImageTestAssets, + size_wrapper: ImageSizeWrapper, + tmp_path: Optional[PosixPath] = None, +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError( "Prompt formatter must be set to build single image inputs") @@ -79,6 +89,7 @@ def build_single_image_inputs_from_test_info( model_prompts = get_model_prompts(test_info.single_image_prompts, test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter) # For models that require a local path / URL encoded in the image; export @@ -97,28 +108,32 @@ def build_single_image_inputs_from_test_info( return build_single_image_inputs(images, model_prompts, size_wrapper) -def build_single_image_inputs(images, model_prompts, - size_wrapper: ImageSizeWrapper): +def build_single_image_inputs( + images, model_prompts, + size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: # For every image / prompt pair, get a pair containing two lists of # length size_factors, where the first contains duplicates of the model # prompt [str], and the second contains copies of the image after being # scaled by one of the size factors. # # NOTE: rescaling preserves the image aspect ratio. - return [( - [prompt for _ in size_wrapper.data], - [ - apply_image_size_scaling(image, size, size_wrapper.type) - for size in size_wrapper.data - ], - ) for image, prompt in zip(images, model_prompts)] + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + image_data=[ + apply_image_size_scaling(image, size, size_wrapper.type) + for size in size_wrapper.data + ], + ) for image, prompt in zip(images, model_prompts) + ] def build_multi_image_inputs_from_test_info( - test_info: VLMTestInfo, - image_assets: ImageTestAssets, - size_wrapper: ImageSizeWrapper, - tmp_path: Optional[PosixPath] = None): + test_info: VLMTestInfo, + image_assets: ImageTestAssets, + size_wrapper: ImageSizeWrapper, + tmp_path: Optional[PosixPath] = None, +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError( "Prompt formatter must be set to build multi image inputs") @@ -126,6 +141,7 @@ def build_multi_image_inputs_from_test_info( model_prompts = get_model_prompts([test_info.multi_image_prompt], test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter) if test_info.prompt_path_encoder is not None: @@ -146,15 +162,18 @@ def build_multi_image_inputs_from_test_info( ) -def build_multi_image_inputs(image_lists, model_prompts, - size_wrapper: ImageSizeWrapper): - return [( - [prompt for _ in size_wrapper.data], - [[ - apply_image_size_scaling(image, size, size_wrapper.type) - for image in images - ] for size in size_wrapper.data], - ) for images, prompt in zip(image_lists, model_prompts)] +def build_multi_image_inputs( + image_lists, model_prompts, + size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + image_data=[[ + apply_image_size_scaling(image, size, size_wrapper.type) + for image in images + ] for size in size_wrapper.data], + ) for images, prompt in zip(image_lists, model_prompts) + ] def build_embedding_inputs_from_test_info( @@ -177,6 +196,7 @@ def build_embedding_inputs_from_test_info( SINGLE_IMAGE_BASE_PROMPTS, test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter, ) @@ -195,13 +215,14 @@ def build_video_inputs_from_test_info( video_assets: VideoTestAssets, size_wrapper: ImageSizeWrapper, num_frames: int, -): +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError("Prompt formatter must be set to build video inputs") model_prompts = get_model_prompts( [VIDEO_BASE_PROMPT], test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter, ) @@ -213,10 +234,14 @@ def build_video_inputs_from_test_info( video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE else rescale_video_size) - return [( - [prompt for _ in size_wrapper.data], - [video_scaler(video, size) for size in size_wrapper.data], - ) for video, prompt in zip(sampled_vids, model_prompts)] + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + video_data=[ + video_scaler(video, size) for size in size_wrapper.data + ], + ) for video, prompt in zip(sampled_vids, model_prompts) + ] def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], @@ -236,3 +261,37 @@ def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], # We have a list of fixed sizes return image.resize(size) raise ValueError("ImageSizeWrapper type must be FIXED_SIZE or SIZE_FACTOR") + + +def build_audio_inputs_from_test_info( + test_info: VLMTestInfo, + audio_assets: AudioTestAssets, +) -> list[PromptWithMultiModalInput]: + if test_info.prompt_formatter is None: + raise ValueError("Prompt formatter must be set to build audio inputs") + model_prompts = get_model_prompts( + SINGLE_AUDIO_BASE_PROMPT, + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) + resampler = AudioResampler( + target_sr=16000, + method="librosa", + ) + audios = [asset.audio_and_sample_rate for asset in audio_assets] + resampled_audios = [( + resampler.resample( + audio, + orig_sr=sr, + ), + int(resampler.target_sr), + ) for audio, sr in audios] + + return [ + PromptWithMultiModalInput( + prompts=model_prompts, + audio_data=resampled_audios, + ) + ] diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index 8e825676b8f4..a5077a090b52 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -83,7 +83,7 @@ def get_model_type_cases(model_type: str, test_info: VLMTestInfo): test_info.num_video_frames) # No sizes passed for custom inputs, since inputs are directly provided - if test_type != VLMTestType.CUSTOM_INPUTS: + if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) if wrapped_sizes is None: raise ValueError( @@ -91,7 +91,7 @@ def get_model_type_cases(model_type: str, test_info: VLMTestInfo): iter_kwargs["size_wrapper"] = wrapped_sizes #Otherwise expand the custom test options instead - else: + elif test_type == VLMTestType.CUSTOM_INPUTS: if test_info.custom_test_opts is None: raise ValueError("Test has type CUSTOM_INPUTS, but none given") iter_kwargs["custom_test_opts"] = test_info.custom_test_opts @@ -136,8 +136,8 @@ def get_wrapped_test_sizes( ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) for factor in EMBEDDING_SIZE_FACTORS ]) - # Custom inputs have preprocessed inputs - elif test_type == VLMTestType.CUSTOM_INPUTS: + # Audio and Custom inputs have preprocessed inputs + elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS): return tuple() size_factors = test_info.image_size_factors \ diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index c3d20f56855f..ccd2799abd90 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Core test implementation to be shared across modalities.""" -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional import torch -from PIL.Image import Image from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import TaskOption @@ -11,14 +10,14 @@ from .....conftest import HfRunner, VllmRunner from ....registry import HF_EXAMPLE_MODELS -from .types import RunnerOutput +from .types import PromptWithMultiModalInput, RunnerOutput def run_test( *, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: list[tuple[list[str], list[Union[list[Image], Image]]]], + inputs: list[PromptWithMultiModalInput], model: str, dtype: str, max_tokens: int, @@ -38,7 +37,6 @@ def run_test( hf_model_kwargs: Optional[dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], task: TaskOption = "auto", - runner_mm_key: str = "images", distributed_executor_backend: Optional[str] = None, tensor_parallel_size: int = 1, vllm_embeddings: Optional[torch.Tensor] = None, @@ -94,10 +92,16 @@ def run_test( if stop_str: vllm_kwargs["stop"] = stop_str - for prompts, media in vllm_inputs: - vllm_kwargs[runner_mm_key] = media + for prompts, image_data, video_data, audio_data in vllm_inputs: + mm_data = dict(images=image_data, + videos=video_data, + audios=audio_data) + vllm_kwargs_with_mm_data = vllm_kwargs | mm_data vllm_output = vllm_model.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs=num_logprobs, **vllm_kwargs) + prompts, + max_tokens, + num_logprobs=num_logprobs, + **vllm_kwargs_with_mm_data) vllm_outputs_per_mm.append(vllm_output) hf_model = hf_runner(model, @@ -122,14 +126,17 @@ def run_test( if stop_str: hf_kwargs["stop_strings"] = stop_str - for prompts, media in inputs: - hf_kwargs[runner_mm_key] = media + for prompts, image_data, video_data, audio_data in inputs: + mm_data = dict(images=image_data, + videos=video_data, + audios=audio_data) + hf_kwargs_with_mm_data = hf_kwargs | mm_data hf_output = hf_model.generate_greedy_logprobs_limit( prompts, max_tokens, num_logprobs=num_logprobs, tokenizer=tokenizer, - **hf_kwargs) + **hf_kwargs_with_mm_data) hf_outputs_per_mm.append(hf_output) # Apply output processing / sanitation to the vLLM and HF runner results diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index 235618ae547e..cc1045561138 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -12,7 +12,7 @@ from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS from .builders import build_multi_image_inputs, build_single_image_inputs -from .types import ImageSizeWrapper, SizeType +from .types import ImageSizeWrapper, PromptWithMultiModalInput, SizeType def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): @@ -32,24 +32,28 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): "<image>\nWhat is the season?", ] formatted_prompts = [formatter(prompt) for prompt in img_prompts] - - return [( - formatted_prompts, + aspect_ratio_images = [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios + [ + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], [ - [stop_sign, cherry_blossom], - # Images with different sizes and aspect-ratios - [ - rescale_image_size(stop_sign, 0.1), - stop_sign, - ], - [ - stop_sign, - rescale_image_size(stop_sign, 0.25), - cherry_blossom.resize((183, 488)), - cherry_blossom.resize((488, 183)) - ], - cherry_blossom, - ])] + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ] + + return [ + PromptWithMultiModalInput( + prompts=formatted_prompts, + image_data=aspect_ratio_images, + ) + ] def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], @@ -68,24 +72,28 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], "<video>\nWhy is this video funny?", ] formatted_prompts = [formatter(prompt) for prompt in video_prompts] - - return [( - formatted_prompts, + aspect_ratio_videos = [ + [video, video], + # Videos with different sizes and aspect-ratios [ - [video, video], - # Videos with different sizes and aspect-ratios - [ - rescale_video_size(video, 0.1), - video, - ], - [ - video, - rescale_video_size(video, 0.25), - resize_video(video, (183, 488)), - resize_video(video, (488, 183)) - ], + rescale_video_size(video, 0.1), video, - ])] + ], + [ + video, + rescale_video_size(video, 0.25), + resize_video(video, (183, 488)), + resize_video(video, (488, 183)) + ], + video, + ] + + return [ + PromptWithMultiModalInput( + prompts=formatted_prompts, + video_data=aspect_ratio_videos, + ) + ] def different_patch_input_cases_internvl(): diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index f0f4ed989241..dc1ea5208240 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -3,11 +3,13 @@ for manipulating the input / output of HF & vLLM test runners, which are typically specific to a small subset of models. """ -import re import types from pathlib import PosixPath from typing import Optional, Union +import numpy as np +import numpy.typing as npt +import regex as re import torch from PIL.Image import Image from transformers import (AutoConfig, AutoTokenizer, BatchFeature, @@ -237,6 +239,18 @@ def minimax_vl_01_hf_output(hf_output: RunnerOutput, return output_ids, output_str, out_logprobs +def ultravox_trunc_hf_output(hf_output: RunnerOutput, + model: str) -> RunnerOutput: + output_ids, output_str, out_logprobs = hf_output + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + eos_token = tokenizer.decode(eos_token_id) + if output_str.endswith(eos_token): + output_str = output_str.split(eos_token)[0] + return output_ids, output_str, out_logprobs + + ####### Functions for converting image assets to embeddings def get_llava_embeddings(image_assets: ImageTestAssets): return [asset.image_embeds for asset in image_assets] @@ -483,30 +497,74 @@ def __init__(self, hf_runner: HfRunner): self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], - **kwargs): + def __call__( + self, + text: str, + images: Union[Image, list[Image]] = None, + videos: Union[npt.NDArray, list[npt.NDArray]] = None, + **kwargs, + ): from vllm.model_executor.models.internvl import ( IMG_CONTEXT, IMG_END, IMG_START, - image_to_pixel_values_internvl) + image_to_pixel_values_internvl, video_to_pixel_values_internvl) images = [images] if isinstance(images, Image) else images - pixel_values = [ - image_to_pixel_values_internvl( - image, - input_size=self.image_size, - min_num=self.min_num, - max_num=self.max_num, - use_thumbnail=self.use_thumbnail, - ) for image in images - ] - num_patches_list = [ - pixel_value.shape[0] for pixel_value in pixel_values - ] + videos = [videos] if isinstance(videos, np.ndarray) else videos + if images is not None: + pixel_values_images = [ + image_to_pixel_values_internvl( + image, + input_size=self.image_size, + min_num=self.min_num, + max_num=self.max_num, + use_thumbnail=self.use_thumbnail, + ) for image in images + ] + num_patches_images = [ + pixel_value.shape[0] for pixel_value in pixel_values_images + ] + else: + pixel_values_images, num_patches_images = [], [] + + if videos is not None: + pixel_values_videos = [ + video_to_pixel_values_internvl( + video, + input_size=self.image_size, + min_num=1, + max_num=1, + use_thumbnail=False, + ) for video in videos + ] + num_patches_videos = [ + pixel_value.shape[0] for pixel_value in pixel_values_videos + ] + else: + pixel_values_videos, num_patches_videos = [], [] + + pixel_values = [] + while ("<image>" in text) or ("<video>" in text): + image_index = text.find("<image>") + video_index = text.find("<video>") + if image_index == -1 or (video_index > -1 + and video_index < image_index): + num_patches = num_patches_videos.pop(0) + pixel_values.append(pixel_values_videos.pop(0)) + context_tokens = IMG_START + \ + IMG_CONTEXT * self.num_image_token + IMG_END + video_tokens = ''.join([ + f'Frame{i+1}: {context_tokens}' + for i in range(num_patches) + ]) + text = text.replace('<video>', video_tokens, 1) + else: + num_patches = num_patches_images.pop(0) + pixel_values.append(pixel_values_images.pop(0)) + context_tokens = IMG_CONTEXT * self.num_image_token \ + * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('<image>', image_tokens, 1) pixel_values = torch.cat(pixel_values, dim=0) - for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches - image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('<image>', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt @@ -678,12 +736,8 @@ def _generate(self, max_new_tokens=None, do_sample=None, **kwargs): return hf_model -def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner: +def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Ovis2.""" - hf_model.model.visual_tokenizer.to(hf_model.dtype) - hf_model.model.vte.to(hf_model.dtype) - hf_model.model.llm.to(hf_model.dtype) - hf_model.model.get_output_embeddings = lambda: \ hf_model.model.llm.get_output_embeddings() @@ -691,7 +745,16 @@ def processor(*args, text="", images=None, **kwargs): text_tokenizer = hf_model.model.get_text_tokenizer() images = [images] if isinstance(images, Image) else images - text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0] + prompt_start_and_end = { + "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), + "llama": + ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"), + } + for start, end in prompt_start_and_end.values(): + if start in text and end in text: + text = text.split(start)[1].split(end)[0] + break prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs( text_or_conversations=text, images=images) diff --git a/tests/models/multimodal/generation/vlm_utils/runners.py b/tests/models/multimodal/generation/vlm_utils/runners.py index 34753121ea90..9e8a1262e8c1 100644 --- a/tests/models/multimodal/generation/vlm_utils/runners.py +++ b/tests/models/multimodal/generation/vlm_utils/runners.py @@ -4,8 +4,8 @@ """ from pathlib import PosixPath -from .....conftest import (HfRunner, ImageTestAssets, VideoTestAssets, - VllmRunner) +from .....conftest import (AudioTestAssets, HfRunner, ImageTestAssets, + VideoTestAssets, VllmRunner) from . import builders, core from .types import ExpandableVLMTestArgs, VLMTestInfo @@ -30,7 +30,6 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"image": 1}, distributed_executor_backend=test_case.distributed_executor_backend, - runner_mm_key="images", **model_test_info.get_non_parametrized_runner_kwargs()) @@ -53,7 +52,6 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"image": len(image_assets)}, distributed_executor_backend=test_case.distributed_executor_backend, - runner_mm_key="images", **model_test_info.get_non_parametrized_runner_kwargs()) @@ -77,7 +75,6 @@ def run_embedding_test(*, model_test_info: VLMTestInfo, limit_mm_per_prompt={"image": 1}, vllm_embeddings=vllm_embeddings, distributed_executor_backend=test_case.distributed_executor_backend, - runner_mm_key="images", **model_test_info.get_non_parametrized_runner_kwargs()) @@ -105,7 +102,30 @@ def run_video_test( num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"video": len(video_assets)}, distributed_executor_backend=test_case.distributed_executor_backend, - runner_mm_key="videos", + **model_test_info.get_non_parametrized_runner_kwargs()) + + +def run_audio_test( + *, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): + inputs = builders.build_audio_inputs_from_test_info( + model_test_info, audio_assets) + + core.run_test( + hf_runner=hf_runner, + vllm_runner=vllm_runner, + inputs=inputs, + model=test_case.model, + dtype=test_case.dtype, + max_tokens=test_case.max_tokens, + num_logprobs=test_case.num_logprobs, + limit_mm_per_prompt={"audio": 1}, + distributed_executor_backend=test_case.distributed_executor_backend, **model_test_info.get_non_parametrized_runner_kwargs()) @@ -120,11 +140,9 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo, inputs = test_case.custom_test_opts.inputs limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt - runner_mm_key = test_case.custom_test_opts.runner_mm_key - # Inputs, limit_mm_per_prompt, and runner_mm_key should all be set + # Inputs and limit_mm_per_prompt should all be set assert inputs is not None assert limit_mm_per_prompt is not None - assert runner_mm_key is not None core.run_test( hf_runner=hf_runner, @@ -136,5 +154,4 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt=limit_mm_per_prompt, distributed_executor_backend=test_case.distributed_executor_backend, - runner_mm_key=runner_mm_key, **model_test_info.get_non_parametrized_runner_kwargs()) diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 56629323394d..1c2bb4d6222b 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -6,7 +6,6 @@ from typing import Any, Callable, NamedTuple, Optional, Union import torch -from PIL.Image import Image from pytest import MarkDecorator from transformers import AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass @@ -15,18 +14,25 @@ from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer -from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, ImageTestAssets +from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset, + ImageTestAssets, PromptAudioInput, PromptImageInput, + PromptVideoInput) from ....utils import check_logprobs_close # meta image tag; will be replaced by the appropriate tag for the model TEST_IMG_PLACEHOLDER = "<vlm_image>" TEST_VIDEO_PLACEHOLDER = "<vlm_video>" +TEST_AUDIO_PLACEHOLDER = "<lmm_audio>" # yapf: disable SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?", "cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?", }) +SINGLE_AUDIO_BASE_PROMPT = AUDIO_ASSETS.prompts({ + "mary_had_lamb": f"{TEST_AUDIO_PLACEHOLDER}Transcribe this audio into English.", # noqa: E501 + "winning_call": f"{TEST_AUDIO_PLACEHOLDER}What is happening in this audio clip?", # noqa: E501 +}) MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PLACEHOLDER}Describe the two images in detail.\n" # noqa: E501 VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?" @@ -38,12 +44,21 @@ # yapf: enable +class PromptWithMultiModalInput(NamedTuple): + """Holds the multimodal input for a single test case.""" + prompts: list[str] + image_data: Optional[PromptImageInput] = None + video_data: Optional[PromptVideoInput] = None + audio_data: Optional[PromptAudioInput] = None + + class VLMTestType(Enum): IMAGE = 1 MULTI_IMAGE = 2 EMBEDDING = 3 VIDEO = 4 - CUSTOM_INPUTS = 5 + AUDIO = 5 + CUSTOM_INPUTS = 6 class SizeType(Enum): @@ -52,10 +67,8 @@ class SizeType(Enum): class CustomTestOptions(NamedTuple): - inputs: list[tuple[list[str], list[Union[list[Image], Image]]]] + inputs: list[PromptWithMultiModalInput] limit_mm_per_prompt: dict[str, int] - # kwarg to pass multimodal data in as to vllm/hf runner instances. - runner_mm_key: str = "images" class ImageSizeWrapper(NamedTuple): @@ -75,6 +88,7 @@ class VLMTestInfo(NamedTuple): prompt_formatter: Optional[Callable[[str], str]] = None img_idx_to_prompt: Callable[[int], str] = lambda idx: "<image>\n" video_idx_to_prompt: Callable[[int], str] = lambda idx: "<video>\n" + audio_idx_to_prompt: Callable[[int], str] = lambda idx: "<audio>\n" # Most models work on the single / multi-image prompts above, but in some # cases the log prob check fails, e.g., for paligemma. We allow passing diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 772a2db3e48a..572fa366d332 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -9,15 +9,15 @@ UserMessage) from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.inputs import MultiModalInputs from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache -from vllm.transformers_utils.tokenizer import (MistralTokenizer, - cached_tokenizer_from_config) +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + cached_tokenizer_from_config, + encode_tokens) from ....multimodal.utils import random_audio, random_image, random_video from ...registry import HF_EXAMPLE_MODELS @@ -28,7 +28,6 @@ def _test_processing_correctness( hit_rate: float, num_batches: int, simplify_rate: float, - ignore_mm_keys: Optional[set[str]] = None, ): model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") @@ -99,10 +98,23 @@ def _test_processing_correctness( } mm_counts = {k: len(vs) for k, vs in mm_data.items()} - prompt = dummy_inputs.get_dummy_processor_inputs( - model_config.max_model_len, - mm_counts, - ).prompt_text + + # Mistral chat outputs tokens directly, rather than text prompts + if isinstance(tokenizer, MistralTokenizer): + images = mm_data.get("image", []) + request = ChatCompletionRequest(messages=[ + UserMessage(content=[ + TextChunk(text=""), + *(ImageChunk(image=image) for image in images), + ]), + ]) + res = tokenizer.mistral.encode_chat_completion(request) + prompt = res.tokens + else: + prompt = dummy_inputs.get_dummy_processor_inputs( + model_config.max_model_len, + mm_counts, + ).prompt # Drop unnecessary keys and test single -> multi conversion if rng.rand() < simplify_rate: @@ -112,66 +124,59 @@ def _test_processing_correctness( elif len(mm_data[k]) == 1: mm_data[k] = mm_data[k][0] - if isinstance(tokenizer, MistralTokenizer): - _test_processing_correctness_mistral( - model_config, - tokenizer, - prompt, - mm_data, - baseline_processor, - cached_processor, - batch_idx, - ignore_mm_keys=ignore_mm_keys, - ) - else: - _test_processing_correctness_hf( - model_config, - tokenizer, - prompt, - mm_data, - baseline_processor, - cached_processor, - batch_idx, - ignore_mm_keys=ignore_mm_keys, - ) - - -def _test_processing_correctness_hf( + _test_processing_correctness_one( + model_config, + tokenizer, + prompt, + mm_data, + baseline_processor, + cached_processor, + batch_idx, + ) + + +# For some multimodal models, tokenizer will always add bos_token +# at the beginning of prompt by default, causing hf_processor outputs +# incorrect token ids. So we need use `add_special_tokens=False` here +# to leave bos_token to be added by the processor. +_ADD_SPECIAL_TOKENS_OVERRIDES = { + "mllama": False, + "ovis": False, + "ultravox": False, + "whisper": False, +} + +_IGNORE_MM_KEYS = { + # In Ultravox, the audio_features can be different depending on padding + # The slight difference should not be a problem though, since + # attention_mask lets us ignore the difference. + "ultravox": {"audio_features"}, +} + + +def _test_processing_correctness_one( model_config: ModelConfig, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - prompt: str, + tokenizer: AnyTokenizer, + prompt: Union[str, list[int]], mm_data: MultiModalDataDict, baseline_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor, batch_idx: int, - ignore_mm_keys: Optional[set[str]] = None, ): - if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): - # For some multimodal models, tokenizer will always add bos_token - # at the beginning of prompt by default, causing hf_processor outputs - # incorrect token ids. So we need use `add_special_tokens=False` here - # to leave bos_token to be added by the processor. - token_prompt = tokenizer.encode(prompt, add_special_tokens=False) + model_type = model_config.hf_config.model_type + ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]()) + + if isinstance(prompt, str): + text_prompt = prompt + token_prompt = encode_tokens( + tokenizer, + prompt, + add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type), + ) else: - token_prompt = tokenizer.encode(prompt) - - baseline_result = baseline_processor.apply( - prompt, - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) - cached_result = cached_processor.apply( - prompt, - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) - - _assert_inputs_equal( - baseline_result, - cached_result, - ignore_mm_keys=ignore_mm_keys, - msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", - ) + # Mistral does not support decode_tokens with skip_special_tokens=False + text_prompt = None + token_prompt = prompt baseline_tokenized_result = baseline_processor.apply( token_prompt, @@ -179,56 +184,6 @@ def _test_processing_correctness_hf( hf_processor_mm_kwargs={}, ) - _assert_inputs_equal( - baseline_result, - baseline_tokenized_result, - ignore_mm_keys=ignore_mm_keys, - msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", - ) - - cached_tokenized_result = cached_processor.apply( - token_prompt, - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) - - _assert_inputs_equal( - cached_result, - cached_tokenized_result, - ignore_mm_keys=ignore_mm_keys, - msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", - ) - - -def _test_processing_correctness_mistral( - model_config: ModelConfig, - tokenizer: MistralTokenizer, - prompt: str, - mm_data: MultiModalDataDict, - baseline_processor: BaseMultiModalProcessor, - cached_processor: BaseMultiModalProcessor, - batch_idx: int, - ignore_mm_keys: Optional[set[str]] = None, -): - images = mm_data.get("image", []) - if not isinstance(images, list): - images = [images] - - request = ChatCompletionRequest(messages=[ - UserMessage(content=[ - TextChunk(text=prompt), - *(ImageChunk(image=image) for image in images), - ]), - ]) - res = tokenizer.mistral.encode_chat_completion(request) - token_prompt = res.tokens - - # Mistral chat outputs tokens directly, rather than text prompts - baseline_tokenized_result = baseline_processor.apply( - token_prompt, - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) cached_tokenized_result = cached_processor.apply( token_prompt, mm_data=mm_data, @@ -239,9 +194,44 @@ def _test_processing_correctness_mistral( baseline_tokenized_result, cached_tokenized_result, ignore_mm_keys=ignore_mm_keys, - msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", + msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})", ) + if text_prompt is not None: + baseline_text_result = baseline_processor.apply( + text_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + cached_text_result = cached_processor.apply( + text_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + _assert_inputs_equal( + baseline_text_result, + cached_text_result, + ignore_mm_keys=ignore_mm_keys, + msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})", + ) + + _assert_inputs_equal( + baseline_text_result, + baseline_tokenized_result, + ignore_mm_keys=ignore_mm_keys, + msg=f"Failed ({batch_idx=}, {text_prompt=}, " + f"{token_prompt=}, {mm_data=})", + ) + + _assert_inputs_equal( + cached_text_result, + cached_tokenized_result, + ignore_mm_keys=ignore_mm_keys, + msg=f"Failed ({batch_idx=}, {text_prompt=}, " + f"{token_prompt=}, {mm_data=})", + ) + # yapf: disable @pytest.mark.parametrize("model_id", [ @@ -257,6 +247,7 @@ def _test_processing_correctness_mistral( "ibm-granite/granite-speech-3.3-8b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", + "OpenGVLab/InternVL3-1B", "HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceTB/SmolVLM2-2.2B-Instruct", "moonshotai/Kimi-VL-A3B-Instruct", @@ -274,9 +265,12 @@ def _test_processing_correctness_mistral( "allenai/Molmo-7B-D-0924", "allenai/Molmo-7B-O-0924", "nvidia/NVLM-D-72B", + "AIDC-AI/Ovis1.6-Gemma2-9B", + "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", + "microsoft/Phi-3.5-vision-instruct", "microsoft/Phi-4-multimodal-instruct", "mistralai/Pixtral-12B-2409", "mistral-community/pixtral-12b", @@ -299,41 +293,6 @@ def test_processing_correctness( num_batches: int, simplify_rate: float, ): - ignore_mm_keys = None - if 'ultravox' in model_id: - # In Ultravox, the audio_features can be different depending on padding - # The slight difference should not be a problem though, since - # attention_mask lets us ignore the difference. - ignore_mm_keys = {"audio_features"} - - _test_processing_correctness( - model_id, - hit_rate=hit_rate, - num_batches=num_batches, - simplify_rate=simplify_rate, - ignore_mm_keys=ignore_mm_keys, - ) - - -# yapf: disable -@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) -@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) -@pytest.mark.parametrize("num_batches", [32]) -@pytest.mark.parametrize("simplify_rate", [1.0]) -# yapf: enable -def test_processing_correctness_phi3v( - model_id: str, - hit_rate: float, - num_batches: int, - simplify_rate: float, -): - # HACK - this is an attempted workaround for the following bug - # https://github.com/huggingface/transformers/issues/34307 - from transformers import AutoImageProcessor # noqa: F401 - from transformers import AutoProcessor # noqa: F401 - - AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) - _test_processing_correctness( model_id, hit_rate=hit_rate, @@ -352,16 +311,10 @@ def _assert_inputs_equal( if ignore_mm_keys is None: ignore_mm_keys = set() - if msg is None: - assert "mm_kwargs" in a and "mm_kwargs" in b - else: - assert "mm_kwargs" in a and "mm_kwargs" in b, msg + assert "mm_kwargs" in a and "mm_kwargs" in b, msg for key in ignore_mm_keys: a["mm_kwargs"].pop(key, None) b["mm_kwargs"].pop(key, None) - if msg is None: - assert a == b - else: - assert a == b, msg + assert a == b, msg diff --git a/tests/models/multimodal/processing/test_mllama.py b/tests/models/multimodal/processing/test_mllama.py index b89376cf1722..d4794396f6d2 100644 --- a/tests/models/multimodal/processing/test_mllama.py +++ b/tests/models/multimodal/processing/test_mllama.py @@ -49,7 +49,7 @@ def test_profiling( ] * max_num_seqs mm_kwargs = processor.apply( - prompt=dummy_mm_data.prompt_text, + prompt=dummy_mm_data.prompt, mm_data=dummy_mm_data.mm_data, hf_processor_mm_kwargs=dict(), )["mm_kwargs"] diff --git a/tests/models/quantization/test_aqlm.py b/tests/models/quantization/test_aqlm.py index 548053b7ae43..1272a62974cc 100644 --- a/tests/models/quantization/test_aqlm.py +++ b/tests/models/quantization/test_aqlm.py @@ -2,6 +2,7 @@ import pytest from tests.quantization.utils import is_quant_method_supported +from vllm.platforms import current_platform # These ground truth generations were generated using `transformers==4.38.1 # aqlm==1.1.0 torch==2.2.0` @@ -34,7 +35,9 @@ ] -@pytest.mark.skipif(not is_quant_method_supported("aqlm"), +@pytest.mark.skipif(not is_quant_method_supported("aqlm") + or current_platform.is_rocm() + or not current_platform.is_cuda(), reason="AQLM is not supported on this GPU type.") @pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) @pytest.mark.parametrize("dtype", ["half"]) diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index 4d15675a3ab2..e01ee2026393 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -55,6 +55,14 @@ def test_models( Only checks log probs match to cover the discrepancy in numerical sensitive kernels. """ + + if backend == "FLASHINFER" and current_platform.is_rocm(): + pytest.skip("Flashinfer does not support ROCm/HIP.") + + if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): + pytest.skip( + f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") + with monkeypatch.context() as m: m.setenv("TOKENIZERS_PARALLELISM", 'true') m.setenv(STR_BACKEND_ENV_VAR, backend) diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py index 3ff36502df57..5f17d12284a0 100644 --- a/tests/models/quantization/test_gguf.py +++ b/tests/models/quantization/test_gguf.py @@ -78,8 +78,12 @@ def gguf_model(self): ) MODELS = [ - LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG, - DOLPHIN_CONFIG + LLAMA_CONFIG, + QWEN2_CONFIG, + PHI3_CONFIG, + GPT2_CONFIG, + # STABLELM_CONFIG, # enable this when v1 support head_size=80 + DOLPHIN_CONFIG, # STARCODER_CONFIG, # broken ] diff --git a/tests/models/quantization/test_gptq_marlin.py b/tests/models/quantization/test_gptq_marlin.py index 680134c6eae8..397bdb98123f 100644 --- a/tests/models/quantization/test_gptq_marlin.py +++ b/tests/models/quantization/test_gptq_marlin.py @@ -14,6 +14,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT +from vllm.platforms import current_platform from ..utils import check_logprobs_close @@ -34,7 +35,9 @@ @pytest.mark.flaky(reruns=3) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin") + or current_platform.is_rocm() + or not current_platform.is_cuda(), reason="gptq_marlin is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half", "bfloat16"]) diff --git a/tests/models/quantization/test_gptq_marlin_24.py b/tests/models/quantization/test_gptq_marlin_24.py index ce28f964d544..6fb24b1f432e 100644 --- a/tests/models/quantization/test_gptq_marlin_24.py +++ b/tests/models/quantization/test_gptq_marlin_24.py @@ -10,6 +10,7 @@ import pytest from tests.quantization.utils import is_quant_method_supported +from vllm.platforms import current_platform from ..utils import check_logprobs_close @@ -38,7 +39,9 @@ class ModelPair: @pytest.mark.flaky(reruns=2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24") + or current_platform.is_rocm() + or not current_platform.is_cuda(), reason="Marlin24 is not supported on this GPU type.") @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index f94f3457c377..510858c2d7ef 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -41,8 +41,8 @@ reason= "Prevent unstable test based on golden strings from breaking the build " " and test input model being too large and hanging the system.") -@pytest.mark.skipif(not is_quant_method_supported("nvfp4"), - reason="nvfp4 is not supported on this GPU type.") +@pytest.mark.skipif(not is_quant_method_supported("modelopt_fp4"), + reason="modelopt_fp4 is not supported on this GPU type.") @pytest.mark.parametrize("model_name", MODELS) def test_models(example_prompts, model_name) -> None: model = LLM( @@ -50,7 +50,7 @@ def test_models(example_prompts, model_name) -> None: max_model_len=MAX_MODEL_LEN, trust_remote_code=True, enforce_eager=True, - quantization="nvfp4", + quantization="modelopt_fp4", ) tokenizer = AutoTokenizer.from_pretrained(model_name) diff --git a/tests/models/registry.py b/tests/models/registry.py index a1f2edac02b9..a49e3ad6b20e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -8,6 +8,8 @@ from packaging.version import Version from transformers import __version__ as TRANSFORMERS_VERSION +from vllm.config import TokenizerMode + @dataclass(frozen=True) class _HfExamplesInfo: @@ -20,7 +22,7 @@ class _HfExamplesInfo: tokenizer: Optional[str] = None """Set the tokenizer to load for this architecture.""" - tokenizer_mode: str = "auto" + tokenizer_mode: TokenizerMode = "auto" """Set the tokenizer type for this architecture.""" speculative_model: Optional[str] = None @@ -55,9 +57,18 @@ class _HfExamplesInfo: trust_remote_code: bool = False """The ``trust_remote_code`` level required to load the model.""" + v0_only: bool = False + """The model is only available with the vLLM V0 engine.""" + hf_overrides: dict[str, Any] = field(default_factory=dict) """The ``hf_overrides`` required to load the model.""" + max_model_len: Optional[int] = None + """ + The maximum model length to use for this model. Some models default to a + length that is too large to fit into memory in CI. + """ + def check_transformers_version( self, *, @@ -124,7 +135,7 @@ def check_available_online( "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", - extras={"tiny": "hmellor/bamba-tiny-random"}), # noqa: E501 + extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", {"1b": "bigscience/bloomz-1b1"}), "ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b", @@ -147,6 +158,9 @@ def check_available_online( "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), + "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct", + is_available_online=False, + min_transformers_version="4.52.2"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), @@ -205,17 +219,18 @@ def check_available_online( "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), - "Olmo2ForCausalLM": _HfExamplesInfo("shanearora/OLMo-7B-1124-hf"), + "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"}), "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", trust_remote_code=True), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), - "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), + "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", @@ -231,7 +246,8 @@ def check_available_online( is_available_online=False), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 is_available_online=False), - "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), + "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t", + v0_only=True), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", @@ -242,6 +258,8 @@ def check_available_online( is_available_online=False, trust_remote_code=True), "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), + "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", + trust_remote_code=True), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), @@ -254,11 +272,17 @@ def check_available_online( "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), + "GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5", + trust_remote_code=True, + hf_overrides={"architectures": + ["GteNewModel"]}), "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", trust_remote_code=True), "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), + "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", + trust_remote_code=True), "NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501 trust_remote_code=True), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), @@ -292,7 +316,8 @@ def check_available_online( "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501 "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 - extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501 + extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501 + v0_only=True), "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 @@ -311,15 +336,18 @@ def check_available_online( max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible."), # noqa: E501 "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", - extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501 + extras={"2B": "OpenGVLab/InternVL2-2B", + "3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501 trust_remote_code=True), "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - min_transformers_version="4.51"), + min_transformers_version="4.51", + max_model_len=10240), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501 @@ -338,7 +366,8 @@ def check_available_online( extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 trust_remote_code=True), "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", @@ -355,9 +384,9 @@ def check_available_online( max_transformers_version="4.48", transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 - "Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B", - trust_remote_code=True, - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501 + "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, + extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", + "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 @@ -371,6 +400,8 @@ def check_available_online( "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501 "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B", min_transformers_version="4.52"), + "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501 + min_transformers_version="4.52"), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 @@ -403,6 +434,9 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), + "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", + trust_remote_code=True, + speculative_model="XiaomiMiMo/MiMo-7B-RL") } _TRANSFORMERS_MODELS = { diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 446c4efbf6af..d403cb392fe0 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -15,12 +15,12 @@ @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) -def test_can_initialize(model_arch): +def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - # Avoid OOM + # Avoid OOM and reduce initialization time by only using 1 layer def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) @@ -34,6 +34,12 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: "num_local_experts": 2, }) + if hasattr(hf_config, "vision_config"): + hf_config.vision_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + return hf_config # Avoid calling model.forward() @@ -46,7 +52,7 @@ def _initialize_kv_caches_v1(self, vllm_config): scheduler_kv_cache_config = get_kv_cache_config( vllm_config, kv_cache_specs[0], - 20 * GiB_bytes, + 10 * GiB_bytes, ) # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config @@ -55,7 +61,9 @@ def _initialize_kv_caches_v1(self, vllm_config): with (patch.object(V0LLMEngine, "_initialize_kv_caches", _initialize_kv_caches_v0), patch.object(V1EngineCore, "_initialize_kv_caches", - _initialize_kv_caches_v1)): + _initialize_kv_caches_v1), monkeypatch.context() as m): + if model_info.v0_only: + m.setenv("VLLM_USE_V1", "0") LLM( model_info.default, tokenizer=model_info.tokenizer, @@ -65,6 +73,7 @@ def _initialize_kv_caches_v1(self, vllm_config): "num_speculative_tokens": 1, } if model_info.speculative_model else None, trust_remote_code=model_info.trust_remote_code, + max_model_len=model_info.max_model_len, load_format="dummy", hf_overrides=hf_overrides, ) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index b45a87d94b86..b62720caa9cb 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -4,6 +4,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset +from vllm.multimodal.image import convert_image_mode from ..utils import create_new_process_for_each_test @@ -58,7 +59,7 @@ def test_oot_registration_embedding( assert all(v == 0 for v in output.outputs.embedding) -image = ImageAsset("cherry_blossom").pil_image.convert("RGB") +image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") @create_new_process_for_each_test() diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 6da488897be5..1a51b4aeab04 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,38 +1,56 @@ # SPDX-License-Identifier: Apache-2.0 """Test the functionality of the Transformers backend.""" +from typing import Any, Optional, Union + import pytest +from vllm.platforms import current_platform + from ..conftest import HfRunner, VllmRunner +from ..core.block.e2e.test_correctness_sliding_window import prep_prompts from ..utils import multi_gpu_test from .utils import check_logprobs_close def check_implementation( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], + runner_ref: type[Union[HfRunner, VllmRunner]], + runner_test: type[VllmRunner], example_prompts: list[str], model: str, + kwargs_ref: Optional[dict[str, Any]] = None, + kwargs_test: Optional[dict[str, Any]] = None, **kwargs, ): + if kwargs_ref is None: + kwargs_ref = {} + if kwargs_test is None: + kwargs_test = {} + max_tokens = 32 num_logprobs = 5 - with vllm_runner(model, **kwargs) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + args = (example_prompts, max_tokens, num_logprobs) - with hf_runner(model) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + with runner_test(model, **kwargs_test, **kwargs) as model_test: + outputs_test = model_test.generate_greedy_logprobs(*args) + + with runner_ref(model, **kwargs_ref) as model_ref: + if isinstance(model_ref, VllmRunner): + outputs_ref = model_ref.generate_greedy_logprobs(*args) + else: + outputs_ref = model_ref.generate_greedy_logprobs_limit(*args) check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", + outputs_0_lst=outputs_ref, + outputs_1_lst=outputs_test, + name_0="ref", + name_1="test", ) +@pytest.mark.skipif( + current_platform.is_rocm(), + reason="Llama-3.2-1B-Instruct, Ilama-3.2-1B produce memory access fault.") @pytest.mark.parametrize( "model,model_impl", [ @@ -53,6 +71,18 @@ def test_models( model_impl=model_impl) +def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None: + prompts, _, _ = prep_prompts(4, (800, 801)) + kwargs_ref = {"max_model_len": 8192, "enforce_eager": True} + kwargs_test = {"model_impl": "transformers", **kwargs_ref} + check_implementation(vllm_runner, + vllm_runner, + prompts, + model="hmellor/tiny-random-Gemma2ForCausalLM", + kwargs_ref=kwargs_ref, + kwargs_test=kwargs_test) + + @multi_gpu_test(num_gpus=2) def test_distributed( hf_runner: type[HfRunner], @@ -60,10 +90,16 @@ def test_distributed( example_prompts, ): kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} - check_implementation(hf_runner, vllm_runner, example_prompts, - "meta-llama/Llama-3.2-1B-Instruct", **kwargs) + check_implementation(hf_runner, + vllm_runner, + example_prompts, + "meta-llama/Llama-3.2-1B-Instruct", + kwargs_test=kwargs) +@pytest.mark.skipif( + current_platform.is_rocm(), + reason="bitsandbytes quantization is currently not supported in rocm.") @pytest.mark.parametrize("model, quantization_kwargs", [ ( "meta-llama/Llama-3.2-1B-Instruct", diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index d61c7d2d5000..a16384efe195 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -77,3 +77,73 @@ def weight_generator(): assert torch.all( new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_skip_prefix(): + """Ensure the auto weight loader can skip prefix.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + # weights needed to be filtered out + redundant_weights = { + "prefix.bn.weight": torch.Tensor([1, 2]), + "prefix.bn.bias": torch.Tensor([3, 4]), + } + yield from (mod.state_dict() | redundant_weights).items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."]) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_skip_substr(): + """Ensure the auto weight loader can skip prefix.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + # weights needed to be filtered out + redundant_weights = { + "nested_mod.0.substr.weight": torch.Tensor([1, 2]), + "nested_mod.0.substr.bias": torch.Tensor([3, 4]), + "nested_mod.substr.weight": torch.Tensor([1, 2]), + "nested_mod.substr.bias": torch.Tensor([3, 4]), + } + yield from (mod.state_dict() | redundant_weights).items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."]) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 diff --git a/tests/models/utils.py b/tests/models/utils.py index bb87863d076e..a43fd77c6d79 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -332,9 +332,10 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int): class EmbedModelInfo(NamedTuple): name: str - is_matryoshka: bool + is_matryoshka: bool = False matryoshka_dimensions: Optional[list[int]] = None architecture: str = "" + dtype: str = "auto" enable_test: bool = True diff --git a/tests/multimodal/assets/rgba.png b/tests/multimodal/assets/rgba.png new file mode 100644 index 000000000000..11eb81857a65 Binary files /dev/null and b/tests/multimodal/assets/rgba.png differ diff --git a/tests/multimodal/test_image.py b/tests/multimodal/test_image.py new file mode 100644 index 000000000000..56b5475c9ca0 --- /dev/null +++ b/tests/multimodal/test_image.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path + +import numpy as np +from PIL import Image, ImageChops + +from vllm.multimodal.image import convert_image_mode + +ASSETS_DIR = Path(__file__).parent / "assets" +assert ASSETS_DIR.exists() + + +def test_rgb_to_rgb(): + # Start with an RGB image. + original_image = Image.open(ASSETS_DIR / "image1.png").convert("RGB") + converted_image = convert_image_mode(original_image, "RGB") + + # RGB to RGB should be a no-op. + diff = ImageChops.difference(original_image, converted_image) + assert diff.getbbox() is None + + +def test_rgba_to_rgb(): + original_image = Image.open(ASSETS_DIR / "rgba.png") + original_image_numpy = np.array(original_image) + + converted_image = convert_image_mode(original_image, "RGB") + converted_image_numpy = np.array(converted_image) + + for i in range(original_image_numpy.shape[0]): + for j in range(original_image_numpy.shape[1]): + # Verify that all transparent pixels are converted to white. + if original_image_numpy[i][j][3] == 0: + assert converted_image_numpy[i][j][0] == 255 + assert converted_image_numpy[i][j][1] == 255 + assert converted_image_numpy[i][j][2] == 255 diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 478184c34b91..f1e45da30eda 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -10,6 +10,7 @@ import pytest from PIL import Image, ImageChops +from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (MediaConnector, merge_and_sort_multimodal_metadata) @@ -53,7 +54,7 @@ def get_supported_suffixes() -> tuple[str, ...]: def _image_equals(a: Image.Image, b: Image.Image) -> bool: - return (np.asarray(a) == np.asarray(b.convert(a.mode))).all() + return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all() @pytest.mark.asyncio diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py new file mode 100644 index 000000000000..e67624ecefcb --- /dev/null +++ b/tests/multimodal/test_video.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +import numpy as np +import numpy.typing as npt +import pytest + +from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader + +NUM_FRAMES = 10 +FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3) +FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3) + + +@VIDEO_LOADER_REGISTRY.register("test_video_loader_1") +class TestVideoLoader1(VideoLoader): + + @classmethod + def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: + return FAKE_OUTPUT_1 + + +@VIDEO_LOADER_REGISTRY.register("test_video_loader_2") +class TestVideoLoader2(VideoLoader): + + @classmethod + def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: + return FAKE_OUTPUT_2 + + +def test_video_loader_registry(): + custom_loader_1 = VIDEO_LOADER_REGISTRY.load("test_video_loader_1") + output_1 = custom_loader_1.load_bytes(b"test") + np.testing.assert_array_equal(output_1, FAKE_OUTPUT_1) + + custom_loader_2 = VIDEO_LOADER_REGISTRY.load("test_video_loader_2") + output_2 = custom_loader_2.load_bytes(b"test") + np.testing.assert_array_equal(output_2, FAKE_OUTPUT_2) + + +def test_video_loader_type_doesnt_exist(): + with pytest.raises(AssertionError): + VIDEO_LOADER_REGISTRY.load("non_existing_video_loader") diff --git a/tests/neuron/2_core/test_eagle.py b/tests/neuron/2_core/test_eagle.py new file mode 100644 index 000000000000..d71c88689a99 --- /dev/null +++ b/tests/neuron/2_core/test_eagle.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import shutil +import tempfile + +import torch +from huggingface_hub import snapshot_download +from safetensors import safe_open + +from vllm import LLM, SamplingParams + + +def patch_eagle_draft_with_lm_head(target_model_id: str, + draft_model_id: str) -> str: + # In NxDI, draft model checkpoint must include lm_head weights from target + # model. For more details see https://awsdocs-neuron.readthedocs-hosted.com + # /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html + # #eagle-checkpoint-compatibility + final_draft_dir = "/tmp/patched_eagle_draft" + + with tempfile.TemporaryDirectory() as tmp_dir: + target_dir = snapshot_download(repo_id=target_model_id, + local_dir=os.path.join( + tmp_dir, "target")) + draft_dir = snapshot_download(repo_id=draft_model_id, + local_dir=os.path.join(tmp_dir, "draft")) + + lm_head_key = "lm_head.weight" + index_path = os.path.join(target_dir, "model.safetensors.index.json") + with open(index_path) as f: + index = json.load(f) + shard_name = index["weight_map"][lm_head_key] + target_safetensor_path = os.path.join(target_dir, shard_name) + + with safe_open(target_safetensor_path, framework="pt") as f: + target_lm_head = f.get_tensor(lm_head_key) + + draft_path = os.path.join(draft_dir, "pytorch_model.bin") + draft_state_dict = torch.load(draft_path, map_location="cpu") + draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16) + torch.save(draft_state_dict, draft_path) + + shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True) + + return final_draft_dir + + +def test_eagle(): + patched_draft_path = patch_eagle_draft_with_lm_head( + target_model_id="meta-llama/Llama-2-7b-hf", + draft_model_id="yuhuili/EAGLE-llama2-chat-7B") + llm = LLM( + model="meta-llama/Llama-2-7b-hf", + speculative_config={ + "model": patched_draft_path, + "num_speculative_tokens": 5, + "max_model_len": 128 + }, + max_num_seqs=1, + max_model_len=128, + tensor_parallel_size=2, + override_neuron_config={ + "enable_eagle_speculation": True, + "enable_fused_speculation": True, + "fused_qkv": True + }, + ) + prompts = [ + "The president of the United States is", + ] + outputs = llm.generate(prompts, SamplingParams(top_k=1)) + expected_output = " the head of state and head of government of " \ + "the United States. The president direct" + + for output in outputs: + generated_text = output.outputs[0].text + print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") + assert (expected_output == generated_text) + + print("Neuron Eagle speculation test passed.") diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py new file mode 100644 index 000000000000..3e651502d1e2 --- /dev/null +++ b/tests/neuron/2_core/test_mistral.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams + + +def test_mistral(): + llm = LLM(model="mistralai/Mistral-7B-v0.1", + tensor_parallel_size=2, + max_num_seqs=4, + max_model_len=128, + use_v2_block_manager=True, + override_neuron_config={ + "sequence_parallel_enabled": False, + "skip_warmup": True + }) + + # Send more prompts than the compiled batch size (4) and request + # varying generation lengths to test accuracy related to Neuron + # specific sequence id sorting. + prompts = [ + "The president of the United States is", + "The capital of France is", + "What is Annapurna labs?", + "I believe the meaning of life is", + "Tell me a story about a brave knight", + "Hello, my name is Llama", + ] + + sampling_params = [ + SamplingParams(top_k=1, max_tokens=10), + SamplingParams(top_k=1, max_tokens=20), + SamplingParams(top_k=1, max_tokens=30), + SamplingParams(top_k=1, max_tokens=40), + SamplingParams(top_k=1, max_tokens=50), + SamplingParams(top_k=1, max_tokens=60) + ] + + outputs = llm.generate(prompts, sampling_params) + + expected_outputs = [ + " the most powerful person in the world. He is", + " a city of many faces. It is a city of history, culture, art, " + "fashion, and", + "\n\nAnnapurna Labs is a semiconductor company that was founded " + "in 2013 by Amazon. The company is", + " to be happy.\n\nI believe that happiness is a choice.\n\nI " + "believe that happiness is a state of mind.\n\nI believe that " + "happiness is a journey.\n\nI believe", + " who rescued a princess from a dragon.\n\nTell me a story about" + " a princess who rescued herself from a dragon.\n\nTell me a " + "story about a princess who rescued herself from a dragon and " + "then rescued a knight from", + " and I am a 10 year old male. I am a very friendly and " + "affectionate boy who loves to be around people. I am a very " + "active boy who loves to play and run around. I am a very smart " + "boy who loves to learn new things. I am a very loyal boy" + ] + + for expected_output, output in zip(expected_outputs, outputs): + generated_text = output.outputs[0].text + print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") + assert (expected_output == generated_text) + + print("Neuron Mistral test passed.") diff --git a/vllm/v1/stats/__init__.py b/tests/plugins/lora_resolvers/__init__.py similarity index 100% rename from vllm/v1/stats/__init__.py rename to tests/plugins/lora_resolvers/__init__.py diff --git a/tests/plugins/lora_resolvers/test_filesystem_resolver.py b/tests/plugins/lora_resolvers/test_filesystem_resolver.py new file mode 100644 index 000000000000..cb0f0c3c5fa6 --- /dev/null +++ b/tests/plugins/lora_resolvers/test_filesystem_resolver.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import shutil + +import pytest +from huggingface_hub import snapshot_download + +from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver + +MODEL_NAME = "mistralai/Mistral-7B-v0.1" +LORA_NAME = "typeof/zephyr-7b-beta-lora" +PA_NAME = "swapnilbp/llama_tweet_ptune" + + +@pytest.fixture(scope='module') +def adapter_cache(request, tmpdir_factory): + # Create dir that mimics the structure of the adapter cache + adapter_cache = tmpdir_factory.mktemp( + request.module.__name__) / "adapter_cache" + return adapter_cache + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def pa_files(): + return snapshot_download(repo_id=PA_NAME) + + +@pytest.mark.asyncio +async def test_filesystem_resolver(adapter_cache, zephyr_lora_files): + model_files = adapter_cache / LORA_NAME + shutil.copytree(zephyr_lora_files, model_files) + + fs_resolver = FilesystemResolver(adapter_cache) + assert fs_resolver is not None + + lora_request = await fs_resolver.resolve_lora(MODEL_NAME, LORA_NAME) + assert lora_request is not None + assert lora_request.lora_name == LORA_NAME + assert lora_request.lora_path == os.path.join(adapter_cache, LORA_NAME) + + +@pytest.mark.asyncio +async def test_missing_adapter(adapter_cache): + fs_resolver = FilesystemResolver(adapter_cache) + assert fs_resolver is not None + + missing_lora_request = await fs_resolver.resolve_lora(MODEL_NAME, "foobar") + assert missing_lora_request is None + + +@pytest.mark.asyncio +async def test_nonlora_adapter(adapter_cache, pa_files): + model_files = adapter_cache / PA_NAME + shutil.copytree(pa_files, model_files) + + fs_resolver = FilesystemResolver(adapter_cache) + assert fs_resolver is not None + + pa_request = await fs_resolver.resolve_lora(MODEL_NAME, PA_NAME) + assert pa_request is None diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index 9d6872e0e077..207de53abd8d 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -29,5 +29,5 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch): # ignore the backend env variable if it is set with monkeypatch.context() as m: m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) - backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) + backend = get_attn_backend(16, torch.float16, "auto", 16, False) assert backend.get_name() == "Dummy_Backend" diff --git a/tests/quantization/test_auto_round.py b/tests/quantization/test_auto_round.py new file mode 100644 index 000000000000..81ceecdb45d6 --- /dev/null +++ b/tests/quantization/test_auto_round.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test model set-up and inference for quantized HF models supported + on the AutoRound. + + Validating the configuration and printing results for manual checking. + + Run `pytest tests/quantization/test_auto_round.py`. +""" + +import pytest + +from vllm.platforms import current_platform + +MODELS = [ + "OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq + "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound" ##auto_round:auto_awq +] + + +@pytest.mark.skipif(not current_platform.is_cpu() + and not current_platform.is_xpu() + and not current_platform.is_cuda(), + reason="only supports CPU/XPU/CUDA backend.") +@pytest.mark.parametrize("model", MODELS) +def test_auto_round(vllm_runner, model): + with vllm_runner(model) as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=8) + assert output + print(f"{output[0][1]}") diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 8d9ae282153c..e8ddfd7fc779 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -8,9 +8,11 @@ import pytest import torch +from transformers import BitsAndBytesConfig from tests.quantization.utils import is_quant_method_supported +from ..models.utils import check_embeddings_close from ..utils import compare_two_settings, create_new_process_for_each_test models_4bit_to_test = [ @@ -19,6 +21,10 @@ "quantize inflight model with both HF and Mistral format weights") ] +models_4bit_to_embedding_test = [ + ("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"), +] + models_pre_qaunt_4bit_to_test = [ ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', 'read pre-quantized 4-bit FP4 model'), @@ -39,7 +45,8 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: - hf_model_kwargs = {"load_in_4bit": True} + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], model_name, False, hf_model_kwargs) @@ -77,7 +84,8 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: - hf_model_kwargs = {"load_in_4bit": True} + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], @@ -113,6 +121,54 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None: compare_two_settings(model_name, common_args, pp_args) +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", + models_4bit_to_embedding_test) +@pytest.mark.parametrize("dtype", ["half"]) +@create_new_process_for_each_test() +def test_4bit_bnb_embedding_model( + model_name, + description, + hf_runner, + vllm_runner, + example_prompts, + dtype: str, +) -> None: + + # The example_prompts has ending "\n", for example: + # "Write a short story about a robot that dreams for the first time.\n" + # sentence_transformers will strip the input texts, see: + # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159 + # This makes the input_ids different between hf_model and vllm_model. + # So we need to strip the input texts to avoid test failing. + example_prompts = [str(s).strip() for s in example_prompts] + + # Inflight 4bit quantization + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) + with hf_runner( + model_name, + dtype=dtype, + model_kwargs=hf_model_kwargs, + is_sentence_transformer=True, + ) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model_name, + task="embed", + dtype=dtype, + quantization="bitsandbytes") as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=5e-2, + ) + + def log_generated_texts(prompts, outputs, runner_name): logged_texts = [] for i, (_, generated_text) in enumerate(outputs): diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 70f716f95e89..c968a68f1a8e 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -13,9 +13,9 @@ from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform @@ -648,3 +648,23 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +def test_compressed_tensors_nvfp4a16(vllm_runner): + # run weight only example + model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4" + with vllm_runner(model, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4) + assert qkv_proj.scheme.group_size == 16 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index 1a20228765e8..6571fc9e471b 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -31,9 +31,6 @@ def test_pre_quantized_model(vllm_runner): ]) def test_opt_125m_int4wo_model_loading_with_params(vllm_runner, pt_load_map_location): - """ - Test loading roberta-base model with no lm_head. - """ torch._dynamo.reset() model_name = "jerryzh168/opt-125m-int4wo" with vllm_runner(model_name=model_name, @@ -47,5 +44,20 @@ def test_opt_125m_int4wo_model_loading_with_params(vllm_runner, print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_opt_125m_int4wo_model_per_module_quant(vllm_runner): + torch._dynamo.reset() + model_name = "jerryzh168/opt-125m-int4wo-per-module" + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0") as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/runai_model_streamer_test/test_weight_utils.py b/tests/runai_model_streamer_test/test_weight_utils.py index 4afa76c51693..06e506c35761 100644 --- a/tests/runai_model_streamer_test/test_weight_utils.py +++ b/tests/runai_model_streamer_test/test_weight_utils.py @@ -23,10 +23,11 @@ def test_runai_model_loader(): runai_model_streamer_tensors = {} hf_safetensors_tensors = {} - for name, tensor in runai_safetensors_weights_iterator(safetensors): + for name, tensor in runai_safetensors_weights_iterator( + safetensors, True): runai_model_streamer_tensors[name] = tensor - for name, tensor in safetensors_weights_iterator(safetensors): + for name, tensor in safetensors_weights_iterator(safetensors, True): hf_safetensors_tensors[name] = tensor assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 8884f8ae70b8..6ef61f2ff406 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -169,7 +169,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, @pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("n_rep", [100]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_flashinfer", [True, False]) +# @pytest.mark.parametrize("use_flashinfer", [True, False]) +# Not testing FlashInfer now, since 0.2.3 API removed the ability +# to pass in uniform samples. +@pytest.mark.parametrize("use_flashinfer", [False]) @torch.inference_mode() def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, frac_seeded: float, n_rep: int, device: str, @@ -214,7 +217,10 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", [3, 8, 32, 128]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_flashinfer", [True, False]) +# @pytest.mark.parametrize("use_flashinfer", [True, False]) +# Not testing FlashInfer now, since 0.2.3 API removed the ability +# to pass in uniform samples. +@pytest.mark.parametrize("use_flashinfer", [False]) @torch.inference_mode() def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int, device: str, use_flashinfer: bool): @@ -284,6 +290,10 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int, Test the flashinfer and nonflashinfer backend generate the same output metrics. """ + + pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed " + "the ability to pass in uniform samples.") + torch.set_default_device(device) torch.manual_seed(0) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 6924aba11576..7b19d5750906 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -478,7 +478,7 @@ def test_sampler_mixed(seed: int, device: str): sampling_params = SamplingParams( temperature=random.random() + 0.1, top_p=min(random.random() + 0.1, 1), - top_k=random.randint(0, 10) or -1, + top_k=random.randint(0, 10), n=n, presence_penalty=random.randint(0, 1), ) @@ -647,6 +647,8 @@ def test_flashinfer_fallback(seed: int, device: str): if not envs.VLLM_USE_FLASHINFER_SAMPLER: pytest.skip("Flashinfer sampler is disabled") + pytest.skip("After FlashInfer 0.2.3, sampling will never fail") + set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index a88ae8cda73d..ce8689f5b89c 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,56 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 - -import functools -import gc -from typing import Callable, TypeVar - import pytest -import torch -from typing_extensions import ParamSpec from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Tensorizer only tested on V0 so far. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - @pytest.fixture(autouse=True) def cleanup(): cleanup_dist_env_and_memory(shutdown_ray=True) -_P = ParamSpec("_P") -_R = TypeVar("_R") - - -def retry_until_skip(n: int): - - def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]: - - @functools.wraps(func) - def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R: - for i in range(n): - try: - return func(*args, **kwargs) - except AssertionError: - gc.collect() - torch.cuda.empty_cache() - if i == n - 1: - pytest.skip(f"Skipping test after {n} attempts.") - - raise AssertionError("Code should not be reached") - - return wrapper_retry - - return decorator_retry - - @pytest.fixture(autouse=True) def tensorizer_config(): config = TensorizerConfig(tensorizer_uri="vllm") diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index 5b9661bf6b05..b6286e148397 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -1,17 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import gc -import json import os import pathlib import subprocess -from functools import partial from unittest.mock import MagicMock, patch -import openai import pytest import torch -from huggingface_hub import snapshot_download from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs @@ -22,13 +18,11 @@ is_vllm_tensorized, load_with_tensorizer, open_stream, - serialize_vllm_model, tensorize_vllm_model) # yapf: enable -from vllm.utils import PlaceholderModule, import_from_path +from vllm.utils import PlaceholderModule -from ..utils import VLLM_PATH, RemoteOpenAIServer -from .conftest import retry_until_skip +from ..utils import VLLM_PATH try: from tensorizer import EncryptionParams @@ -104,6 +98,7 @@ def test_can_deserialize_s3(vllm_runner): @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_deserialized_encrypted_vllm_model_has_same_outputs( vllm_runner, tmp_path): + args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: model_path = tmp_path / (model_ref + ".tensors") key_path = tmp_path / (model_ref + ".key") @@ -111,15 +106,13 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( outputs = vllm_model.generate(prompts, sampling_params) - config_for_serializing = TensorizerConfig(tensorizer_uri=model_path, - encryption_keyfile=key_path) + config_for_serializing = TensorizerConfig(tensorizer_uri=str(model_path), + encryption_keyfile=str(key_path)) - vllm_model.apply_model( - partial(serialize_vllm_model, - tensorizer_config=config_for_serializing)) + tensorize_vllm_model(args, config_for_serializing) - config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, - encryption_keyfile=key_path) + config_for_deserializing = TensorizerConfig( + tensorizer_uri=str(model_path), encryption_keyfile=str(key_path)) with vllm_runner(model_ref, load_format="tensorizer", @@ -155,113 +148,46 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, assert outputs == deserialized_outputs -def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): - multilora_inference = import_from_path( - "examples.offline_inference.multilora_inference", - EXAMPLES_PATH / "offline_inference/multilora_inference.py", - ) - - model_ref = "meta-llama/Llama-2-7b-hf" - lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") - test_prompts = multilora_inference.create_test_prompts(lora_path) - - # Serialize model before deserializing and binding LoRA adapters - with vllm_runner(model_ref) as vllm_model: - model_path = tmp_path / (model_ref + ".tensors") - - vllm_model.apply_model( - partial( - serialize_vllm_model, - tensorizer_config=TensorizerConfig(tensorizer_uri=model_path))) - - with vllm_runner( - model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri=model_path, - num_readers=1, - ), - enable_lora=True, - max_loras=1, - max_lora_rank=8, - max_cpu_loras=2, - max_num_seqs=50, - max_model_len=1000, - ) as loaded_vllm_model: - multilora_inference.process_requests( - loaded_vllm_model.model.llm_engine, test_prompts) - - assert loaded_vllm_model - - -def test_load_without_tensorizer_load_format(vllm_runner): +def test_load_without_tensorizer_load_format(vllm_runner, capfd): model = None - with pytest.raises(ValueError): + try: model = vllm_runner( model_ref, model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) - del model - gc.collect() - torch.cuda.empty_cache() - - -@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): - ## Serialize model - with vllm_runner(model_ref) as vllm_model: - model_path = tmp_path / (model_ref + ".tensors") - - vllm_model.apply_model( - partial( - serialize_vllm_model, - tensorizer_config=TensorizerConfig(tensorizer_uri=model_path))) - - model_loader_extra_config = { - "tensorizer_uri": str(model_path), - } - - ## Start OpenAI API server - openai_args = [ - "--dtype", - "float16", - "--load-format", - "tensorizer", - "--model-loader-extra-config", - json.dumps(model_loader_extra_config), - ] - - with RemoteOpenAIServer(model_ref, openai_args) as server: - print("Server ready.") - - client = server.get_client() - completion = client.completions.create(model=model_ref, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert len(completion.choices) == 1 - assert len(completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) - - -def test_raise_value_error_on_invalid_load_format(vllm_runner): + except RuntimeError: + out, err = capfd.readouterr() + combined_output = out + err + assert ("ValueError: Model loader extra config " + "is not supported for load " + "format LoadFormat.AUTO") in combined_output + finally: + del model + gc.collect() + torch.cuda.empty_cache() + + +def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd): model = None - with pytest.raises(ValueError): + try: model = vllm_runner( model_ref, load_format="safetensors", model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) - del model - gc.collect() - torch.cuda.empty_cache() + except RuntimeError: + out, err = capfd.readouterr() + + combined_output = out + err + assert ("ValueError: Model loader extra config is not supported " + "for load format LoadFormat.SAFETENSORS") in combined_output + finally: + del model + gc.collect() + torch.cuda.empty_cache() @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") -def test_tensorizer_with_tp_path_without_template(vllm_runner): - with pytest.raises(ValueError): +def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd): + try: model_ref = "EleutherAI/pythia-1.4b" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" @@ -276,6 +202,13 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner): tensor_parallel_size=2, disable_custom_all_reduce=True, ) + except RuntimeError: + out, err = capfd.readouterr() + combined_output = out + err + assert ("ValueError: For a sharded model, tensorizer_uri " + "should include a string format template like '%04d' " + "to be formatted with the rank " + "of the shard") in combined_output @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") @@ -289,7 +222,6 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( enforce_eager=True, ) as base_model: outputs = base_model.generate(prompts, sampling_params) - base_model.model.llm_engine.model_executor.shutdown() # load model with two shards and serialize with encryption model_path = str(tmp_path / (model_ref + "-%02d.tensors")) @@ -297,7 +229,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( tensorizer_config = TensorizerConfig( tensorizer_uri=model_path, - encryption_keyfile=key_path, + encryption_keyfile=str(key_path), ) tensorize_vllm_model( @@ -325,21 +257,20 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( assert outputs == deserialized_outputs -@retry_until_skip(3) +@pytest.mark.flaky(reruns=3) def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): gc.collect() torch.cuda.empty_cache() model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") config = TensorizerConfig(tensorizer_uri=str(model_path)) + args = EngineArgs(model=model_ref, device="cuda") with vllm_runner(model_ref) as vllm_model: outputs = vllm_model.generate(prompts, sampling_params) - vllm_model.apply_model( - partial(serialize_vllm_model, tensorizer_config=config)) - - assert is_vllm_tensorized(config) + tensorize_vllm_model(args, config) + assert is_vllm_tensorized(config) with vllm_runner(model_ref, load_format="tensorizer", diff --git a/tests/test_logger.py b/tests/test_logger.py index 11deae309ac8..046f70504c89 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 - +import enum import json import logging import os import sys import tempfile +from dataclasses import dataclass from json.decoder import JSONDecodeError from tempfile import NamedTemporaryFile from typing import Any @@ -16,6 +17,7 @@ from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, enable_trace_function_call, init_logger) from vllm.logging_utils import NewLineFormatter +from vllm.logging_utils.dump_input import prepare_object_to_dump def f1(x): @@ -216,3 +218,37 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): assert other_logger.handlers != root_logger.handlers assert other_logger.level != root_logger.level assert other_logger.propagate + + +def test_prepare_object_to_dump(): + str_obj = 'str' + assert prepare_object_to_dump(str_obj) == "'str'" + + list_obj = [1, 2, 3] + assert prepare_object_to_dump(list_obj) == '[1, 2, 3]' + + dict_obj = {'a': 1, 'b': 'b'} + assert prepare_object_to_dump(dict_obj) in [ + "{a: 1, b: 'b'}", "{b: 'b', a: 1}" + ] + + set_obj = {1, 2, 3} + assert prepare_object_to_dump(set_obj) == '[1, 2, 3]' + + tuple_obj = ('a', 'b', 'c') + assert prepare_object_to_dump(tuple_obj) == "['a', 'b', 'c']" + + class CustomEnum(enum.Enum): + A = enum.auto() + B = enum.auto() + C = enum.auto() + + assert prepare_object_to_dump(CustomEnum.A) == repr(CustomEnum.A) + + @dataclass + class CustomClass: + a: int + b: str + + assert (prepare_object_to_dump(CustomClass( + 1, 'b')) == "CustomClass(a=1, b='b')") diff --git a/tests/test_outputs.py b/tests/test_outputs.py new file mode 100644 index 000000000000..c41bd6723ba1 --- /dev/null +++ b/tests/test_outputs.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.outputs import RequestOutput + + +def test_request_output_forward_compatible(): + output = RequestOutput(request_id="test_request_id", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[], + finished=False, + example_arg_added_in_new_version="some_value") + assert output is not None diff --git a/tests/test_regression.py b/tests/test_regression.py index 8c9d4a91c73b..e092945422ed 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -60,6 +60,9 @@ def test_model_from_modelscope(monkeypatch: pytest.MonkeyPatch): # model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary with monkeypatch.context() as m: m.setenv("VLLM_USE_MODELSCOPE", "True") + # Don't use HF_TOKEN for ModelScope repos, otherwise it will fail + # with 400 Client Error: Bad Request. + m.setenv("HF_TOKEN", "") llm = LLM(model="qwen/Qwen1.5-0.5B-Chat") prompts = [ diff --git a/tests/test_utils.py b/tests/test_utils.py index deff33e5c3ca..0b88d05efeaa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ import asyncio import hashlib +import json import pickle import socket from collections.abc import AsyncIterator @@ -17,7 +18,7 @@ from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, deprecate_kwargs, get_open_port, - make_zmq_socket, memory_profiling, + make_zmq_path, make_zmq_socket, memory_profiling, merge_async_iterators, sha256, split_zmq_path, supports_kw, swap_dict_values) @@ -138,6 +139,7 @@ def parser(): parser.add_argument('--model-name') parser.add_argument('--batch-size', type=int) parser.add_argument('--enable-feature', action='store_true') + parser.add_argument('--hf-overrides', type=json.loads) return parser @@ -251,6 +253,29 @@ def test_no_model_tag(parser_with_config, cli_config_file): parser_with_config.parse_args(['serve', '--config', cli_config_file]) +def test_dict_args(parser): + args = [ + "--model-name=something.something", + "--hf-overrides.key1", + "val1", + "--hf-overrides.key2.key3", + "val2", + "--hf-overrides.key2.key4", + "val3", + "--hf-overrides.key5=val4", + ] + parsed_args = parser.parse_args(args) + assert parsed_args.model_name == "something.something" + assert parsed_args.hf_overrides == { + "key1": "val1", + "key2": { + "key3": "val2", + "key4": "val3", + }, + "key5": "val4", + } + + # yapf: enable @pytest.mark.parametrize( "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", @@ -714,3 +739,8 @@ def test_make_zmq_socket_ipv6(): # Clean up zsock.close() ctx.term() + + +def test_make_zmq_path(): + assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555" + assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555" diff --git a/tests/test_vllm_port.py b/tests/test_vllm_port.py new file mode 100644 index 000000000000..ccbb36bf4c06 --- /dev/null +++ b/tests/test_vllm_port.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from unittest.mock import patch + +import pytest + +from vllm.envs import get_vllm_port + + +def test_get_vllm_port_not_set(): + """Test when VLLM_PORT is not set.""" + with patch.dict(os.environ, {}, clear=True): + assert get_vllm_port() is None + + +def test_get_vllm_port_valid(): + """Test when VLLM_PORT is set to a valid integer.""" + with patch.dict(os.environ, {"VLLM_PORT": "5678"}, clear=True): + assert get_vllm_port() == 5678 + + +def test_get_vllm_port_invalid(): + """Test when VLLM_PORT is set to a non-integer value.""" + with (patch.dict(os.environ, {"VLLM_PORT": "abc"}, clear=True), + pytest.raises(ValueError, match="must be a valid integer")): + get_vllm_port() + + +def test_get_vllm_port_uri(): + """Test when VLLM_PORT is set to a URI.""" + with (patch.dict(os.environ, {"VLLM_PORT": "tcp://localhost:5678"}, + clear=True), + pytest.raises(ValueError, match="appears to be a URI")): + get_vllm_port() diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index f1c880286951..b16d9af35be9 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -1,15 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from mistral_common.protocol.instruct.messages import UserMessage +from mistral_common.protocol.instruct.messages import (AssistantMessage, + ToolMessage, + UserMessage) from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.protocol.instruct.tool_calls import Function, Tool +from mistral_common.protocol.instruct.tool_calls import (Function, + FunctionCall, Tool, + ToolCall) from vllm.transformers_utils.tokenizers.mistral import ( make_mistral_chat_completion_request) -# yapf: enable @pytest.mark.parametrize( "openai_request,expected_mistral_request", [( @@ -78,6 +81,107 @@ ) def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request): - assert (make_mistral_chat_completion_request( - openai_request["messages"], - openai_request["tools"]) == expected_mistral_request) + actual_request = make_mistral_chat_completion_request( + openai_request["messages"], openai_request["tools"]) + assert actual_request == expected_mistral_request + + +# Tool use with list content and reasoning_content +@pytest.mark.parametrize("openai_request,expected_mistral_request", [( + { + "messages": [ + { + "role": "user", + "content": "What's the weather in Paris?", + }, + { + "role": + "assistant", + "reasoning_content": + None, + "content": + None, + "tool_calls": [{ + "id": "call123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + }], + }, + { + "role": "tool", + "content": [{ + "type": "text", + "text": "Rainy" + }], + "name": "get_weather", + "tool_call_id": "call123", + }, + ], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"], + }, + }, + }], + }, + ChatCompletionRequest( + messages=[ + UserMessage(content="What's the weather in Paris?"), + AssistantMessage( + content=None, + tool_calls=[ + ToolCall( + id="call123", + function=FunctionCall( + name="get_weather", + arguments='{"city": "Paris"}', + ), + ) + ], + ), + ToolMessage( + content="Rainy", + tool_call_id="call123", + name="get_weather", + ), + ], + tools=[ + Tool( + type="function", + function=Function( + name="get_weather", + description="Gets the current weather in a city.", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"], + }, + ), + ) + ], + ), +)]) +def test_make_mistral_chat_completion_request_list_content( + openai_request, expected_mistral_request): + actual_request = make_mistral_chat_completion_request( + openai_request["messages"], openai_request["tools"]) + assert actual_request == expected_mistral_request diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index 2ab87a0ef41f..291769848145 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from copy import deepcopy from unittest.mock import MagicMock import pytest +import regex as re from pydantic import TypeAdapter from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -333,4 +333,4 @@ def test_streaming_output_valid(output, empty_params, delta_len): combined_messages += message.tool_calls[0].function.arguments combined_messages += "}]" assert json.loads(combined_messages) == output - assert json.dumps(json.loads(combined_messages)) == output_json + assert json.dumps(json.loads(combined_messages)) == output_json \ No newline at end of file diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index c14eaf71e978..efa6455c41df 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -88,7 +88,7 @@ def ensure_system_prompt(messages: list[dict[str, Any]], "meta-llama/Llama-4-Scout-17B-16E-Instruct", "arguments": [ "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "pythonic", "--chat-template", + "--tool-call-parser", "llama4_pythonic", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_llama4_pythonic.jinja"), "-tp", "4" diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index df487ec2ccaa..43a27da2dbe4 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib import pytest import torch @@ -10,8 +11,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager # disable yapf here as it formats differently than isort such that both fail # yapf: disable -from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType, - FreeKVCacheBlockQueue, KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, @@ -19,7 +19,8 @@ hash_request_tokens, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) + KVCacheGroupSpec, KVCacheTensor, + SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -54,21 +55,39 @@ def new_kv_cache_spec(block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, - use_mla=False): + use_mla=False, + sliding_window=None): return FullAttentionSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - use_mla=use_mla) + use_mla=use_mla, + sliding_window=sliding_window) -def test_none_hash(): - assert NONE_HASH is not None - assert isinstance(NONE_HASH, int) - assert NONE_HASH != 0 +def test_none_hash(monkeypatch): + import vllm.v1.core.kv_cache_utils + + # case 1: PYTHONHASHSEED is not set, use random + with monkeypatch.context() as m: + m.delenv('PYTHONHASHSEED', raising=False) + reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) + assert reloaded_kv_cache_utils.NONE_HASH is not None + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) + assert reloaded_kv_cache_utils.NONE_HASH != 0 + + # case 2: PYTHONHASHSEED is set, use the seed + with monkeypatch.context() as m: + m.setenv('PYTHONHASHSEED', 'python hash seed') + reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) + assert reloaded_kv_cache_utils.NONE_HASH is not None + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) + assert sha256('python hash seed') == reloaded_kv_cache_utils.NONE_HASH def test_kv_cache_block(): + import vllm.v1.core.kv_cache_utils + # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) assert block.block_id == 0 @@ -82,7 +101,8 @@ def test_kv_cache_block(): assert block.ref_cnt == 0 # Test block hash setting and resetting - block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3)) + block_hash = vllm.v1.core.kv_cache_utils.BlockHashType(hash_value=123, + token_ids=(1, 2, 3)) block.block_hash = block_hash assert block.block_hash == block_hash @@ -256,13 +276,14 @@ def test_generate_block_hash_extra_keys_cache_salt(): @pytest.mark.parametrize("hash_fn", [sha256, hash]) def test_hash_block_tokens(hash_fn): + import vllm.v1.core.kv_cache_utils parent_block_hash = 123 curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") block_hash = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, extra_keys) - assert isinstance(block_hash, BlockHashType) + assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHashType) assert block_hash.hash_value == hash_fn( (parent_block_hash, curr_block_token_ids, extra_keys)) assert block_hash.token_ids == curr_block_token_ids @@ -271,6 +292,7 @@ def test_hash_block_tokens(hash_fn): @pytest.mark.parametrize("hash_fn", [sha256, hash]) def test_hash_request_tokens(hash_fn): + import vllm.v1.core.kv_cache_utils request = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], @@ -285,8 +307,10 @@ def test_hash_request_tokens(hash_fn): block_hashes = hash_request_tokens(hash_fn, block_size, request) assert len(block_hashes) == 2 - assert isinstance(block_hashes[0], BlockHashType) - assert isinstance(block_hashes[1], BlockHashType) + assert isinstance(block_hashes[0], + vllm.v1.core.kv_cache_utils.BlockHashType) + assert isinstance(block_hashes[1], + vllm.v1.core.kv_cache_utils.BlockHashType) # Check the first block assert block_hashes[0].token_ids == (0, 1, 2) @@ -471,6 +495,68 @@ def test_unify_kv_cache_configs(): unify_kv_cache_configs(diff_kv_cache_config) +def test_merge_kv_cache_spec(): + same_layer_specs = [ + new_kv_cache_spec(num_kv_heads=32), + new_kv_cache_spec(num_kv_heads=32), + ] + merged_layer_spec = same_layer_specs[0].merge(same_layer_specs) + assert merged_layer_spec.block_size == 16 + assert merged_layer_spec.num_kv_heads == 32 + assert merged_layer_spec.head_size == 64 + assert merged_layer_spec.dtype == torch.float32 + assert merged_layer_spec.sliding_window is None + + different_layer_specs = [ + new_kv_cache_spec(num_kv_heads=32), + new_kv_cache_spec(num_kv_heads=16), + ] + with pytest.raises(AssertionError): + different_layer_specs[0].merge(different_layer_specs) + + full_spec = new_kv_cache_spec(num_kv_heads=32) + different_type_layer_specs = [ + full_spec, + SlidingWindowSpec( + block_size=full_spec.block_size, + num_kv_heads=full_spec.num_kv_heads, + head_size=full_spec.head_size, + dtype=full_spec.dtype, + use_mla=full_spec.use_mla, + sliding_window=1, + ), + ] + with pytest.raises(AssertionError): + different_type_layer_specs[0].merge(different_type_layer_specs) + with pytest.raises(AssertionError): + different_type_layer_specs[1].merge(different_type_layer_specs) + + different_sliding_window_layer_specs = [ + new_kv_cache_spec(num_kv_heads=32), + new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + new_kv_cache_spec(num_kv_heads=32, sliding_window=2), + ] + with pytest.raises(ValueError): + different_sliding_window_layer_specs[0].merge( + different_sliding_window_layer_specs) + + same_sliding_window_layer_specs = [ + new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + ] + merged_layer_spec = same_sliding_window_layer_specs[0].merge( + same_sliding_window_layer_specs) + assert merged_layer_spec.sliding_window == 1 + + same_sliding_window_layer_spec_with_none = [ + new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + new_kv_cache_spec(num_kv_heads=32, sliding_window=None), + ] + merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge( + same_sliding_window_layer_spec_with_none) + assert merged_layer_spec.sliding_window == 1 + + @pytest.mark.parametrize( ("model_id", "max_model_len", "want_estimated_max_len"), [ ("Qwen/Qwen1.5-7B", 16385, 16384), @@ -539,7 +625,7 @@ def test_allocate_with_lookahead(): max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=2, # Total required: 3+2=5 tokens ) assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks @@ -550,7 +636,7 @@ def test_allocate_with_lookahead(): # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=2, ) assert len(blocks.blocks) == 2 @@ -561,7 +647,7 @@ def test_allocate_with_lookahead(): max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=4, ) assert len(blocks.blocks) == 2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 01295e848ee9..3da27786b1f2 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -81,8 +81,10 @@ def test_prefill(hash_algo): assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + blocks = manager.allocate_slots(req0, 55, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[1, 2, 3, 4]] # Check full block metadata parent_block_hash = None @@ -105,11 +107,13 @@ def test_prefill(hash_algo): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [1, 2, 3] + assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert blocks.get_block_ids() == [5] + blocks = manager.allocate_slots(req1, num_new_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[5]] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -137,11 +141,13 @@ def test_prefill(hash_algo): req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert computed_blocks.get_block_ids() == [1, 2, 3] + assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert blocks.get_block_ids() == [6] + blocks = manager.allocate_slots(req2, num_new_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[6]] # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -161,9 +167,11 @@ def test_prefill(hash_algo): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 10, + len(computed_blocks.blocks) * 16, + computed_blocks) # This block ID order also checks the eviction order. - assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] + assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -194,11 +202,13 @@ def test_prefill_plp(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(manager.req_to_block_hashes[req0.request_id]) == 0 assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + blocks = manager.allocate_slots(req0, 55, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0_block_hashes = [b.block_hash for b in blocks.blocks] # Check full block metadata @@ -223,11 +233,13 @@ def test_prefill_plp(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [1, 2, 3] + assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert blocks.get_block_ids() == [5] + blocks = manager.allocate_slots(req1, num_new_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[5]] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -256,18 +268,20 @@ def test_prefill_plp(): common_token_ids + unique_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 3 + assert len(manager.req_to_block_hashes[req2.request_id]) == 0 assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 55, computed_blocks) + blocks = manager.allocate_slots(req2, 55, + len(computed_blocks.blocks) * 16, + computed_blocks) block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks] == req0_block_hashes - assert block_ids != [1, 2, 3, 4] + assert block_ids != [[1, 2, 3, 4]] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. - for block_id in block_ids: + for block_id in block_ids[0]: assert manager.block_pool.blocks[block_id].ref_cnt == 1 manager.free(req2) @@ -290,16 +304,21 @@ def test_decode(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + blocks = manager.allocate_slots(req0, 55, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[1, 2, 3, 4]] # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 4) + new_blocks = manager.allocate_slots(req0, 4, + len(computed_blocks.blocks) * 16, + computed_blocks) assert new_blocks is not None and len(new_blocks.blocks) == 0 - assert manager.req_to_blocks[req0.request_id][-1].block_hash is None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-1].block_hash is None # Append slots with allocating a new block. req0.num_computed_tokens = 59 @@ -307,10 +326,14 @@ def test_decode(): # the preallocated block. for _ in range(9 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 19) + new_blocks = manager.allocate_slots(req0, 19, + len(computed_blocks.blocks) * 16, + computed_blocks) assert new_blocks is not None and len(new_blocks.blocks) == 1 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None - assert manager.req_to_blocks[req0.request_id][-1].block_hash is None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-2].block_hash is not None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-1].block_hash is None def test_evict(): @@ -325,7 +348,9 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) + blocks = manager.allocate_slots(req0, 5 * 16 + 7, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 6 # 5 full + 1 partial # 3 blocks. @@ -334,7 +359,9 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) + blocks = manager.allocate_slots(req1, 3 * 16, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 3 # 3 full blocks last_token_id += 3 * 16 @@ -352,10 +379,12 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert computed_blocks.get_block_ids() == [1, 2] + assert computed_blocks.get_block_ids() == [[1, 2]] assert num_computed_tokens == 2 * 16 - blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert blocks.get_block_ids() == [10] + blocks = manager.allocate_slots(req2, 3, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[10]] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -377,7 +406,9 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req, num_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 1 # Deallocate the block. @@ -389,7 +420,9 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) + blocks = manager.allocate_slots(req, num_tokens - 1, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 1 assert manager.block_pool.blocks[ @@ -414,7 +447,9 @@ def test_computed_blocks_not_evicted(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req0, num_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 1 assert blocks.blocks[0].block_id == 1 @@ -423,7 +458,9 @@ def test_computed_blocks_not_evicted(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req1, num_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 1 assert blocks.blocks[0].block_id == 2 @@ -440,6 +477,7 @@ def test_computed_blocks_not_evicted(): assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, + len(computed_blocks.blocks) * 16, computed_blocks) assert len(blocks.blocks) == 1 assert blocks.blocks[0].block_id == 2 @@ -461,7 +499,9 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 10, computed_blocks) + blocks = manager.allocate_slots(req1, 10, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 3 # Free the blocks. @@ -472,7 +512,9 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 16, computed_blocks) + blocks = manager.allocate_slots(req2, 16, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 4 # New requests should not have any blocks. @@ -480,7 +522,9 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 4, computed_blocks) + blocks = manager.allocate_slots(req3, 4, + len(computed_blocks.blocks) * 16, + computed_blocks) assert not blocks @@ -578,14 +622,18 @@ def test_mm_prefix_caching(): assert block_hashes[1].extra_keys == ("aaa", "bbb") assert block_hashes[2].extra_keys == ("bbb", ) - blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + blocks = manager.allocate_slots(req0, 59, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5) + new_blocks = manager.allocate_slots(req0, 5, + len(computed_blocks.blocks) * 16, + computed_blocks) assert new_blocks is not None and len(new_blocks.blocks) == 0 # The just completed block should have hashes with extra keys. @@ -635,14 +683,18 @@ def test_cache_key_salting(): assert block_hashes[1].extra_keys is None assert block_hashes[2].extra_keys is None - blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + blocks = manager.allocate_slots(req0, 59, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5) + new_blocks = manager.allocate_slots(req0, 5, + len(computed_blocks.blocks) * 16, + computed_blocks) assert new_blocks is not None and len(new_blocks.blocks) == 0 # Now one more block that should not have extra keys. @@ -688,16 +740,18 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks assert num_computed_tokens == 0 - manager.allocate_slots(req0, 48, computed_blocks) - block_part0 = manager.req_to_blocks[req0.request_id] + manager.allocate_slots(req0, 48, + len(computed_blocks.blocks) * 16, computed_blocks) + block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks == block_part0 assert num_computed_tokens == 3 * 16 - manager.allocate_slots(req1, 48, computed_blocks) - block_part1 = manager.req_to_blocks[req1.request_id] + manager.allocate_slots(req1, 48, + len(computed_blocks.blocks) * 16, computed_blocks) + block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) @@ -710,7 +764,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks assert num_computed_tokens == 0 - manager.allocate_slots(req2, block_size * 2, computed_blocks) + manager.allocate_slots(req2, block_size * 2, + len(computed_blocks.blocks) * 16, computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). @@ -721,7 +776,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert computed_blocks.blocks == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. - assert manager.allocate_slots(req3, 48, computed_blocks) is None + assert manager.allocate_slots(req3, 48, + len(computed_blocks.blocks) * 16, + computed_blocks) is None # Block 0-2 are used by Req 1. assert {block.ref_cnt for block in block_part1[:3]} == {1} # Block 3-5 are free. @@ -740,7 +797,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) - assert blocks.get_block_ids() == [1, 2, 3, 4] + assert blocks.get_block_ids() == [[1, 2, 3, 4]] unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -748,8 +805,10 @@ def test_reset_prefix_cache(): computed_blocks, _ = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks.blocks) == 3 - blocks = manager.allocate_slots(req1, 7, computed_blocks) - assert blocks.get_block_ids() == [5] + blocks = manager.allocate_slots(req1, 7, + len(computed_blocks.blocks) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[5]] # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() @@ -779,7 +838,8 @@ def test_prefix_cache_stats_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks assert num_computed_tokens == 0 - manager.allocate_slots(req, 16, computed_blocks) + manager.allocate_slots(req, 16, + len(computed_blocks.blocks) * 16, computed_blocks) manager.reset_prefix_cache() # Ensure prefix_cache_stats remains None @@ -857,7 +917,8 @@ def test_eagle_enabled_removes_last_block(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.allocate_slots(req, len(token_ids), + len(computed_blocks.blocks) * 16, computed_blocks) manager.free(req) # New request with same tokens + Eagle enabled @@ -886,7 +947,8 @@ def test_eagle_with_partial_blocks(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.allocate_slots(req, len(token_ids), + len(computed_blocks.blocks) * 16, computed_blocks) manager.free(req) # New request with Eagle enabled @@ -925,7 +987,8 @@ def test_eagle_with_sliding_window(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.allocate_slots(req, len(token_ids), + len(computed_blocks.blocks) * 16, computed_blocks) # record the block hash of the first block in the request for later use block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] assert block_hash_first_block is not None diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index bfe9df10d4d1..f40d477a0036 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -812,10 +812,11 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req_id in req_ids: - blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] + blocks = (scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[req_id]) hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] - assert (scheduler.kv_cache_manager.num_cached_block[req_id] == - EXPECTED_TOTAL_BLOCKS) + assert (scheduler.kv_cache_manager.single_type_manager. + num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS @@ -869,7 +870,7 @@ def test_kv_connector_basic(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) ###################################################### # FIRST SET OF REQUESTS - External Hit Only @@ -980,7 +981,7 @@ def test_kv_connector_unable_to_allocate(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. @@ -1059,7 +1060,7 @@ def test_kv_connector_handles_preemption(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) # Create two requests. # Both can be scheduled at first, but the second request @@ -1195,9 +1196,11 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.req_to_blocks) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 - assert len(scheduler.kv_cache_manager.num_cached_block) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( diff --git a/tests/v1/core/test_scheduler_e2e.py b/tests/v1/core/test_scheduler_e2e.py index 0a79424a30b7..511d57d405ba 100644 --- a/tests/v1/core/test_scheduler_e2e.py +++ b/tests/v1/core/test_scheduler_e2e.py @@ -19,7 +19,8 @@ def model() -> LLM: enable_prefix_caching=True, long_prefill_token_threshold=2, max_num_batched_tokens=6, - max_num_seqs=3) + max_num_seqs=3, + block_size=16) def test_concurrent_partial_prefill(model): @@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model): assert len(outputs) == 3 for output in outputs: assert len(output.outputs) == 1 + + +def test_prefix_cache_stats_is_recorded(model): + # 17 tokens will make sure first 16 tokens are cached in a block + input_tokens = {"prompt_token_ids": [101] * 17} + _ = model.generate([input_tokens]) + outputs = model.generate([input_tokens]) + assert outputs[0].num_cached_tokens == 16 diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 595c8608fc64..101a2379be37 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -4,13 +4,22 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock -from vllm.v1.core.specialized_manager import SlidingWindowManager +from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager from vllm.v1.kv_cache_interface import SlidingWindowSpec +def get_sliding_window_manager(sliding_window_spec, block_pool): + return SlidingWindowManager(sliding_window_spec, + block_pool, + use_eagle=False, + num_kv_cache_groups=1, + caching_hash_fn=lambda x: x) + + def test_sliding_window_possible_cached_prefix(): + block_size = 2 sliding_window_spec = SlidingWindowSpec( - block_size=2, + block_size=block_size, num_kv_heads=1, head_size=1, dtype=torch.float32, @@ -19,9 +28,7 @@ def test_sliding_window_possible_cached_prefix(): ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, - block_pool, - use_eagle=False) + manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): block_hash_list = [ @@ -38,7 +45,9 @@ def run_one_case(block_is_cached, expect_length): i: block_pool.blocks[i + 10] } - computed_blocks = manager.find_longest_cache_hit(block_hash_list) + computed_blocks = manager.find_longest_cache_hit( + block_hash_list, + len(block_hash_list) * block_size) assert len(computed_blocks) == expect_length assert all(block == block_pool.null_block @@ -81,9 +90,7 @@ def test_sliding_window_remove_skipped_blocks(): block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, - block_pool, - use_eagle=False) + manager = get_sliding_window_manager(sliding_window_spec, block_pool) null_block_id = block_pool.null_block.block_id @@ -104,39 +111,35 @@ def assert_block_id(block_table, ids): 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 ] block_table = id_to_block_table(original_block_ids) - removed = manager.remove_skipped_blocks(block_table, 0) - assert_block_id(removed, []) + manager.req_to_blocks["test"] = block_table + + manager.remove_skipped_blocks("test", 0) assert_block_id(block_table, original_block_ids) # 4 tokens are computed. Only token 0 is out of the sliding window. As # block 1000 also contains token 1 that is in the sliding window, block 1000 # cannot be removed. - removed = manager.remove_skipped_blocks(block_table, 4) - assert_block_id(removed, []) + manager.remove_skipped_blocks("test", 4) assert_block_id(block_table, original_block_ids) # 5 tokens are computed. Token 0 & 1 are out of the sliding window. # Block 1000 can be removed. - removed = manager.remove_skipped_blocks(block_table, 5) - assert_block_id(removed, [original_block_ids[0]]) + manager.remove_skipped_blocks("test", 5) assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) # 6 tokens are computed. Token 0-2 are out of the sliding window. # Cannot remove new block as the block 1001 is still used by token 3. - removed = manager.remove_skipped_blocks(block_table, 6) - assert_block_id(removed, []) + manager.remove_skipped_blocks("test", 6) assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) # 7 tokens are computed. Token 0-3 are out of the sliding window. # Block 1001 can be removed and block 1000 is already removed. - removed = manager.remove_skipped_blocks(block_table, 7) - assert_block_id(removed, [original_block_ids[1]]) + manager.remove_skipped_blocks("test", 7) assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) # 11 tokens are computed. Token 0-7 are out of the sliding window. # Block 1002 & 1003 can be removed now. Block 1003 represents a longer # sequence, and is expected to be evicted earlier than 1002, so the order # of removed blocks should be [1003, 1002]. - removed = manager.remove_skipped_blocks(block_table, 11) - assert_block_id(removed, [original_block_ids[3], original_block_ids[2]]) + manager.remove_skipped_blocks("test", 11) assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:]) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index fd8d1fd7ff48..8bea032f656f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -18,9 +18,10 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine, - EngineCoreClient, SyncMPClient) +from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, + SyncMPClient) from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import CoreEngineProcManager from ...distributed.conftest import MockSubscriber from ...utils import create_new_process_for_each_test @@ -289,7 +290,6 @@ def test_kv_cache_events( log_stats=False, ) endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") - time.sleep(0.1) subscriber = MockSubscriber(endpoint, topic=publisher_config.topic, decode_type=KVEventBatch) @@ -348,13 +348,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch): # Monkey-patch to extract core process pid while it's starting. core_proc_pid = [None] - ce_ctor = CoreEngine.__init__ + cepm_ctor = CoreEngineProcManager.__init__ - def patched_ce_ctor(self, *args, **kwargs): - ce_ctor(self, *args, **kwargs) - core_proc_pid[0] = self.proc_handle.proc.pid + def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs): + cepm_ctor(self, *args, **kwargs) + core_proc_pid[0] = self.processes[0].pid - m.setattr(CoreEngine, "__init__", patched_ce_ctor) + m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor) t = time.time() engine_args = EngineArgs(model=MODEL_NAME) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index cefb89eb652b..e77916f95823 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -6,6 +6,7 @@ import pytest from vllm import LLM, SamplingParams +from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector MODEL = "facebook/opt-125m" DTYPE = "half" @@ -97,3 +98,67 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: raise AssertionError( f"{len(completion_counts)} unique completions; expected" f" {n}. Repeats: {repeats}") + + +def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): + max_tokens = 100 + # Use spec decoding to test num_accepted_tokens_per_pos + speculative_config = { + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 5, + } + monkeypatch.setenv("VLLM_USE_V1", "1") + with vllm_runner( + MODEL, + speculative_config=speculative_config, + disable_log_stats=False, + ) as vllm_model: + model: LLM = vllm_model.model + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens) + outputs = model.generate(example_prompts, sampling_params) + + n_prompts = len(example_prompts) + assert len(outputs) == n_prompts + + total_tokens = 0 + for out in outputs: + assert len(out.outputs) == 1 + total_tokens += len(out.outputs[0].token_ids) + assert total_tokens == max_tokens * n_prompts + + metrics = model.get_metrics() + + def find_metric(name) -> list[Metric]: + found = [] + for metric in metrics: + if metric.name == name: + found.append(metric) + return found + + num_requests_running = find_metric("vllm:num_requests_running") + assert len(num_requests_running) == 1 + assert isinstance(num_requests_running[0], Gauge) + assert num_requests_running[0].value == .0 + + generation_tokens = find_metric("vllm:generation_tokens") + assert len(generation_tokens) == 1 + assert isinstance(generation_tokens[0], Counter) + assert generation_tokens[0].value == total_tokens + + request_generation_tokens = find_metric( + "vllm:request_generation_tokens") + assert len(request_generation_tokens) == 1 + assert isinstance(request_generation_tokens[0], Histogram) + assert "+Inf" in request_generation_tokens[0].buckets + assert request_generation_tokens[0].buckets["+Inf"] == n_prompts + assert request_generation_tokens[0].count == n_prompts + assert request_generation_tokens[0].sum == total_tokens + + num_accepted_tokens_per_pos = find_metric( + "vllm:spec_decode_num_accepted_tokens_per_pos") + assert len(num_accepted_tokens_per_pos) == 1 + assert isinstance(num_accepted_tokens_per_pos[0], Vector) + assert len(num_accepted_tokens_per_pos[0].values) == 5 diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index d84b2b22db12..8c03f04330dd 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -72,12 +72,16 @@ def sample_json_schema(): "type": "string" } }, - "required": ["company", "duration", "position"] - } + "required": ["company", "duration", "position"], + "additionalProperties": False + }, + "minItems": 0, + "maxItems": 3 } }, "required": - ["name", "age", "skills", "grade", "email", "work_history"] + ["name", "age", "skills", "grade", "email", "work_history"], + "additionalProperties": False } @@ -100,7 +104,8 @@ def unsupported_json_schema(): } } }, - "required": ["score", "tags"] + "required": ["score", "tags"], + "additionalProperties": False } @@ -139,7 +144,8 @@ def sample_definition_json_schema(): }, 'required': ['steps', 'final_answer'], 'title': 'MathReasoning', - 'type': 'object' + 'type': 'object', + "additionalProperties": False } diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 81601c87ad8b..5f1fff200de3 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -1,21 +1,27 @@ +# ruff: noqa: E501 # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import json -import re from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any import jsonschema import pytest +import regex as re from pydantic import BaseModel +from tests.reasoning.utils import run_reasoning_extraction from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager from vllm.sampling_params import GuidedDecodingParams, SamplingParams +if TYPE_CHECKING: + from vllm.config import TokenizerMode + NGRAM_SPEC_CONFIG = { "model": "[ngram]", "num_speculative_tokens": 5, @@ -62,6 +68,16 @@ class CarDescription(BaseModel): car_type: CarType +def _load_json(s: str, backend: str) -> str: + if backend != "xgrammar": + return json.loads(s) + + # xgrammar specific workarounds + # https://github.com/mlc-ai/xgrammar/issues/286 + s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s) + return json.loads(s) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( "model_name, guided_decoding_backend, tokenizer_mode, speculative_config", @@ -102,7 +118,7 @@ def test_structured_output( # sampling_params = SamplingParams( temperature=1.0, - max_tokens=1000, + max_tokens=4096, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) outputs = llm.generate(prompts=[ (f"Give an example JSON for an employee profile that fits this " @@ -131,7 +147,7 @@ def test_structured_output( # sampling_params = SamplingParams( temperature=1.0, - max_tokens=100, + max_tokens=4096, n=2, guided_decoding=GuidedDecodingParams(json_object=True)) @@ -161,7 +177,7 @@ def test_structured_output( # sampling_params = SamplingParams( temperature=1.0, - max_tokens=1000, + max_tokens=4096, guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) if guided_decoding_backend.startswith("xgrammar"): with pytest.raises(ValueError, @@ -376,12 +392,13 @@ def test_structured_output( "minLength": min_length } }, - "required": ["description"] + "required": ["description"], + "additionalProperties": False } sampling_params = SamplingParams( temperature=1.0, - max_tokens=1000, + max_tokens=4096, guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( @@ -417,7 +434,8 @@ def test_structured_output( "city": { "type": "string" } - } + }, + "additionalProperties": False }, "end": "</function>" }], @@ -426,13 +444,13 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.0, - max_tokens=100, + max_tokens=4096, guided_decoding=GuidedDecodingParams( structural_tag=json.dumps(structural_tag_config))) prompt = """ You have access to the following function to retrieve the weather in a city: - + { "name": "get_weather", "parameters": { @@ -443,7 +461,7 @@ def test_structured_output( } } } - + If a you choose to call a function ONLY reply in the following format: <{start_tag}={function_name}>{parameters}{end_tag} where @@ -464,7 +482,7 @@ def test_structured_output( - Always add your sources when using search results to answer the user query You are a helpful assistant. - + Given the previous instructions, what is the weather in New York City? \ Make the response as short as possible. """ @@ -502,6 +520,88 @@ def test_structured_output( f"{generated_text!r}\nError: {str(e)}") +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize( + "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 + [ + ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", + "deepseek_r1", NGRAM_SPEC_CONFIG), + ("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None), + ], +) +def test_structured_output_with_reasoning_matrices( + monkeypatch: pytest.MonkeyPatch, + guided_decoding_backend: str, + tokenizer_mode: TokenizerMode, + reasoning_parser: str, + model_name: str, + speculative_config: dict[str, Any] | None, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + if current_platform.is_tpu() and speculative_config: + pytest.skip("TPU does not support speculative decoding") + + # Use a single LLM instance for several scenarios to + # speed up the test suite. + llm = LLM( + model=model_name, + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager=bool(not current_platform.is_tpu()), + max_model_len=1024, + max_num_seqs=16, + guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=True, + tokenizer_mode=tokenizer_mode, + reasoning_parser=reasoning_parser, + speculative_config=speculative_config, + ) + tokenizer = llm.get_tokenizer(None) + reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( + tokenizer=tokenizer) + + reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501 + reasoning_schema = { + "type": "object", + "properties": { + "result": { + "type": "integer" + } + }, + "required": ["result"], + "additionalProperties": False + } + if "Qwen3" in model_name: + reasoning_prompt += "<think>\n" + + sampling_params = SamplingParams( + temperature=0.1, + max_tokens=8192, + guided_decoding=GuidedDecodingParams(json=reasoning_schema), + ) + outputs = llm.generate( + [reasoning_prompt], + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + output = outputs[0] + assert output is not None and isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + reasoning_content, content = run_reasoning_extraction( + reasoner, [generated_text]) + print( + f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" + ) + + assert content is not None and reasoning_content is not None + output_json = json.loads(content) + jsonschema.validate(instance=output_json, schema=reasoning_schema) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("model_name, tokenizer_mode", PARAMS_MODELS_TOKENIZER_MODE) diff --git a/tests/v1/entrypoints/openai/test_chat_completion.py b/tests/v1/entrypoints/openai/test_chat_completion.py new file mode 100644 index 000000000000..c650ccd0ccd7 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_chat_completion.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +# any model with a chat template defined in tokenizer_config should work here +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_json_schema(client: openai.AsyncOpenAI, + model_name: str) -> None: + invalid_json_schema = { + "$defs": { + "CarType": { + "enum": ["sedan", "SUV", "Truck", "Coupe"], + "title": "CarType", + "type": "string", + } + }, + "properties": { + "brand": { + "title": "Brand", + "type": "string" + }, + "model": { + "title": "Model", + "type": "string" + }, + "car_type": { + "$ref": "#/$defs/CarType" + }, + "foo": "bar", + }, + "required": ["brand", "model", "car_type"], + "title": "CarDescription", + "type": "object", + } + prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": invalid_json_schema}, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): + prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": r"[.*", + "stop": ["\n"] + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): + invalid_simplified_sql_grammar = """ + root ::= select_statementinvalidsyntax + + select_statement ::= "SELECT " column " from " table " where " condition + + column ::= "col_1 " | "col_2 " + + table ::= "table_1 " | "table_2 " + + condition ::= column "= " number + + number ::= "1 " | "2 " + """ + + prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_grammar": invalid_simplified_sql_grammar}, + ) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 57ca99e1f68c..333ad23795f3 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -import re from typing import Optional import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re from openai import BadRequestError from tests.utils import RemoteOpenAIServer @@ -584,3 +584,97 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_json_schema(client: openai.AsyncOpenAI, + model_name: str) -> None: + invalid_json_schema = { + "$defs": { + "CarType": { + "enum": ["sedan", "SUV", "Truck", "Coupe"], + "title": "CarType", + "type": "string", + } + }, + "properties": { + "brand": { + "title": "Brand", + "type": "string" + }, + "model": { + "title": "Model", + "type": "string" + }, + "car_type": { + "$ref": "#/$defs/CarType" + }, + "foo": "bar", + }, + "required": ["brand", "model", "car_type"], + "title": "CarDescription", + "type": "object", + } + prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.completions.create( + model=model_name, + prompt=prompt, + extra_body={"guided_json": invalid_json_schema}, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): + prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.completions.create( + model=model_name, + prompt=prompt, + extra_body={ + "guided_regex": r"[.*", + "stop": ["\n"] + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): + invalid_simplified_sql_grammar = """ + root ::= select_statementinvalidsyntax + + select_statement ::= "SELECT " column " from " table " where " condition + + column ::= "col_1 " | "col_2 " + + table ::= "table_1 " | "table_2 " + + condition ::= column "= " number + + number ::= "1 " | "2 " + """ + + prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.completions.create( + model=model_name, + prompt=prompt, + extra_body={"guided_grammar": invalid_simplified_sql_grammar}, + ) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh new file mode 100755 index 000000000000..c17784e0a263 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -0,0 +1,180 @@ +#!/bin/bash +set -xe + +# Models to run +MODELS=( + "Qwen/Qwen3-0.6B" +) + +# Number of prefill and decode instances to create +NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 +NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2 + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +SMI_BIN=$(which nvidia-smi || which rocm-smi) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then + extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" + fi + + echo "$extra_args" +} + +get_num_gpus() { + if [[ "$SMI_BIN" == *"nvidia"* ]]; then + echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)" + else + echo "$($SMI_BIN -l | grep GPU | wc -l)" + fi +} + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Arrays to store all hosts and ports + PREFILL_HOSTS=() + PREFILL_PORTS=() + DECODE_HOSTS=() + DECODE_PORTS=() + + # Start prefill instances + for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs + GPU_ID=$((i % $(get_num_gpus))) + # Calculate port number (base port + instance number) + PORT=$((8100 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((5559 + i)) + + echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + PREFILL_HOSTS+=("localhost") + PREFILL_PORTS+=($PORT) + done + + # Start decode instances + for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs + GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(get_num_gpus))) + # Calculate port number (base port + instance number) + PORT=$((8200 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((5659 + i)) + + echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + DECODE_HOSTS+=("localhost") + DECODE_PORTS+=($PORT) + done + + # Wait for all instances to start + for PORT in "${PREFILL_PORTS[@]}"; do + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PORT + done + + for PORT in "${DECODE_PORTS[@]}"; do + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $PORT + done + + # Build the command for the proxy server with all the hosts and ports + PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" + + # Add all prefill hosts and ports + PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}" + + # Add all decode hosts and ports + PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}" + + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" diff --git a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh new file mode 100644 index 000000000000..98903a176e28 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh @@ -0,0 +1,123 @@ +#!/bin/bash +set -xe + +# Models to run +MODELS=( + "Qwen/Qwen3-0.6B" +) + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then + extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" + fi + + echo "$extra_args" +} + + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Start prefill instance + PREFILL_PORT=8001 + + BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \ + --port $PREFILL_PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Start decode instance + DECODE_PORT=8002 + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \ + --port $DECODE_PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Wait for all instances to start + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PREFILL_PORT + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $DECODE_PORT + + # Build the command for the proxy server with all the hosts and ports + PROXY_PORT=8192 + PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORT}" + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py new file mode 100644 index 000000000000..be2d84f3bb17 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import lm_eval +import openai + +BASE_URL = "http://localhost:8192/v1" +NUM_CONCURRENT = 100 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 + +# Model-specific expected values +EXPECTED_VALUES = { + "Qwen/Qwen3-0.6B": 0.41, +} + +SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 + +# Get model name from environment variable +MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B") + + +def run_simple_prompt(): + client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL) + completion = client.completions.create(model=MODEL_NAME, + prompt=SIMPLE_PROMPT) + + print("-" * 50) + print(f"Completion results for {MODEL_NAME}:") + print(completion) + print("-" * 50) + + +def test_accuracy(): + """Run the end to end accuracy test.""" + run_simple_prompt() + + model_args = (f"model={MODEL_NAME}," + f"base_url={BASE_URL}/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + expected_value = EXPECTED_VALUES.get(MODEL_NAME) + + if expected_value is None: + print(f"Warning: No expected value found for {MODEL_NAME}. " + "Skipping accuracy check.") + print(f"Measured value: {measured_value}") + return + + assert (measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py new file mode 100644 index 000000000000..5363fbde0096 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import openai + +PREFILL_PORT = os.getenv("PREFILL_PORT", None) +DECODE_PORT = os.getenv("DECODE_PORT", None) +PROXY_PORT = os.getenv("PROXY_PORT", None) + +if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: + raise ValueError( + "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") + +LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501 +PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501 +SHORT_PROMPT = "Red Hat is " + + +def test_edge_cases(): + # Set the OpenAI API key and base URL + decode_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{DECODE_PORT}/v1", + ) + prefill_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PREFILL_PORT}/v1", + ) + proxy_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PROXY_PORT}/v1", + ) + + # Get the list of models + models = decode_client.models.list() + MODEL = models.data[0].id + + # (1) Check that we can handle a very short prompt, + # less than the length of the block size. + completion = proxy_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"SMALL PROMPT: {proxy_response=}") + assert proxy_response == prefill_response + + # (2) Check that we can handle a full prefix cache + # hit on the D worker but not on the P worker. + # (2a): prime the D worker. + completion = decode_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + decode_response = completion.choices[0].text + # (2b): send via the P/D setup + completion = proxy_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + print(f"FULL CACHE HIT: {proxy_response=}") + assert proxy_response == decode_response + + # (3) Check that we can handle a partial prefix cache + # hit on the D worker. + completion = proxy_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"PARTIAL CACHE HIT: {proxy_response=}") + assert proxy_response == prefill_response diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py new file mode 100644 index 000000000000..13071f581375 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import itertools +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + + # Create prefill clients + for i, (host, port) in enumerate(global_args.prefiller_instances): + prefiller_base_url = f'http://{host}:{port}/v1' + app.state.prefill_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Create decode clients + for i, (host, port) in enumerate(global_args.decoder_instances): + decoder_base_url = f'http://{host}:{port}/v1' + app.state.decode_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Initialize round-robin iterators + app.state.prefill_iterator = itertools.cycle( + range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle( + range(len(app.state.decode_clients))) + + print(f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients.") + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info['client'].aclose() + + for client_info in app.state.decode_clients: + await client_info['client'].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + + # For prefiller instances + parser.add_argument("--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + "--prefiller-port", + type=int, + nargs="+", + default=[8100]) + + # For decoder instances + parser.add_argument("--decoder-hosts", + "--decoder-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", + "--decoder-port", + type=int, + nargs="+", + default=[8200]) + + args = parser.parse_args() + + # Validate and pair hosts with ports + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports") + + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError( + "Number of decoder hosts must match number of decoder ports") + + # Create tuples of (host, port) for each service type + args.prefiller_instances = list( + zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + + return args + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == 'prefill': + client_idx = next(app.state.prefill_iterator) + return app.state.prefill_clients[client_idx] + elif service_type == 'decode': + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data['kv_transfer_params'] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + response = await client_info['client'].post(endpoint, + json=req_data, + headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + async with client_info['client'].stream("POST", + endpoint, + json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info = get_next_client(request.app, 'prefill') + + # Send request to prefill service + response = await send_request_to_service(prefill_client_info, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + # Get the next decode client in round-robin fashion + decode_client_info = get_next_client(request.app, 'decode') + + logger.debug("Using %s %s", prefill_client_info, decode_client_info) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(decode_client_info, + "/completions", + req_data, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.get("/healthcheck") +async def healthcheck(): + """Simple endpoint to check if the server is running.""" + return { + "status": "ok", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients) + } + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/v1/kv_connector/unit/__init__.py b/tests/v1/kv_connector/unit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py new file mode 100644 index 000000000000..a21d92c52244 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +import filecmp +import shutil +import tempfile +from collections import defaultdict +from pathlib import Path + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa + SharedStorageConnector) + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + +PROMPT_CONTEXT = "Hi " * 100 +PROMPTS = [ + PROMPT_CONTEXT + "Hello, my name is", + PROMPT_CONTEXT + "The capital of France is", +] + +SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) + + +class TestSharedStorageConnector(SharedStorageConnector): + + def __init__(self, config: VllmConfig, role): + self.name = config.kv_transfer_config.kv_connector_extra_config["name"] + self._connector = SharedStorageConnector(config, role) + self.call_record: dict[str, int] = defaultdict(int) + # Use a unique temp file per connector + self._event_file = tempfile.gettempdir( + ) + f"/connector_{self.name}_events.log" + # Start with an empty file + with open(self._event_file, "w") as _: + pass + + def __getattribute__(self, name): + if name in ("_connector", "call_record", "name", "_event_file", + "__class__", "__dict__", "__getattribute__", + "__init__"): # avoid recursion + return object.__getattribute__(self, name) + if not hasattr(self._connector, name): + return object.__getattribute__(self, name) + attr = getattr(self._connector, name) + + # Intercept calls to the connector interface and write an event + # for each one to a file, which can be read back in the main test proc. + if callable(attr): + + def wrapper(*args, **kwargs): + self.call_record[name] += 1 + # Log the event as a line to the file + try: + with open(self._event_file, "a") as f: + f.write(name + "\n") + except Exception as e: + print(f"[ERROR] Could not log event {name} " + f"for {self.name}: {e}") + return attr(*args, **kwargs) + + return wrapper + return attr + + +KVConnectorFactory.register_connector("TestSharedStorageConnector", + TestSharedStorageConnector.__module__, + TestSharedStorageConnector.__name__) + + +# Helper function to compare directories recursively +def _compare_directories(dir1: Path, dir2: Path) -> bool: + """Compares two directories recursively for identical content.""" + dcmp = filecmp.dircmp(dir1, dir2) + if dcmp.left_only or dcmp.right_only or dcmp.diff_files: + print(f"Differences found between {dir1} and {dir2}:") + print(f" Left only: {dcmp.left_only}") + print(f" Right only: {dcmp.right_only}") + print(f" Different files: {dcmp.diff_files}") + return False + for sub_dir in dcmp.common_dirs: + if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir): + return False + return True + + +def test_multi_shared_storage_connector_consistency(): + """ + Tests that MultiConnector with two SharedStorageConnectors saves + identical KV cache data to separate storage locations. + """ + storage_1_path = Path("storage_1/") + storage_2_path = Path("storage_2/") + shutil.rmtree(storage_1_path, ignore_errors=True) + shutil.rmtree(storage_2_path, ignore_errors=True) + storage_1_path.mkdir() + storage_2_path.mkdir() + + # Configure MultiConnector with two SharedStorageConnectors + kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [{ + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_1_path), + "name": "storage1", + } + }, { + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_2_path), + "name": "storage2", + } + }] + }, + ) + + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + # Run generation - this should trigger saving KV cache + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + # --- Verification --- + + # Check that both storage directories were populated + local_subdirs = list(storage_1_path.iterdir()) + external_subdirs = list(storage_2_path.iterdir()) + + assert len( + local_subdirs + ) > 0, f"Local storage path {storage_1_path} is empty after generation." + assert len(external_subdirs) > 0, ( + f"External storage path {storage_2_path} is empty after generation.") + assert len(local_subdirs) == len(external_subdirs), ( + f"Mismatch in number of cache entries: " + f"Local={len(local_subdirs)}, External={len(external_subdirs)}") + + # The subdirectories should correspond to the prompt hashes + # Since prompts are the same, the hash directories should be the same name + local_subdir_names = sorted([d.name for d in local_subdirs]) + external_subdir_names = sorted([d.name for d in external_subdirs]) + assert local_subdir_names == external_subdir_names, ( + "Cache directory names do not match between local and external storage" + ) + + # Compare the contents of each corresponding cache directory + for subdir_name in local_subdir_names: + print(f"Comparing contents of cache directory: {subdir_name}") + assert _compare_directories(storage_1_path / subdir_name, + storage_2_path / subdir_name), \ + (f"Contents differ for cache directory '{subdir_name}' between " + f"{storage_1_path} and {storage_2_path}") + + events = get_connector_events() + # get_num_new_matched_tokens will be called on each connector in turn. + # neither of them have hits so update_state_after_alloc won't be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will return new tokens from the first + # connector so update_state_after_alloc will be called once blocks + # are allocated for the first connector. + # get_num_new_matched_tokens *won't* be called on the second connector + # in this case. + assert events["storage1"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + assert events["storage2"][:2] == [ + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Delete storage1 connector state + shutil.rmtree(storage_1_path) + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will be called for the first connector but it + # won't have a hit so update_state_after_alloc won't be called. + # get_num_new_matched_tokens will also be called on the second connector, + # but it should have a hit so update_state_after_alloc will be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Clean up + shutil.rmtree(storage_1_path) + shutil.rmtree(storage_2_path) + + +def get_connector_events() -> dict[str, list[str]]: + # Read in connector events and reset the files. + import glob + event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") + connector_events = {} + for fname in event_files: + name = fname.split("connector_")[1].split("_events.log")[0] + try: + with open(fname, "r+") as f: + connector_events[name] = [ + line.strip() for line in f if line.strip() + ] + f.truncate(0) + except Exception as e: + print(f"[ERROR] Could not read connector events for {name}: {e}") + + return connector_events + + +def test_engine_id_conflict(): + configs = [KVTransferConfig() for _ in range(2)] + ids = [config.engine_id for config in configs] + assert ids[0] != ids[1], ( + "Engine IDs should be different for different configs. " + f"Got {ids}") diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py new file mode 100644 index 000000000000..9b2a720c11c4 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorMetadata) + +from .utils import create_request, create_scheduler, create_vllm_config + + +def test_basic_inferface(): + """Unit test for basic NixlConnector interface functionality.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_id = request.request_id + + scheduler.add_request(request) + + # Remote Prefill, triggers NixlConnectorMetdata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + + assert len(kv_connector_metadata.requests) == 1 + assert request_id in kv_connector_metadata.requests + req_meta = kv_connector_metadata.requests[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, scheduler.kv_cache_manager. + single_type_manager.req_to_blocks[request_id]): + assert block_id == block.block_id + + +def test_prompt_less_than_block_size(): + """ + Test that we can handle case where prompt is < block. + + In this case, the P worker will send empty remote_block_ids. + The D worker should not schedule an async read in this case, + since there is nothing to pull. + """ + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Half of a block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_TOKENS = int(BLOCK_SIZE * 0.5) + + # Request will have 0 remote blocks. + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + num_remote_blocks=0) + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + + # This request should not have to read async. + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + assert len(kv_connector_metadata.requests) == 0 + + # This request should be scheduled regularly. + assert len(scheduler_output.scheduled_new_reqs) == 1 diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py new file mode 100644 index 000000000000..77098140343a --- /dev/null +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a Remote Decode request.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + + # Ensure the request is finished after 1 tokens. + assert request.is_finished() + assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED + output = engine_core_outputs.outputs[0] + assert output.finish_reason == FinishReason.LENGTH + assert output.kv_transfer_params is not None + + # Request freed in Scheduler and in Persistent Batch ... + assert request_id in scheduler.finished_req_ids + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + # ... but blocks should not be freed. + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block.ref_cnt == 1 + + # STEP (2): Send Finished to PB. + # (2a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 1 + assert request_id in scheduler_output.finished_req_ids + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (2b): execute_model() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (2c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP (3): Finished sending. + # (3a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (3b): execute_model() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request_id] + + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + assert_scheduler_empty(scheduler) + + +def test_short_prompt_lifecycle(): + """Test lifecycle of a Remote Decode request with short prompt.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Not enough tokens for full block. + NUM_TOKENS = vllm_config.cache_config.block_size // 2 + request = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + # Since tokens < block_size, there will be no kv xfer. + # So this should be cleaned up immediately. + _ = scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + # We need one more call to schedule() to clear data for persistent batch. + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_prefix_cache_lifecycle(): + """Test that remote decode params still works with a prefix cache hit.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Prime the KVCache. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 3 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS) + + scheduler.add_request(request_normal) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + + ##################### + # Actual Test: confirm we send all blocks. + + # Step (1): Send the KV Transfer. + NUM_EXTERNAL_FULL_BLOCKS -= 1 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote]) + eco = scheduler.update_from_output(scheduler_output, model_runner_output) + kv_transfer_params = eco.outputs[0].kv_transfer_params + + # Ensure we send all block ids, even if there is a cache hit. + assert (len( + kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS) + + # STEP (2): Ensure it is freed. + scheduler_output = scheduler.schedule() + scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request_remote.request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py new file mode 100644 index 000000000000..6fcff0d62045 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -0,0 +1,423 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a remote prefill.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + START_FREE_BLOCK_QUEUE_SIZE = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): + # (1a): schedule() + scheduler_output = scheduler.schedule() + + # Nothing running and empty scheduler output. + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.num_scheduled_tokens) == 0 + assert scheduler_output.total_num_scheduled_tokens == 0 + + # Req waiting for KVs with no computed/scheduled toks ... + assert len(scheduler.waiting) == 1 + assert request in scheduler.waiting + assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert (request.num_computed_tokens == 0) + + # ... but should have (uncached) blocks allocated to it. + block_pool = scheduler.kv_cache_manager.block_pool + assert (block_pool.free_block_queue.num_free_blocks + < START_FREE_BLOCK_QUEUE_SIZE) + assert len(block_pool.cached_block_hash_to_block) == 0 + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block._block_hash is None + + # (1b): forward() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(engine_core_outputs.outputs) == 0 + + # STEP (2): + # (2a): schedule(): nothing happens! + scheduler_output = scheduler.schedule() + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 0 + + # (2b): forward(): request finishes recv. + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # STEP (3): + # (3a): schedule(): this should actually schedule. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + # Confirm the block are actually allocated. + num_hashed_blocks = 0 + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS + + # Confirm the rest of the prompt is scheduled in this step. + scheduled_req = scheduler_output.scheduled_new_reqs[0] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_computed_tokens = scheduled_req.num_computed_tokens + total_prompt_tokens = len(scheduled_req.prompt_token_ids) + assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + + # (3b): execute_model() + model_runner_output = create_model_runner_output([request]) + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs.outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) + + +def test_interleaved_lifecycle(): + """Test Remote Prefills Work Well With Other Requests.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_local_a = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + ) + request_local_b = create_request( + request_id=3, + num_tokens=NUM_TOKENS, + ) + + # STEP 1: Regular request is running. + scheduler.add_request(request_local_a) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + model_runner_output = create_model_runner_output([request_local_a]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 2: Add a local and remote request. + scheduler.add_request(request_local_b) + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 1 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 3: continue running, KVs not arrived yet. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + reqs=[request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + # STEP 4: KVs arrive. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b], + finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 5: RECVed KVs are sent to ModelRunner. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(scheduler.waiting) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 6: Hit EOS and free. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote], + use_eos=True, + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_no_spurious_prefix_caching(): + """ + With P/D, blocks can be allocated but uncomputed for + multiple engine steps. This test confirms that we do + not accidentally have cache hits against uncomputed + blocks. + """ + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 and a half full external blocks. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + # Both of these requests have prompts like [1,1,1,1,1, ...] + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + use_all_1s_for_prompt_tokens=True, + ) + + request_local = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + do_remote_prefill=False, + use_all_1s_for_prompt_tokens=True, + ) + + # Schedule the remote prefill request. This should not + # cause any blocks to be cached. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert len(scheduler.waiting) == 1 + + # Schedule the local prefill request. This should + # cause blocks to be cached, but separately from + scheduler.add_request(request_local) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501 + request_remote.request_id] + + # Local should have cached blocks (but not all due to preallocate). + num_hashed_blocks = 0 + for block in local_blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks > 0 + + # Remote blocks should not be cached. + for block in remote_blocks: + assert block.ref_cnt == 1 + assert block._block_hash is None + + +def test_full_block_prompt(): + """Test that we handle a prompt that is the full block size.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Initialize a recv. + scheduler_output = scheduler.schedule() + # All blocks should be allocated. + num_blocks = len(scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # STEP (2): Recv. + scheduler_output = scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # # STEP (3): Run as usual. + scheduler_output = scheduler.schedule() + + # We need to recompute the final token of the prompt to generate + # the first new token, so we should not have a new block. + num_blocks = len(scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + NUM_TOKENS - 1) + assert (scheduler_output.num_scheduled_tokens[request_id] == 1) + + model_runner_output = create_model_runner_output([request]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs.outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) + + +def test_cannot_schedule_after_recv(): + """ + Test that we can handle no schedule after recv due to not + enough remaining KV blocks. + """ + + # NOTE: the KVCacheManager will use 1 null block. + # So there are 5 total working blocks. + TOTAL_NUM_BLOCKS = 6 + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS) + + # Prime the KVCache. + NUM_PROMPT_BLOCKS = 2 + BLOCK_SIZE = vllm_config.cache_config.block_size + # Prompt will use 2 blocks + 1 block after we schedule. + NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) + NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) + + request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_remote = create_request(request_id=2, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True) + + # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). + scheduler.add_request(request_normal) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 2: 5 blocks are in use (2 new for remote blocks). + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 3: finish recving (5 blocks in use) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + reqs=[request_normal], finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 4: try to schedule, not enough blocks. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 5: finish the request, free it. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Step 6: now we can schedule (with 2 blocks computed). + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote]) + assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + NUM_PROMPT_BLOCKS * BLOCK_SIZE) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 7: free everything. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py new file mode 100644 index 000000000000..53e2d6fda1ae --- /dev/null +++ b/tests/v1/kv_connector/unit/utils.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional + +import torch + +from vllm import SamplingParams +from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, + ModelConfig, SchedulerConfig, VllmConfig) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 + + +def assert_scheduler_empty(scheduler: Scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler.finished_recving_kv_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len( + scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 + + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 64, + block_size: int = 16, +) -> VllmConfig: + """Initialize VllmConfig For Testing.""" + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + return VllmConfig(scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu")) + + +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + """Initialize Scheduler For Testing.""" + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_request( + request_id: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, + num_remote_blocks: int = 3, +) -> Request: + """Make dummy request for testing.""" + + kv_transfer_params: Optional[dict[str, Any]] = None + + if do_remote_decode: + assert not do_remote_prefill + kv_transfer_params = dict(do_remote_prefill=False, + do_remote_decode=True) + elif do_remote_prefill: + kv_transfer_params = dict(do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list( + range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234) + + max_tokens = 1 if do_remote_decode else max_tokens + sampling_params = SamplingParams(max_tokens=max_tokens) + + if use_all_1s_for_prompt_tokens: + prompt_token_ids = [1] * num_tokens + else: + prompt_token_ids = [i * request_id for i in range(num_tokens)] + + req = Request( + request_id=f"id-{request_id}", + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + req.kv_transfer_params = kv_transfer_params + return req + + +def create_model_runner_output( + reqs: list[Request], + finished_sending: Optional[list[str]] = None, + finished_recving: Optional[list[str]] = None, + use_eos: bool = False, +) -> ModelRunnerOutput: + """Make dummy model runner output for testing.""" + + # Make request data. + req_ids = [req.request_id for req in reqs] + req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} + + # Make sampled tokens. + sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token_ids = [[sampled_token] for _ in req_ids] + + # Make output data structure. + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) diff --git a/tests/v1/metrics/test_ray_metrics.py b/tests/v1/metrics/test_ray_metrics.py new file mode 100644 index 000000000000..02475f7c150b --- /dev/null +++ b/tests/v1/metrics/test_ray_metrics.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import ray + +from vllm.sampling_params import SamplingParams +from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM +from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger + + +@pytest.fixture(scope="function", autouse=True) +def use_v1_only(monkeypatch): + """ + The change relies on V1 APIs, so set VLLM_USE_V1=1. + """ + monkeypatch.setenv('VLLM_USE_V1', '1') + + +MODELS = [ + "distilbert/distilgpt2", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [16]) +def test_engine_log_metrics_ray( + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """ Simple smoke test, verifying this can be used without exceptions. + Need to start a Ray cluster in order to verify outputs.""" + + @ray.remote(num_gpus=1) + class EngineTestActor: + + async def run(self): + engine_args = AsyncEngineArgs( + model=model, + dtype=dtype, + disable_log_stats=False, + ) + + engine = AsyncLLM.from_engine_args( + engine_args, stat_loggers=[RayPrometheusStatLogger]) + + for i, prompt in enumerate(example_prompts): + engine.generate( + request_id=f"request-id-{i}", + prompt=prompt, + sampling_params=SamplingParams(max_tokens=max_tokens), + ) + + # Create the actor and call the async method + actor = EngineTestActor.remote() # type: ignore[attr-defined] + ray.get(actor.run.remote()) diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index 8a5076412cfa..220f05c7ff1c 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -1,37 +1,115 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest import torch +from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs from torch import Generator -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.platforms import current_platform +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + is_flashinfer_available) DEVICE = "cuda" BATCH_SIZE = 1024 VOCAB_SIZE = 128 * 1024 +FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available + + +@pytest.fixture(autouse=True) +def reset_default_device(): + """ + Explicitly set the default device, which can affect subsequent tests. + Adding this fixture helps avoid this problem. + """ + original_device = torch.get_default_device() + yield + torch.set_default_device(original_device) + def test_topk_impl_equivalance(): - with torch.device(DEVICE): - generator = Generator(device=DEVICE).manual_seed(33) + torch.set_default_device(DEVICE) + generator = Generator(device=DEVICE).manual_seed(33) + + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) + + # Random top-k values between 1 and 9. + k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) + + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_( + torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool), + VOCAB_SIZE) + + # Top-k only implementation + result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + + # Top-p + top-k + no_op_top_p = torch.tensor([1.0]) + result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) + + assert torch.allclose(result1, result2) + + +def test_flashinfer_sampler(): + ''' + This test verifies that the FlashInfer top-k and top-p sampling + implementation produces the same results as the Python implementation. + + NOTE: FlashInfer did not directly expose an interface for fused top-k and + top-p prob renorm (it did provide fused sampling but we cannot compare + sampling results due to randomness), so we will compare the probability + renormed consequently by top-k and then top-p of FlashInfer implementation. + ''' + + if not FLASHINFER_ENABLED: + pytest.skip( + "FlashInfer not installed or not available on this platform.") + + torch.set_default_device(DEVICE) + generator = Generator(device=DEVICE).manual_seed(42) + + # Generate random logits + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) - logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) + # Generate various top-k and top-p values + k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) + p_values = torch.rand( + (BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0] - # Random top-k values between 1 and 9. - k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) + # Sometimes disable top-k (k=vocab_size) + k_values.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=torch.bool), VOCAB_SIZE) - # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). - k.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=bool), VOCAB_SIZE) + # Sometimes disable top-p (p=1.0) + p_values.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=torch.bool), 1.0) - # Top-k only implementation - result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + python_logits = apply_top_k_top_p( + logits=logits.clone(), + k=k_values, + p=p_values, + ) + python_probs = torch.softmax(python_logits, dim=-1) - # Top-p + top-k - no_op_top_p = torch.tensor([1.0]) - result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) + # FlashInfer only exposed renorm interfaces for probs so convert first + flashinfer_probs = torch.softmax(logits.clone(), dim=-1) + flashinfer_probs = top_k_renorm_probs( + probs=flashinfer_probs, + top_k=k_values, + ) + flashinfer_probs = top_p_renorm_probs( + probs=flashinfer_probs, + top_p=p_values, + ) - assert torch.allclose(result1, result2) + # Compare the results + assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ + "FlashInfer and Python sampling implementations do not match!" diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index f540895bbf14..932b652aea32 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -import re from enum import Enum from typing import Optional +import regex as re + from vllm import CompletionOutput diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py new file mode 100644 index 000000000000..b49ac45f3129 --- /dev/null +++ b/tests/v1/spec_decode/test_eagle.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest import mock + +import pytest +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VllmConfig) +from vllm.v1.spec_decode.eagle import EagleProposer + +model_dir = "meta-llama/Llama-3.1-8B-Instruct" +eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" +eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + + +def _create_proposer(method: str, k: int) -> EagleProposer: + model_config = ModelConfig(model=model_dir, + task="generate", + max_model_len=100, + tokenizer=model_dir, + tokenizer_mode="auto", + dtype="auto", + seed=None, + trust_remote_code=False) + + # Choose model directory based on method + draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + model=draft_model_dir, + method=method, + num_speculative_tokens=k, + ) + + vllm_config = VllmConfig(model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device="cuda"), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig()) + + return EagleProposer(vllm_config=vllm_config, device='cuda') + + +def test_prepare_inputs(): + """ + cu_target_query_lens: [0, a, a + b, a + b + c] + num_rejected_tokens: [n1, n2, n3] + num_tokens_per_req: [a - n1, b - n2, c - n3] + cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + token_indices: [0, 1, ..., a - n1 - 1, + a, a + 1, ..., a + b - n2 - 1, + a + b, a + b + 1, ..., a + b + c - n3 - 1] + """ + device = torch.device('cuda') + + # a = 4, b = 7, c = 5 + # n1 = 1, n2 = 3, n3 = 2 + + # Cumulative lengths: [0, 4, 11, 16] + cu_target_query_lens = torch.tensor([0, 4, 11, 16], + dtype=torch.int32, + device=device) + + # Rejected tokens per request: [1, 3, 2] + num_rejected_tokens = torch.tensor([1, 3, 2], + dtype=torch.int32, + device=device) + + # Expected calculations: + # query_len_per_req = [4, 7, 5] + # num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens) + # Expected cumulative counts: [0, 3, 7, 10] + expected_cu_num_tokens = torch.tensor([0, 3, 7, 10], + dtype=torch.int32, + device=device) + + # Expected token indices (mapped from original positions): + # First request: indices 0, 1, 2 (keeping first 3 from positions 0-3) + # Second request: indices 4, 5, 6, 7 (keeping first 4 from positions 4-10) + # Third request: indices 11, 12, 13 (keeping first 3 from positions 11-15) + expected_token_indices = torch.tensor( + [ + 0, + 1, + 2, # First request: 3 tokens (4-1) + 4, + 5, + 6, + 7, # Second request: 4 tokens (7-3) + 11, + 12, + 13 # Third request: 3 tokens (5-2) + ], + dtype=torch.int32, + device=device) + + # n1 + n2 + n3 - a - b -c + num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum( + ).item() + + cu_num_tokens, token_indices = EagleProposer.prepare_inputs( + cu_target_query_lens, num_rejected_tokens, num_tokens) + + assert torch.equal(cu_num_tokens, expected_cu_num_tokens) + assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() + assert torch.equal(token_indices, expected_token_indices) + + +@pytest.mark.parametrize( + "method,proposer_helper,draft_model_dir,target_attribute_path", [ + ("eagle", lambda k: _create_proposer("eagle", k), eagle_dir, + ('lm_head', )), + ("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir, + ('model', 'embed_tokens')), + ]) +@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') +@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') +@mock.patch('vllm.v1.spec_decode.eagle.get_model') +def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, + proposer_helper, draft_model_dir, target_attribute_path): + + # Setup model mock + mock_model = mock.MagicMock() + mock_get_model.return_value = mock_model + + # Setup mocks for attention layers + target_attn_layers = { + "target_attn_1": mock.MagicMock(), + "target_attn_2": mock.MagicMock() + } + # Draft model has one extra attention layer compared to target model + all_attn_layers = { + **target_attn_layers, "draft_extra_attn": mock.MagicMock() + } + + # Make mock_get_layers return different values for each call + mock_get_layers.side_effect = [target_attn_layers, all_attn_layers] + + # Setup mock for pp group to return the appropriate value for world size + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 2 if method == "eagle" else 1 + mock_get_pp_group.return_value = mock_pp_group + + # Setup target model with the appropriate attributes + target_model = mock.MagicMock() + + # Create the necessary attributes on the target model + current_obj = target_model + for i, attr in enumerate(target_attribute_path): + if i == len(target_attribute_path) - 1: + # Set the last attribute in the path to a MagicMock + setattr(current_obj, attr, mock.MagicMock()) + else: + # Create intermediate objects if needed + setattr(current_obj, attr, mock.MagicMock()) + current_obj = getattr(current_obj, attr) + + # Create proposer using the helper function + proposer = proposer_helper(k=8) + + # Call the method under test + proposer.load_model(target_model) + + # Verify common interactions + mock_get_model.assert_called_once() + + # Verify the specific attribute sharing based on the method + if method == "eagle": + assert proposer.model.lm_head == target_model.lm_head + else: + assert proposer.model.model.embed_tokens == \ + target_model.model.embed_tokens + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) +def test_propose(num_speculative_tokens): + # Use GPU device + device = torch.device('cuda') + + # Setup test parameters + batch_size = 2 + seq_len_1 = 5 + seq_len_2 = 3 + total_tokens = seq_len_1 + seq_len_2 + vocab_size = 100 + + # Create proposer first so we can use its actual hidden_size + proposer = _create_proposer("eagle", num_speculative_tokens) + # Get the hidden_size from the proposer to ensure consistency + hidden_size = proposer.hidden_size + + # Helper to create deterministic logits that will produce specific tokens + def create_deterministic_logits(token_ids): + logits = torch.full((batch_size, vocab_size), -100.0, device=device) + for i, token_id in enumerate(token_ids): + logits[i, token_id] = 100.0 + return logits + + # We mock a model that returns deterministic logits + # Sequence 1: 42, 43, 44, ... + # Sequence 2: 60, 61, 62, ... + base_token_ids = [42, 60] + + # Skip loading the model and replace it with a mock directly + # Create the mock model with deterministic outputs + model_mock = mock.MagicMock() + + # Setup for model forward calls + forward_returns = [] + for i in range(num_speculative_tokens): + if i == 0: + # First call uses all tokens + h_logits = torch.zeros(total_tokens, hidden_size, device=device) + h_states = torch.zeros(total_tokens, hidden_size, device=device) + else: + # Subsequent calls use batch_size tokens + h_logits = torch.zeros(batch_size, hidden_size, device=device) + h_states = torch.zeros(batch_size, hidden_size, device=device) + forward_returns.append((h_logits, h_states)) + + # For single token case, we only need the first item; + # for multi-token, we need the sequence + if num_speculative_tokens == 1: + model_mock.return_value = forward_returns[0] + else: + model_mock.side_effect = forward_returns + + # Setup for compute_logits calls + logits_returns = [] + for i in range(num_speculative_tokens): + # For each call, increment the base token IDs + current_tokens = [base_id + i for base_id in base_token_ids] + logits_returns.append(create_deterministic_logits(current_tokens)) + + if num_speculative_tokens == 1: + model_mock.compute_logits.return_value = logits_returns[0] + else: + model_mock.compute_logits.side_effect = logits_returns + + # Assign the mock to the proposer + proposer.model = model_mock + + # Assign draft attn_layer_names since load_model is not invoked + proposer.attn_layer_names = ["layer.0"] + + # Create input tensors + cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens], + dtype=torch.int32, + device=device) + + target_token_ids = torch.randint(0, + vocab_size, (total_tokens, ), + device=device) + target_positions = torch.cat([ + torch.arange(seq_len_1, device=device), + torch.arange(seq_len_2, device=device) + ]) + target_hidden_states = torch.randn(total_tokens, + hidden_size, + device=device) + target_slot_mapping = torch.randint(0, + 100, (total_tokens, ), + device=device) + next_token_ids = torch.randint(0, + vocab_size, (batch_size, ), + dtype=torch.int32, + device=device) + block_table = torch.randint(0, 10, (batch_size, 10), device=device) + + sampling_metadata = mock.MagicMock() + + # Call the method under test + result = proposer.propose(target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=block_table, + sampling_metadata=sampling_metadata) + + assert result.shape == (batch_size, num_speculative_tokens) + + # Create expected tokens based on our token pattern + if num_speculative_tokens == 1: + # Example for num_speculative_tokens=1: + # [[42], [60]] + expected_tokens = torch.tensor( + [[base_token_ids[0]], [base_token_ids[1]]], device=device) + else: + # Example for num_speculative_tokens=3: + # [[42, 43, 44], [60, 61, 62]] + expected_tokens = torch.zeros((batch_size, num_speculative_tokens), + dtype=torch.int64, + device=device) + for i in range(batch_size): + for j in range(num_speculative_tokens): + expected_tokens[i, j] = base_token_ids[i] + j + + # Verify all tokens match our expectations + assert torch.equal(result, expected_tokens) diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 1cefe8726df7..ffc0bceeee49 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -57,14 +57,6 @@ def unsupported_array_schemas(): "type": "array", "maxContains": 5 }, - { - "type": "array", - "minItems": 1 - }, - { - "type": "array", - "maxItems": 10 - }, ] diff --git a/tests/v1/test_metrics_reader.py b/tests/v1/test_metrics_reader.py new file mode 100644 index 000000000000..68539c80b59c --- /dev/null +++ b/tests/v1/test_metrics_reader.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 + +import prometheus_client +import pytest + +from vllm.v1.metrics.reader import (Counter, Gauge, Histogram, Vector, + get_metrics_snapshot) + + +@pytest.fixture(autouse=True) +def test_registry(monkeypatch): + # Use a custom registry for tests + test_registry = prometheus_client.CollectorRegistry(auto_describe=True) + monkeypatch.setattr("vllm.v1.metrics.reader.REGISTRY", test_registry) + return test_registry + + +@pytest.mark.parametrize("num_engines", [1, 4]) +def test_gauge_metric(test_registry, num_engines): + g = prometheus_client.Gauge("vllm:test_gauge", + "Test gauge metric", + labelnames=["model", "engine_index"], + registry=test_registry) + for i in range(num_engines): + g.labels(model="foo", engine_index=str(i)).set(98.5) + + metrics = get_metrics_snapshot() + assert len(metrics) == num_engines + engine_labels = [str(i) for i in range(num_engines)] + for m in metrics: + assert isinstance(m, Gauge) + assert m.name == "vllm:test_gauge" + assert m.value == 98.5 + assert m.labels["model"] == "foo" + assert m.labels["engine_index"] in engine_labels + engine_labels.remove(m.labels["engine_index"]) + + +@pytest.mark.parametrize("num_engines", [1, 4]) +def test_counter_metric(test_registry, num_engines): + c = prometheus_client.Counter("vllm:test_counter", + "Test counter metric", + labelnames=["model", "engine_index"], + registry=test_registry) + for i in range(num_engines): + c.labels(model="bar", engine_index=str(i)).inc(19) + + metrics = get_metrics_snapshot() + assert len(metrics) == num_engines + engine_labels = [str(i) for i in range(num_engines)] + for m in metrics: + assert isinstance(m, Counter) + assert m.name == "vllm:test_counter" + assert m.value == 19 + assert m.labels["model"] == "bar" + assert m.labels["engine_index"] in engine_labels + engine_labels.remove(m.labels["engine_index"]) + + +@pytest.mark.parametrize("num_engines", [1, 4]) +def test_histogram_metric(test_registry, num_engines): + h = prometheus_client.Histogram("vllm:test_histogram", + "Test histogram metric", + labelnames=["model", "engine_index"], + buckets=[10, 20, 30, 40, 50], + registry=test_registry) + for i in range(num_engines): + hist = h.labels(model="blaa", engine_index=str(i)) + hist.observe(42) + hist.observe(21) + hist.observe(7) + + metrics = get_metrics_snapshot() + assert len(metrics) == num_engines + engine_labels = [str(i) for i in range(num_engines)] + for m in metrics: + assert isinstance(m, Histogram) + assert m.name == "vllm:test_histogram" + assert m.count == 3 + assert m.sum == 70 + assert m.buckets["10.0"] == 1 + assert m.buckets["20.0"] == 1 + assert m.buckets["30.0"] == 2 + assert m.buckets["40.0"] == 2 + assert m.buckets["50.0"] == 3 + assert m.labels["model"] == "blaa" + assert m.labels["engine_index"] in engine_labels + engine_labels.remove(m.labels["engine_index"]) + + +@pytest.mark.parametrize("num_engines", [1, 4]) +def test_vector_metric(test_registry, num_engines): + c = prometheus_client.Counter( + "vllm:spec_decode_num_accepted_tokens_per_pos", + "Vector-like counter metric", + labelnames=["position", "model", "engine_index"], + registry=test_registry) + for i in range(num_engines): + c.labels(position="0", model="llama", engine_index=str(i)).inc(10) + c.labels(position="1", model="llama", engine_index=str(i)).inc(5) + c.labels(position="2", model="llama", engine_index=str(i)).inc(1) + + metrics = get_metrics_snapshot() + assert len(metrics) == num_engines + engine_labels = [str(i) for i in range(num_engines)] + for m in metrics: + assert isinstance(m, Vector) + assert m.name == "vllm:spec_decode_num_accepted_tokens_per_pos" + assert m.values == [10, 5, 1] + assert m.labels["model"] == "llama" + assert m.labels["engine_index"] in engine_labels + engine_labels.remove(m.labels["engine_index"]) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index c34c673e985e..1b77417a1bd3 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -12,7 +12,7 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "mistralai/Mamba-Codestral-7B-v0.1", # mamba - "hmellor/bamba-tiny-random", # hybrid + "hmellor/tiny-random-BambaForCausalLM", # hybrid "BAAI/bge-m3", # embedding ] diff --git a/tests/v1/test_stats.py b/tests/v1/test_stats.py deleted file mode 100644 index 48419d8a2791..000000000000 --- a/tests/v1/test_stats.py +++ /dev/null @@ -1,302 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from vllm.sampling_params import SamplingParams -from vllm.v1.stats.common import RequestStats, RequestStatsUpdate - - -def make_update( - request_id: str, - update_type: RequestStatsUpdate.Type, - monotonic_ts_s: float, - **kwargs, -): - if update_type == RequestStatsUpdate.Type.INPUT_PROCESSED: - kwargs.setdefault("sampling_params", SamplingParams(n=1)) - kwargs.setdefault("num_prompt_tokens", 10) - elif update_type == RequestStatsUpdate.Type.PREFILLING: - kwargs.setdefault("num_computed_tokens", 10) - kwargs.setdefault("num_cached_tokens", 10) - elif update_type == RequestStatsUpdate.Type.DETOKENIZED: - kwargs.setdefault("num_new_tokens", 10) - elif update_type == RequestStatsUpdate.Type.FINISHED: - kwargs.setdefault("finish_reason", "test_reason") - - return RequestStatsUpdate( - request_id=request_id, - type=update_type, - monotonic_ts_s=monotonic_ts_s, - **kwargs, - ) - - -def test_invalid_request_update(): - request_id = "test_request" - update_specific_required_fields = { - RequestStatsUpdate.Type.INPUT_PROCESSED: [ - "sampling_params", - "num_prompt_tokens", - ], - RequestStatsUpdate.Type.PREFILLING: [ - "num_computed_tokens", - "num_cached_tokens", - ], - RequestStatsUpdate.Type.DETOKENIZED: ["num_new_tokens"], - RequestStatsUpdate.Type.FINISHED: ["finish_reason"], - } - - # Missing a required field should raise an assertion error. - for update_type in RequestStatsUpdate.Type: - required_fields = update_specific_required_fields.get(update_type, []) - - # Try to miss one of the required fields. - kwargs = {field: object() for field in required_fields} - for field in required_fields: - copy_kwargs = kwargs.copy() - copy_kwargs.pop(field) - with pytest.raises(ValueError): - RequestStatsUpdate( - request_id=request_id, - type=update_type, - **copy_kwargs, - ) - - -def test_invalid_request_update_transition(): - # Test invalid transition type. - for src in RequestStatsUpdate.Type: - for dst in RequestStatsUpdate.Type: - if dst not in RequestStatsUpdate._VALID_TRANSITIONS[src]: - with pytest.raises(AssertionError): - RequestStatsUpdate.check_valid_update( - make_update( - update_type=dst, - request_id="test_request", - monotonic_ts_s=1, - ), - last_update_type=src, - last_updated_ts_s=0, - ) - else: - RequestStatsUpdate.check_valid_update( - make_update( - request_id="test_request", - update_type=dst, - monotonic_ts_s=1, - ), - last_update_type=src, - last_updated_ts_s=0, - ) - - # Test invalid timestamp. - with pytest.raises(AssertionError): - RequestStatsUpdate.check_valid_update( - make_update( - request_id="test_request", - update_type=RequestStatsUpdate.Type.ARRIVED, - monotonic_ts_s=1, - ), - last_update_type=None, - last_updated_ts_s=2, - ) - - -def test_lifecycle_updates(): - request_id = "test_request" - stats = RequestStats(request_id=request_id) - - # Test the below scenario: - arrived_ts = 0 - input_processed_ts = 1 - queued_ts = 2 - prefilling_ts = 3 - decoded_ts = 5 - detokenized_ts = 6 - decoded_2_ts = 7 - detokenized_2_ts = 8 - preempted_ts = 9 - resumed_ts = 10 - decoded_3_ts = 11 - detokenized_3_ts = 12 - finished_ts = 13 - - # Test ARRIVED - arrived_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.ARRIVED, - monotonic_ts_s=arrived_ts, - ) - stats.update_from(arrived_update) - assert stats.arrival_ts_s == arrived_ts - assert stats.last_updated_ts_s == arrived_ts - - # Test INPUT_PROCESSED - sampling_params = SamplingParams(n=1) - input_processed_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.INPUT_PROCESSED, - monotonic_ts_s=input_processed_ts, - sampling_params=sampling_params, - num_prompt_tokens=6, - ) - stats.update_from(input_processed_update) - assert stats.input_processor_end_ts_s == input_processed_ts - assert stats.last_updated_ts_s == input_processed_ts - assert stats.num_prompt_tokens == 6 - assert stats.sampling_params == sampling_params - - assert stats.first_token_ts_s is None - assert stats.prefill_ts_s is None - - # Test QUEUED - queued_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.QUEUED, - monotonic_ts_s=queued_ts, - ) - stats.update_from(queued_update) - assert stats.queued_ts_s == queued_ts - assert stats.last_updated_ts_s == queued_ts - - # Test PREFILLING - prefilling_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.PREFILLING, - monotonic_ts_s=prefilling_ts, - num_computed_tokens=3, - num_cached_tokens=1, - ) - stats.update_from(prefilling_update) - assert stats.prefill_ts_s == prefilling_ts - assert stats.num_computed_tokens == 3 - assert stats.num_cached_tokens == 1 - assert stats.queue_duration_s == prefilling_ts - queued_ts - - # Test DECODING - decoded_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.DECODING, - monotonic_ts_s=decoded_ts, - ) - stats.update_from(decoded_update) - assert stats.last_updated_ts_s == decoded_ts - - # Test DETOKENIZED - detokenized_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.DETOKENIZED, - monotonic_ts_s=detokenized_ts, - num_new_tokens=1, - ) - stats.update_from(detokenized_update) - assert stats.last_updated_ts_s == detokenized_ts - assert stats.num_output_tokens == 1 - # Since arrival - assert stats.first_token_latency_s == detokenized_ts - arrived_ts - # Since first scheduled - assert stats.prefill_latency_s == detokenized_ts - prefilling_ts - - # Test another DECODING and DETOKENIZED should - # yield correct inter token latency - decoded_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.DECODING, - monotonic_ts_s=decoded_2_ts, - ) - stats.update_from(decoded_update) - - detokenized_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.DETOKENIZED, - monotonic_ts_s=detokenized_2_ts, - num_new_tokens=1, - ) - stats.update_from(detokenized_update) - assert stats.output_token_latency_s_lst == [ - detokenized_2_ts - detokenized_ts, - ] - assert stats.num_output_tokens == 2 - - # Test PREEMPTED - preempted_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.PREEMPTED, - monotonic_ts_s=preempted_ts, - ) - stats.update_from(preempted_update) - assert stats.last_updated_ts_s == preempted_ts - assert stats.preempted_ts_s_lst == [preempted_ts] - # States should be reset - assert stats.num_computed_tokens == 0 - assert stats.num_cached_tokens == 0 - # These states should not be reset - assert stats.num_output_tokens == 2 - assert stats.output_token_latency_s_lst == [ - detokenized_2_ts - detokenized_ts, - ] - assert stats.prefill_latency_s == prefilling_ts - arrived_ts - assert stats.num_prompt_tokens == 6 - assert stats.prefill_start_ts_s_lst == [prefilling_ts] - - # Test resumed - resumed_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.PREFILLING, - monotonic_ts_s=resumed_ts, - num_computed_tokens=6, - num_cached_tokens=2, - ) - stats.update_from(resumed_update) - # prefill timestamp should not be updated since it's a resumed prefill - assert stats.prefill_ts_s == prefilling_ts - assert stats.num_computed_tokens == 6 - assert stats.num_cached_tokens == 2 - assert stats.prefill_start_ts_s_lst == [ - prefilling_ts, - resumed_ts, - ] - assert stats.last_updated_ts_s == resumed_ts - - # Test another DECODED/DETOKENIZED should yield correct first token latency. - decoded_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.DECODING, - monotonic_ts_s=decoded_3_ts, - ) - detokenized_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.DETOKENIZED, - monotonic_ts_s=detokenized_3_ts, - num_new_tokens=1, - ) - stats.update_from(decoded_update) - stats.update_from(detokenized_update) - assert stats.first_token_ts_s == detokenized_ts - arrived_ts - assert stats.num_output_tokens == 3 - assert stats.output_token_latency_s_lst == [ - detokenized_2_ts - detokenized_ts, - detokenized_3_ts - detokenized_2_ts, - ] - - # Test FINISHED - finished_update = RequestStatsUpdate( - request_id=request_id, - type=RequestStatsUpdate.Type.FINISHED, - monotonic_ts_s=finished_ts, - finish_reason="test_reason", - ) - stats.update_from(finished_update) - assert stats.last_updated_ts_s == finished_ts - assert stats.e2e_latency_s == finished_ts - arrived_ts - assert stats.inference_latency_s == finished_ts - prefilling_ts - assert stats.prefill_latency_s == detokenized_ts - prefilling_ts - assert stats.decode_latency_s == finished_ts - detokenized_ts - assert stats.first_token_latency_s == detokenized_ts - arrived_ts - assert stats.queue_duration_s == prefilling_ts - queued_ts - assert stats.is_finished - assert stats.finish_reason == "test_reason" - - # TODO(rickyx): Add model forward/execute time. - assert stats.model_forward_duration_s == 0.0 - assert stats.model_execute_duration_s == 0.0 diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 915ec2914a82..27741bd156be 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -9,9 +9,11 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState, - InputBatch) +from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 @@ -22,6 +24,27 @@ MAX_NUM_PROMPT_TOKENS = 64 +def get_kv_cache_config() -> KVCacheConfig: + return KVCacheConfig( + num_blocks=10, + tensors={ + "layer.0": KVCacheTensor(size=1024), + }, + kv_cache_groups=[ + KVCacheGroupSpec( + layer_names=["layer.0"], + kv_cache_spec=FullAttentionSpec( + block_size=1, + num_kv_heads=1, + head_size=16, + dtype=torch.float16, + use_mla=False, + ), + ), + ], + ) + + def _compare_objs(obj1, obj2): attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) attr_names = set([ @@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2): elif isinstance(a, np.ndarray): if np.allclose(a, b): is_same = True + elif isinstance(a, MultiGroupBlockTable): + for a_i, b_i in zip(a.block_tables, b.block_tables): + _compare_objs(a_i, b_i) + is_same = True elif isinstance(a, (BlockTable, SamplingMetadata)): _compare_objs(a, b) is_same = True # if we make it here must be same @@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int): sampling_params=_create_sampling_params(), mm_inputs=[], mm_positions=[], - block_ids=[], + block_ids=[[]], generator=None, num_computed_tokens=len(output_token_ids), output_token_ids=output_token_ids, @@ -220,10 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, - max_num_blocks_per_req=10, + max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, + block_size=1, ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -309,18 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, - max_num_blocks_per_req=10, + max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, + block_size=1, ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, - max_num_blocks_per_req=10, + max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, + block_size=1, ) reqs: list[CachedRequestState] = [] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 68e34cfacc58..b8c3d88617d0 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,14 +1,53 @@ # SPDX-License-Identifier: Apache-2.0 + import pytest -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VllmConfig) from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +def initialize_kv_cache(runner: GPUModelRunner): + """ + Only perform necessary steps in GPUModelRunner.initialize_kv_cache() + """ + kv_cache_config = KVCacheConfig( + num_blocks=10, + tensors={ + "layer.0": KVCacheTensor(size=1024), + }, + kv_cache_groups=[ + KVCacheGroupSpec( + layer_names=["layer.0"], + kv_cache_spec=FullAttentionSpec( + block_size=16, + num_kv_heads=runner.model_config.get_num_kv_heads( + runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + use_mla=False, + )) + ]) + runner.kv_cache_config = kv_cache_config + runner.input_batch = InputBatch( + max_num_reqs=runner.max_num_reqs, + max_model_len=runner.max_model_len, + max_num_batched_tokens=runner.max_num_tokens, + device=runner.device, + pin_memory=runner.pin_memory, + vocab_size=runner.model_config.get_vocab_size(), + block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size, + ) + runner.initialize_attn_backend(kv_cache_config) + + @pytest.fixture def model_runner(): scheduler_config = SchedulerConfig( @@ -31,14 +70,18 @@ def model_runner(): swap_space=0, cache_dtype="auto", ) + parallel_config = ParallelConfig() vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, scheduler_config=scheduler_config, + parallel_config=parallel_config, ) device = "cuda" - return GPUModelRunner(vllm_config, device) + runner = GPUModelRunner(vllm_config, device) + initialize_kv_cache(runner) + return runner def _schedule_new_request(*req_ids: str) -> SchedulerOutput: @@ -54,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), - block_ids=[0], + block_ids=[[0]], num_computed_tokens=0, lora_request=None, )) @@ -92,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner, def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: req_index = model_runner.input_batch.req_id_to_index[req_id] - block_table = model_runner.input_batch.block_table + block_table = model_runner.input_batch.block_table[0] req_state = model_runner.requests[req_id] - if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): + if block_table.num_blocks_per_row[req_index] != len( + req_state.block_ids[0]): return False num_blocks = block_table.num_blocks_per_row[req_index] return (block_table.block_table_np[req_index, :num_blocks] == - req_state.block_ids).all() + req_state.block_ids[0]).all() def test_update_states_new_request(model_runner): @@ -181,7 +225,7 @@ def test_update_states_request_resumed(model_runner): req_id=req_id, resumed_from_preemption=False, new_token_ids=[], - new_block_ids=[], + new_block_ids=[[]], num_computed_tokens=0, ) diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 9c1c11da572e..ee98aed2684d 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -4,4 +4,5 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True -awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file +awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main +compressed-tensors, RedHatAI/Llama-4-Scout-17B-16E-Instruct-quantized.w4a16, main \ No newline at end of file diff --git a/tools/check_triton_import.py b/tools/check_triton_import.py new file mode 100644 index 000000000000..18c9726a11ac --- /dev/null +++ b/tools/check_triton_import.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess +import sys + +import regex as re + +FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)") + +# the way allowed to import triton +ALLOWED_LINES = { + "from vllm.triton_utils import triton", + "from vllm.triton_utils import tl", + "from vllm.triton_utils import tl, triton", +} + + +def is_forbidden_import(line: str) -> bool: + stripped = line.strip() + return bool( + FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES + + +def parse_diff(diff: str) -> list[str]: + violations = [] + current_file = None + current_lineno = None + + for line in diff.splitlines(): + if line.startswith("+++ b/"): + current_file = line[6:] + elif line.startswith("@@"): + match = re.search(r"\+(\d+)", line) + if match: + current_lineno = int( + match.group(1)) - 1 # next "+ line" is here + elif line.startswith("+") and not line.startswith("++"): + current_lineno += 1 + code_line = line[1:] + if is_forbidden_import(code_line): + violations.append( + f"{current_file}:{current_lineno}: {code_line.strip()}") + return violations + + +def get_diff(diff_type: str) -> str: + if diff_type == "staged": + return subprocess.check_output( + ["git", "diff", "--cached", "--unified=0"], text=True) + elif diff_type == "unstaged": + return subprocess.check_output(["git", "diff", "--unified=0"], + text=True) + else: + raise ValueError(f"Unknown diff_type: {diff_type}") + + +def main(): + all_violations = [] + for diff_type in ["staged", "unstaged"]: + try: + diff_output = get_diff(diff_type) + violations = parse_diff(diff_output) + all_violations.extend(violations) + except subprocess.CalledProcessError as e: + print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr) + + if all_violations: + print("โŒ Forbidden direct `import triton` detected." + " โžค Use `from vllm.triton_utils import triton` instead.\n") + for v in all_violations: + print(f"โŒ {v}") + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/enforce_regex_import.py b/tools/enforce_regex_import.py new file mode 100644 index 000000000000..b55c4a94eac8 --- /dev/null +++ b/tools/enforce_regex_import.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import subprocess +from pathlib import Path + +import regex as re + +FORBIDDEN_PATTERNS = re.compile( + r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)') +ALLOWED_PATTERNS = [ + re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'), + re.compile(r'^\s*import\s+regex\s*$'), +] + + +def get_staged_python_files() -> list[str]: + try: + result = subprocess.run( + ['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'], + capture_output=True, + text=True, + check=True) + files = result.stdout.strip().split( + '\n') if result.stdout.strip() else [] + return [f for f in files if f.endswith('.py')] + except subprocess.CalledProcessError: + return [] + + +def is_forbidden_import(line: str) -> bool: + line = line.strip() + return bool( + FORBIDDEN_PATTERNS.match(line) + and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)) + + +def check_file(filepath: str) -> list[tuple[int, str]]: + violations = [] + try: + with open(filepath, encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + if is_forbidden_import(line): + violations.append((line_num, line.strip())) + except (OSError, UnicodeDecodeError): + pass + return violations + + +def main() -> int: + files = get_staged_python_files() + if not files: + return 0 + + total_violations = 0 + + for filepath in files: + if not Path(filepath).exists(): + continue + + violations = check_file(filepath) + if violations: + print(f"\nโŒ {filepath}:") + for line_num, line in violations: + print(f" Line {line_num}: {line}") + total_violations += 1 + + if total_violations > 0: + print(f"\n๐Ÿ’ก Found {total_violations} violation(s).") + print("โŒ Please replace 'import re' with 'import regex as re'") + print( + " Also replace 'from re import ...' with 'from regex import ...'" + ) # noqa: E501 + print("โœ… Allowed imports:") + print(" - import regex as re") + print(" - import regex") # noqa: E501 + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/ep_kernels/README.md b/tools/ep_kernels/README.md new file mode 100644 index 000000000000..5c98e999da33 --- /dev/null +++ b/tools/ep_kernels/README.md @@ -0,0 +1,27 @@ +Large-scale cluster-level expert parallel, as described in the [DeepSeek-V3 Technical Report](http://arxiv.org/abs/2412.19437), is an efficient way to deploy sparse MoE models with many experts. However, such deployment requires many components beyond a normal Python package, including system package support and system driver support. It is impossible to bundle all these components into a Python package. + +Here we break down the requirements in 3 steps: +1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this. +2. Build and install the system libraries (GDR Copy). This step requires root access. You can do it inside a Docker container so that they can be shipped as a single image. +3. Build and install the system drivers (GDR Copy, and necessary modifications to NVIDIA driver to enable IBGDA). This step requires root access, and must be done on the host machine. + +2 and 3 are necessary for multi-node deployment. + +All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`. + +# Usage + +## Single-node + +```bash +bash install_python_libraries.sh +``` + +## Multi-node + +```bash +bash install_python_libraries.sh +sudo bash install_system_libraries.sh +sudo bash install_system_drivers.sh +sudo reboot # Reboot is required to load the new driver +``` diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh new file mode 100644 index 000000000000..e5632f4b5875 --- /dev/null +++ b/tools/ep_kernels/install_python_libraries.sh @@ -0,0 +1,77 @@ +set -ex + +# prepare workspace directory +WORKSPACE=$1 +if [ -z "$WORKSPACE" ]; then + export WORKSPACE=$(pwd)/ep_kernels_workspace +fi + +if [ ! -d "$WORKSPACE" ]; then + mkdir -p $WORKSPACE +fi + +# install dependencies if not installed +pip3 install cmake torch ninja + +# build gdrcopy, required by nvshmem +pushd $WORKSPACE +wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.4.4.tar.gz +mkdir -p gdrcopy_src +tar -xvf v2.4.4.tar.gz -C gdrcopy_src --strip-components=1 +pushd gdrcopy_src +make -j$(nproc) +make prefix=$WORKSPACE/gdrcopy_install install +popd + +# build nvshmem +pushd $WORKSPACE +mkdir -p nvshmem_src +wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz +tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1 +pushd nvshmem_src +wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch +git init +git apply -vvv nvshmem.patch + +# assume CUDA_HOME is set correctly +export GDRCOPY_HOME=$WORKSPACE/gdrcopy_install +export NVSHMEM_SHMEM_SUPPORT=0 +export NVSHMEM_UCX_SUPPORT=0 +export NVSHMEM_USE_NCCL=0 +export NVSHMEM_IBGDA_SUPPORT=1 +export NVSHMEM_PMIX_SUPPORT=0 +export NVSHMEM_TIMEOUT_DEVICE_POLLING=0 +export NVSHMEM_USE_GDRCOPY=1 +export NVSHMEM_IBRC_SUPPORT=1 + +# remove MPI dependency +export NVSHMEM_BUILD_TESTS=0 +export NVSHMEM_BUILD_EXAMPLES=0 +export NVSHMEM_MPI_SUPPORT=0 + +cmake -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install + +cd $WORKSPACE/nvshmem_build/ +make -j$(nproc) +make install + +popd + +export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH + +# build and install pplx, require pytorch installed +pushd $WORKSPACE +git clone https://github.com/ppl-ai/pplx-kernels +cd pplx-kernels +# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 +# PIP_NO_BUILD_ISOLATION=0 disables build isolation +PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install -vvv -e . +popd + +# build and install deepep, require pytorch installed +pushd $WORKSPACE +git clone https://github.com/deepseek-ai/DeepEP +cd DeepEP +export NVSHMEM_DIR=$WORKSPACE/nvshmem_install +PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . +popd diff --git a/tools/ep_kernels/install_system_drivers.sh b/tools/ep_kernels/install_system_drivers.sh new file mode 100644 index 000000000000..8b0669ef404f --- /dev/null +++ b/tools/ep_kernels/install_system_drivers.sh @@ -0,0 +1,24 @@ +set -ex + +# prepare workspace directory +WORKSPACE=$1 +if [ -z "$WORKSPACE" ]; then + export WORKSPACE=$(pwd)/ep_kernels_workspace +fi + +if [ ! -d "$WORKSPACE" ]; then + mkdir -p $WORKSPACE +fi + +# build and install gdrcopy driver +pushd $WORKSPACE +cd gdrcopy_src +./insmod.sh +# run gdrcopy_copybw to test the installation +$WORKSPACE/gdrcopy_install/bin/gdrcopy_copybw + +# turn on IBGDA +echo 'options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;"' | tee -a /etc/modprobe.d/nvidia.conf +update-initramfs -u + +echo "Please reboot the system to apply the changes" diff --git a/tools/ep_kernels/install_system_libraries.sh b/tools/ep_kernels/install_system_libraries.sh new file mode 100644 index 000000000000..c148d5443900 --- /dev/null +++ b/tools/ep_kernels/install_system_libraries.sh @@ -0,0 +1,18 @@ +set -ex + +# prepare workspace directory +WORKSPACE=$1 +if [ -z "$WORKSPACE" ]; then + export WORKSPACE=$(pwd)/ep_kernels_workspace +fi + +if [ ! -d "$WORKSPACE" ]; then + mkdir -p $WORKSPACE +fi + +# build and install gdrcopy system packages +pushd $WORKSPACE +cd gdrcopy_src/packages +apt install devscripts -y +CUDA=${CUDA_HOME:-/usr/local/cuda} ./build-deb-packages.sh +dpkg -i *.deb diff --git a/tools/install_nixl.sh b/tools/install_nixl.sh new file mode 100644 index 000000000000..56717cfb77f7 --- /dev/null +++ b/tools/install_nixl.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# Usage: ./install_nixl.sh [--force] + +FORCE=false +if [ "$1" == "--force" ]; then + FORCE=true +fi + +SUDO=false +if command -v sudo >/dev/null 2>&1 && sudo -n true 2>/dev/null; then + SUDO=true +fi + +ARCH=$(uname -m) + +ROOT_DIR="/usr/local" +mkdir -p "$ROOT_DIR" +GDR_HOME="$ROOT_DIR/gdrcopy" +UCX_HOME="$ROOT_DIR/ucx" +NIXL_HOME="$ROOT_DIR/nixl" +CUDA_HOME=/usr/local/cuda + +export PATH="$GDR_HOME/bin:$UCX_HOME/bin:$NIXL_HOME/bin:$PATH" +export LD_LIBRARY_PATH="$GDR_HOME/lib:$UCX_HOME/lib:$NIXL_HOME/lib/$ARCH-linux-gnu:$LD_LIBRARY_PATH" + +TEMP_DIR="nixl_installer" +mkdir -p "$TEMP_DIR" +cd "$TEMP_DIR" + +pip install meson ninja pybind11 + +if [ ! -e "/dev/gdrdrv" ] || [ "$FORCE" = true ]; then + echo "Installing gdrcopy\n" + wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.5.tar.gz + tar xzf v2.5.tar.gz; rm v2.5.tar.gz + cd gdrcopy-2.5 + make prefix=$GDR_HOME CUDA=$CUDA_HOME all install + + if $SUDO; then + echo "Running insmod.sh with sudo" + sudo ./insmod.sh + else + echo "Skipping insmod.sh - sudo not available" + echo "Please run 'sudo ./gdrcopy-2.5/insmod.sh' manually if needed" + fi + + cd .. +else + echo "Found /dev/gdrdrv. Skipping gdrcopy installation" +fi + +if ! command -v ucx_info &> /dev/null || [ "$FORCE" = true ]; then + echo "Installing UCX" + wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz + tar xzf ucx-1.18.0.tar.gz; rm ucx-1.18.0.tar.gz + cd ucx-1.18.0 + + # Checking Mellanox NICs + MLX_OPTS="" + if lspci | grep -i mellanox > /dev/null || command -v ibstat > /dev/null; then + echo "Mellanox NIC detected, adding Mellanox-specific options" + MLX_OPTS="--with-rdmacm \ + --with-mlx5-dv \ + --with-ib-hw-tm" + fi + + ./configure --prefix=$UCX_HOME \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=$CUDA_HOME \ + --with-dm \ + --with-gdrcopy=$GDR_HOME \ + --with-verbs \ + --enable-mt \ + $MLX_OPTS + make -j + make -j install-strip + + if $SUDO; then + echo "Running ldconfig with sudo" + sudo ldconfig + else + echo "Skipping ldconfig - sudo not available" + echo "Please run 'sudo ldconfig' manually if needed" + fi + + cd .. +else + echo "Found existing UCX. Skipping UCX installation" +fi + +if ! command -v nixl_test &> /dev/null || [ "$FORCE" = true ]; then + echo "Installing NIXL" + wget https://github.com/ai-dynamo/nixl/archive/refs/tags/0.2.0.tar.gz + tar xzf 0.2.0.tar.gz; rm 0.2.0.tar.gz + cd nixl-0.2.0 + meson setup build --prefix=$NIXL_HOME -Ducx_path=$UCX_HOME + cd build + ninja + ninja install + + cd ../.. +else + echo "Found existing NIXL. Skipping NIXL installation" +fi diff --git a/tools/update-dockerfile-graph.sh b/tools/update-dockerfile-graph.sh index a1e22a69cdc7..88189e8ab208 100755 --- a/tools/update-dockerfile-graph.sh +++ b/tools/update-dockerfile-graph.sh @@ -24,7 +24,7 @@ if printf '%s\n' "${FILES[@]}" | grep -q "^docker/Dockerfile$"; then fi # Define the target file path - TARGET_GRAPH_FILE="docs/source/assets/contributing/dockerfile-stages-dependency.png" + TARGET_GRAPH_FILE="docs/assets/contributing/dockerfile-stages-dependency.png" # Ensure target directory exists mkdir -p "$(dirname "$TARGET_GRAPH_FILE")" diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0206d4552c8b..3c8e6b95ce76 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor, prefix_lse, suffix_output, suffix_lse) +def convert_vertical_slash_indexes( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, context_size, + block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + +def convert_vertical_slash_indexes_mergehead( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + # [N_HEADS] : different head use different number of indices + vertical_indices_count: torch.Tensor, + slash_indices_count: torch.Tensor, + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.empty(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.empty(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes_mergehead( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + # pos encoding ops def rotary_embedding( positions: torch.Tensor, @@ -159,14 +254,8 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - # TODO: Remove this contiguous call when the kernel is updated to support tensor slices - query_contiguous = query.contiguous() - key_contiguous = key.contiguous() if key is not None else None - torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous, - head_size, cos_sin_cache, is_neox) - query.copy_(query_contiguous) - if key is not None: - key.copy_(key_contiguous) + torch.ops._C.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -174,16 +263,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - # TODO: Remove this contiguous call when the kernel is updated to support tensor slices - query_contiguous = query.contiguous() - key_contiguous = key.contiguous() if key is not None else None - torch.ops._C.batched_rotary_embedding(positions, query_contiguous, - key_contiguous, head_size, + torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) - query.copy_(query_contiguous) - if key is not None: - key.copy_(key_contiguous) # layer norm ops @@ -333,6 +415,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -745,10 +828,11 @@ def get_cutlass_moe_mm_data( - output_permutation: Permutation that must be used to shuffle the output after executing the MMs. """ - torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, - input_permutation, output_permutation, - num_experts, n, k) + return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, + input_permutation, + output_permutation, + num_experts, n, k) def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, @@ -767,9 +851,41 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, MMs used in the fused MoE operation. - a/b/c_strides: The data strides passed to grouped matrix multiplication. """ - torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales, - b_scales, expert_offsets, problem_sizes, - a_strides, b_strides, c_strides) + return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, + a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, + c_strides) + + +def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + alphas: torch.Tensor, problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, + out_dtype: torch.dtype, device: torch.device): + """ + An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs + the gemms for each combination based on the specified problem sizes. + + This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. + - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized + input and expert weights. + - a_/b_scales: The blockscales in FP8-E4M3 precision + - expert_offsets/sf_offsets: Indices that mark at which token index + each expert begins its computation. The number of tokens + computed with expert E is expert_offsets[E + 1] - + expert_offsets[E] And the sf_size per expert is + sf_offset[E+1] - sf_offset[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + """ + m_topk = a_tensors.shape[0] + n = b_tensors.shape[1] + c_shape = (m_topk, n) + c = torch.empty(c_shape, device=device, dtype=out_dtype) + torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales, + b_scales, alphas, problem_sizes, + expert_offsets, sf_offsets) + return c.to(out_dtype) # aqlm @@ -833,6 +949,7 @@ def gptq_marlin_gemm(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -845,9 +962,10 @@ def gptq_marlin_gemm(a: torch.Tensor, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, b_q_type.id, - size_m, size_n, size_k, is_k_full, + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) @@ -960,6 +1078,63 @@ def scaled_fp4_quant( return output, output_scale +def scaled_fp4_experts_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + topk: int, + expert_map: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + packed MoE Inputs. + Args: + input: The input tensor to be quantized to FP4 + expert_map: The expert map tensor + input_global_scale: A scalar scaling factor for the entire tensor. + expert_offsets: The expert offsets tensor + blockscale_offsets: The blockscale offsets tensor + Outputs: + output: The quantized tensor in FP4 + output_scales: The blockscale tensor in FP8-E4M3 + """ + assert not current_platform.is_rocm() + assert input_tensor.ndim == 2, ( + f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') + + input_tensor = input_tensor[ + expert_map] if expert_map is not None else input_tensor + m_numtopk, k = input_tensor.shape + # Control the maximum number of tokens per expert supported by the + # NVFP4 MoE Expert Quantization. This is used to prevent the kernel + # from running out of memory. This value can also be increased to support + # larger models. + MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE + assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( + f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" + f"{MAX_TOKENS_PER_EXPERT})" + f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" + f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.") + scales_k = k // 16 + padded_k = (scales_k + (4 - 1)) // 4 + + # output is uint8 and packed fp4 values + output = torch.empty(m_numtopk, + k // 2, + device=input_tensor.device, + dtype=torch.uint8) + output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device) + torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor, + input_global_scale, expert_offsets, + blockscale_offsets) + output_scales = output_scales.view(torch.float8_e4m3fn) + return output, output_scales + + # fp8 def scaled_fp8_quant( input: torch.Tensor, @@ -1297,6 +1472,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], b_qweight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -1311,11 +1487,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], use_fp32_reduce: bool, is_zp_float: bool) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace, - sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights, - moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m, - size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, - is_zp_float) + input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, + perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, + topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, + b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, + use_fp32_reduce, is_zp_float) if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py index 18e0c5227d45..9cc2b181fc7c 100644 --- a/vllm/adapter_commons/layers.py +++ b/vllm/adapter_commons/layers.py @@ -1,15 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Tuple @dataclass class AdapterMapping: # Per every token in input_ids: - index_mapping: Tuple[int, ...] + index_mapping: tuple[int, ...] # Per sampled token: - prompt_mapping: Tuple[int, ...] + prompt_mapping: tuple[int, ...] def __post_init__(self): self.index_mapping = tuple(self.index_mapping) diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index f9a5d2fffad5..a84fbea2e444 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, TypeVar +from typing import Any, Callable, Optional, TypeVar from torch import nn @@ -49,9 +49,9 @@ def __init__( model: the model to be adapted. """ self.model: nn.Module = model - self._registered_adapters: Dict[int, Any] = {} + self._registered_adapters: dict[int, Any] = {} # Dict instead of a Set for compatibility with LRUCache. - self._active_adapters: Dict[int, None] = {} + self._active_adapters: dict[int, None] = {} self.adapter_type = 'Adapter' self._last_mapping = None @@ -97,7 +97,7 @@ def get_adapter(self, adapter_id: int) -> Optional[Any]: raise NotImplementedError @abstractmethod - def list_adapters(self) -> Dict[int, Any]: + def list_adapters(self) -> dict[int, Any]: raise NotImplementedError @abstractmethod diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index c2dc5433cc65..46e9629e1f55 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Optional, Set +from typing import Any, Callable, Optional ## model functions -def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], +def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None], deactivate_func: Callable) -> bool: if adapter_id in active_adapters: deactivate_func(adapter_id) @@ -13,7 +13,7 @@ def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], return False -def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], +def add_adapter(adapter: Any, registered_adapters: dict[int, Any], capacity: int, add_func: Callable) -> bool: if adapter.id not in registered_adapters: if len(registered_adapters) >= capacity: @@ -32,23 +32,23 @@ def set_adapter_mapping(mapping: Any, last_mapping: Any, return last_mapping -def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], +def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any], deactivate_func: Callable) -> bool: deactivate_func(adapter_id) return bool(registered_adapters.pop(adapter_id, None)) -def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: +def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]: return dict(registered_adapters) def get_adapter(adapter_id: int, - registered_adapters: Dict[int, Any]) -> Optional[Any]: + registered_adapters: dict[int, Any]) -> Optional[Any]: return registered_adapters.get(adapter_id) ## worker functions -def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], +def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any], apply_adapters_func, set_adapter_mapping_func) -> None: apply_adapters_func(requests) @@ -66,7 +66,7 @@ def add_adapter_worker(adapter_request: Any, list_adapters_func, return loaded -def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, +def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func, adapter_slots: int, remove_adapter_func, add_adapter_func) -> None: models_that_exist = list_adapters_func() @@ -88,5 +88,5 @@ def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, add_adapter_func(models_map[adapter_id]) -def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: +def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]: return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index ce24e08a5b56..3c1d26404c99 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Optional, Set +from typing import Any, Optional import torch @@ -17,7 +17,7 @@ def is_enabled(self) -> bool: raise NotImplementedError @abstractmethod - def set_active_adapters(self, requests: Set[Any], + def set_active_adapters(self, requests: set[Any], mapping: Optional[Any]) -> None: raise NotImplementedError @@ -34,5 +34,5 @@ def remove_all_adapters(self) -> None: raise NotImplementedError @abstractmethod - def list_adapters(self) -> Set[int]: + def list_adapters(self) -> set[int]: raise NotImplementedError diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py new file mode 100644 index 000000000000..eceab1f1ac9a --- /dev/null +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -0,0 +1,1494 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with Dual chunk flash attention and sparse attention. +""" +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + FlashAttentionMetadataBuilder) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.utils import async_tensor_h2d +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache, sparse_attn_func) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + +logger = init_logger(__name__) + + +class DualChunkFlashAttentionBackend(FlashAttentionBackend): + + accept_output_buffer: bool = False + + @staticmethod + def get_name() -> str: + return "DUAL_CHUNK_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]: + return DualChunkFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]: + return DualChunkFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]: + return DualChunkFlashAttentionMetadataBuilder + + +@dataclass +class DualChunkFlashAttentionMetadata(FlashAttentionMetadata): + # Block size of the paged kv cache. + block_size: int = 16 + + # Original max position embeddings. + original_max_position_embeddings: int = 0 + + # Chunk size + chunk_size: int = 8192 + + # Local size + local_size: int = 1024 + + # (batch_size,). The orig sequence length per sequence. + orig_seq_lens: Optional[List[int]] = None + + # orig_seq_lens stored as a tensor. + orig_seq_lens_tensor: Optional[torch.Tensor] = None + + # Length scaling factor + scaling_factor: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for intra attention. + seq_lens_intra: Optional[torch.Tensor] = None + + # Max sequence length for intra attention. + max_seq_len_intra: Optional[int] = None + + # (batch_size, num_blocks). Block table for intra attention. + block_tables_intra: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for succ attention. + seq_lens_succ: Optional[torch.Tensor] = None + + # Max sequence length for succ attention. + max_seq_len_succ: Optional[int] = None + + # (batch_size, num_blocks). Block table for succ attention. + block_tables_succ: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for inter attention. + seq_lens_inter: Optional[torch.Tensor] = None + + # Max sequence length for inter attention. + max_seq_len_inter: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "DualChunkFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + prefill_metadata = super().prefill_metadata + if prefill_metadata is None: + return None + + prefill_metadata = DualChunkFlashAttentionMetadata( + **prefill_metadata.asdict_zerocopy()) + + prefill_metadata.orig_seq_lens = ( + None if self.orig_seq_lens is None else + self.orig_seq_lens[:self.num_prefills]) + prefill_metadata.orig_seq_lens_tensor = ( + None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[:self.num_prefills]) + + if self.original_max_position_embeddings > 0: + assert prefill_metadata.orig_seq_lens_tensor is not None + prefill_metadata.scaling_factor = ( + 0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor / + self.original_max_position_embeddings) + + 1.0).clip(min=1) + + self._cached_prefill_metadata = prefill_metadata + return prefill_metadata + + @property + def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + + decode_metadata = super().decode_metadata + if decode_metadata is None: + return None + + decode_metadata = DualChunkFlashAttentionMetadata( + **decode_metadata.asdict_zerocopy()) + + decode_metadata.orig_seq_lens_tensor = ( + None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[self.num_prefills:]) + + assert decode_metadata.orig_seq_lens_tensor is not None + assert decode_metadata.block_tables is not None + + cache_seq_lens = decode_metadata.orig_seq_lens_tensor + chunk_len = self.chunk_size - self.local_size + chunk_num_curr = (cache_seq_lens - 1) // chunk_len + batch_size = decode_metadata.num_decode_tokens + + if self.original_max_position_embeddings > 0: + decode_metadata.scaling_factor = (0.1 * torch.log( + cache_seq_lens / self.original_max_position_embeddings) + + 1.0).clip(min=1) + + seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len + max_seq_len_intra = seq_lens_intra.max().item() + decode_metadata.seq_lens_intra = seq_lens_intra + decode_metadata.max_seq_len_intra = max_seq_len_intra + + block_tables_intra = torch.zeros( + batch_size, + (max_seq_len_intra - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + st = chunk_num_curr[i] * chunk_len // self.block_size + ed = min( + st + (max_seq_len_intra - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_intra[i, :ed - + st] = decode_metadata.block_tables[i, st:ed] + decode_metadata.block_tables_intra = block_tables_intra + + seq_lens_succ = (chunk_num_curr - + (chunk_num_curr - 1).clip(min=0)) * chunk_len + max_seq_len_succ = seq_lens_succ.max().item() + decode_metadata.seq_lens_succ = seq_lens_succ + decode_metadata.max_seq_len_succ = max_seq_len_succ + if max_seq_len_succ: + block_tables_succ = torch.zeros( + batch_size, + (max_seq_len_succ - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // + self.block_size) + end = min( + start + (max_seq_len_succ - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_succ[ + i, :end - start] = decode_metadata.block_tables[i, + start:end] + decode_metadata.block_tables_succ = block_tables_succ + + seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len + max_seq_len_inter = seq_lens_inter.max().item() + decode_metadata.seq_lens_inter = seq_lens_inter + decode_metadata.max_seq_len_inter = max_seq_len_inter + + self._cached_decode_metadata = decode_metadata + return decode_metadata + + +class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + + def prepare(self): + super().prepare() + self.orig_seq_lens: List[int] = [] + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + super()._add_seq_group(inter_data, chunked_prefill_enabled, + prefix_cache_hit) + for prompt_len, seq_len in zip(inter_data.prompt_lens, + inter_data.seq_lens): + self.orig_seq_lens.append(max(prompt_len, seq_len)) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + attn_metadata = super().build(seq_lens, query_lens, + cuda_graph_pad_size, batch_size) + attn_metadata = DualChunkFlashAttentionMetadata( + **attn_metadata.asdict_zerocopy()) + + device = self.runner.device + attn_metadata.orig_seq_lens = self.orig_seq_lens + attn_metadata.orig_seq_lens_tensor = async_tensor_h2d( + self.orig_seq_lens, torch.int, device, self.runner.pin_memory) + + attn_metadata.block_size = self.runner.block_size + dual_chunk_attn_config = getattr(self.runner.model_config.hf_config, + "dual_chunk_attention_config", {}) + attn_metadata.original_max_position_embeddings = \ + dual_chunk_attn_config.get("original_max_position_embeddings", 0) + attn_metadata.chunk_size = dual_chunk_attn_config.get( + "chunk_size", 8192) + attn_metadata.local_size = dual_chunk_attn_config.get( + "local_size", 1024) + + return attn_metadata + + +class DualChunkFlashAttentionImpl(FlashAttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + The prompts might have different lengths, while the generation tokens + always have length 1. + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + layer_idx: int = -1, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + DualChunkFlashAttentionBackend.get_supported_head_sizes()) + + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + assert dual_chunk_attention_config is not None + self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) + self.local_size = dual_chunk_attention_config.get("local_size", 1024) + self.original_max_position_embeddings = dual_chunk_attention_config.get( + "original_max_position_embeddings", 0) + self.sparse_attention_config = dual_chunk_attention_config.get( + "sparse_attention_config", None) + if not self.sparse_attention_config: + logger.warning_once("Sparse attention will not be enabled as " + "sparse attention config is not provided.") + self.sparse_attention_enabled = dual_chunk_attention_config.get( + "sparse_attention_enabled", self.sparse_attention_config + is not None) + self.sparse_attention_threshold = dual_chunk_attention_config.get( + "sparse_attention_threshold", 32768) + self.sparse_attention_last_q = dual_chunk_attention_config.get( + "sparse_attention_last_q", 64) + self.layer_idx = layer_idx + self.dual_chunk_attention_config = dual_chunk_attention_config + + if self.sparse_attention_config: + self.sparse_attention_config = { + int(i): j + for i, j in self.sparse_attention_config[ + self.layer_idx].items() + } + start_head = self.num_heads * get_tensor_model_parallel_rank() + end_head = start_head + self.num_heads + self.sparse_attention_config = [ + self.sparse_attention_config[i] + for i in range(start_head, end_head) + ] + + if self.sparse_attention_enabled: + self.arange = torch.arange(self.sparse_attention_last_q, + device="cuda") + self.last_q_mask = (self.arange[None, None, :, None] + >= self.arange[None, None, None, :]) + + def forward( # type: ignore + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DualChunkFlashAttentionMetadata, + ) -> torch.Tensor: + """Forward pass with DualChunkFlashAttention. + Args: + query: shape = [num_tokens, num_heads * head_size] + query_succ: shape = [num_tokens, num_heads * head_size] + query_inter: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ) = torch.split(query, query.shape[-1] // 5, dim=-1) + + assert ( + query_succ is not None and query_inter is not None + ), "query_succ and query_inter are required in Dual Chunk Attention." + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + query_succ = query_succ.view(-1, self.num_heads, self.head_size) + query_inter = query_inter.view(-1, self.num_heads, self.head_size) + query_succ_critical = query_succ_critical.view(-1, self.num_heads, + self.head_size) + query_inter_critical = query_inter_critical.view( + -1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.original_max_position_embeddings > 0: + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.scaling_factor is not None + assert prefill_meta.query_start_loc is not None + assert prefill_meta.orig_seq_lens is not None + current_start = 0 + query_start_loc_cpu = prefill_meta.query_start_loc.cpu() + for i in range(len(prefill_meta.orig_seq_lens)): + current_end = (current_start + + (query_start_loc_cpu[i + 1] - + query_start_loc_cpu[i]).item()) + key[current_start:current_end].mul_( + prefill_meta.scaling_factor[i]) + current_start = current_end + assert current_end <= attn_metadata.num_prefill_tokens + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + key[attn_metadata.num_prefill_tokens:].mul_( + scaling_factor.unsqueeze(-1).unsqueeze(-1)) + + if kv_cache is not None and kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + decode_query_succ = query_succ[num_prefill_tokens:] + decode_query_inter = query_inter[num_prefill_tokens:] + + # QKV for prefill. + query = query[:num_prefill_tokens] + query_succ = query_succ[:num_prefill_tokens] + query_inter = query_inter[:num_prefill_tokens] + query_succ_critical = query_succ_critical[:num_prefill_tokens] + query_inter_critical = query_inter_critical[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention, called during the profiling run. + out = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + assert prefill_meta.orig_seq_lens is not None + output[:num_prefill_tokens] = ( + self._dual_chunk_flash_attn_prefill( + q=query, + q_succ=query_succ, + q_inter=query_inter, + q_succ_critical=query_succ_critical, + q_inter_critical=query_inter_critical, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + orig_seq_lens=prefill_meta.orig_seq_lens, + scaling_factor=prefill_meta.scaling_factor, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1), + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + chunk_size=self.chunk_size, + local_size=self.local_size, + )) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[num_prefill_tokens:] = ( + self._dual_chunk_flash_attn_decoding( + decode_query.unsqueeze(1), + decode_query_succ.unsqueeze(1), + decode_query_inter.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + chunk_size=self.chunk_size, + local_size=self.local_size, + original_max_position_embeddings=self. + original_max_position_embeddings, + decode_meta=decode_meta, + ).squeeze(1)) + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + def _dual_chunk_flash_attn_prefill( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + orig_seq_lens: List[int], + scaling_factor: torch.Tensor, + softmax_scale: float, + causal: Optional[bool] = True, + window_size: Tuple[int, int] = (-1, -1), + alibi_slopes: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + chunk_size: int = 8192, + local_size: int = 1024, + ): + if alibi_slopes is not None: + raise ValueError( + "Dual Chunk Attention does not support alibi_slopes") + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + if window_size != (-1, -1): + raise ValueError( + "Dual Chunk Attention does not support window_size") + + cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() + cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() + all_outputs = [] + + for i in range(0, len(cu_seqlens_q_cpu) - 1): + qs = cu_seqlens_q_cpu[i] + qe = cu_seqlens_q_cpu[i:i + 2][-1] + ks = cu_seqlens_k_cpu[i] + ke = cu_seqlens_k_cpu[i:i + 2][-1] + + current_q = q[qs:qe] + current_q_succ = q_succ[qs:qe] + current_q_inter = q_inter[qs:qe] + current_q_succ_critical = q_succ_critical[qs:qe] + current_q_inter_critical = q_inter_critical[qs:qe] + + if block_table is None: + current_k = k[ks:ke] + current_v = v[ks:ke] + current_block_table = None + current_orig_seq_len = orig_seq_lens[i] + else: + current_block_table = block_table[i] + current_orig_seq_len = orig_seq_lens[i] + current_k = k + current_v = v + sparse_attn_enabled = (self.sparse_attention_enabled + and current_orig_seq_len + > self.sparse_attention_threshold) + + if current_q.shape[0] == 0: + continue + + if current_k.shape[0] == 0: + all_outputs.append( + torch.zeros( + (current_q.shape[0], current_q.shape[1], v.shape[2]), + device=q.device, + dtype=q.dtype, + )) + continue + + current_output = torch.empty_like(current_q) + group_size = int(current_q.size(-2) / current_k.size(-2)) + + if sparse_attn_enabled: + num_device_q_heads = current_q.size(-2) + heads_vertical_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + heads_slash_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + for head_id in range(current_q.size(-2)): + ( + ty, + vertical_size, + slash_size, + _, + ) = self.sparse_attention_config[head_id] + assert ty == "vertical_and_slash", "only support slash mode" + + if vertical_size == 30: + vertical_size += 100 + heads_vertical_size[head_id] = vertical_size + heads_slash_size[head_id] = slash_size + + current_output = self._dual_chunk_flash_attn_prefill_func( + current_q, # allheads + current_q_succ, + current_q_inter, + current_q_succ_critical, + current_q_inter_critical, + current_k, + current_v, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + heads_vertical_size=heads_vertical_size, + heads_slash_size=heads_slash_size, + group_size=group_size) + else: + for head_id in range(current_q.size(-2)): + # (seq_len, num_heads, head_size) + current_q_head = current_q[:, head_id, :].unsqueeze(1) + current_q_succ_head = \ + current_q_succ[:, head_id, :].unsqueeze(1) + current_q_inter_head = \ + current_q_inter[:, head_id, :].unsqueeze(1) + current_q_succ_head_critical = \ + current_q_succ_critical[:, head_id, :].unsqueeze(1) + current_q_inter_head_critical = \ + current_q_inter_critical[:, head_id, :].unsqueeze(1) + if block_table is not None: + current_k_head = current_k[..., head_id // + group_size, :].unsqueeze(2) + current_v_head = current_v[..., head_id // + group_size, :].unsqueeze(2) + + else: + current_k_head = current_k[:, head_id, :].unsqueeze(1) + current_v_head = current_v[:, head_id, :].unsqueeze(1) + + current_out = self._dual_chunk_flash_attn_prefill_func( + current_q_head, + current_q_succ_head, + current_q_inter_head, + current_q_succ_head_critical, + current_q_inter_head_critical, + current_k_head, + current_v_head, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + ) + current_output[:, head_id:head_id + 1, :] = current_out + all_outputs.append(current_output) + return torch.cat(all_outputs, dim=0) + + def _dual_chunk_flash_attn_prefill_func( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + block_table, + softmax_scale: float, + chunk_size: int, + local_size: int, + scaling_factor: float, + k_length: int, + sparse_attn_enabled: Optional[bool] = True, + heads_vertical_size=None, + heads_slash_size=None, + group_size=None, + ): + flash_results = [] + chunk_len = chunk_size - local_size + + if block_table is not None: + block_size = v.shape[1] + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + else: + block_size = 1 + + if self.original_max_position_embeddings > 0: + softmax_scale = softmax_scale * scaling_factor + + begin = k_length - q.shape[0] + while begin < k_length: + flash_per_chunk = [] + + prev_chunk_end_pos = (begin // chunk_len) * chunk_len + next_chunk_end_pos = prev_chunk_end_pos + chunk_len + end = min(next_chunk_end_pos, k_length) + qbegin = begin - (k_length - q.shape[0]) + qend = end - (k_length - q.shape[0]) + + qk_chunks = [] + q_states_intra = q[qbegin:qend] + # choose critical token + if block_table is not None: + block_tables_intra = _get_block(block_table, block_size, + prev_chunk_end_pos, end) + k_states_intra = k[block_tables_intra].view( + -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)] + v_states_intra = v[block_tables_intra].view( + -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)] + else: + block_tables_intra = None + k_states_intra = k[prev_chunk_end_pos:end] + v_states_intra = v[prev_chunk_end_pos:end] + + if sparse_attn_enabled: + last_q_size = min(qend - qbegin, self.sparse_attention_last_q) + _, num_device_k_heads, head_dim = k_states_intra.shape + k_states_intra = (k_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + v_states_intra = (v_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + qk_chunks.append( + (q_states_intra.transpose(0, 1)[:, -last_q_size:] * + softmax_scale) @ k_states_intra.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len >= 0: + q_states_succ = q_succ[qbegin:qend] + q_states_succ_critical = q_succ_critical[qbegin:qend] + if block_table is not None: + block_tables_succ = _get_block( + block_table, block_size, + prev_chunk_end_pos - chunk_len, prev_chunk_end_pos) + k_states_succ = k[block_tables_succ].view( + -1, *k.shape[-2:])[:chunk_len] + v_states_succ = v[block_tables_succ].view( + -1, *v.shape[-2:])[:chunk_len] + else: + k_states_succ = k[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + v_states_succ = v[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + + if sparse_attn_enabled: + k_states_succ = (k_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_succ = (v_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_succ_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_succ.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + q_states_inter = q_inter[qbegin:qend] + q_states_inter_critical = q_inter_critical[qbegin:qend] + if block_table is not None: + block_tables_inter = _get_block( + block_table, block_size, 0, + prev_chunk_end_pos - chunk_len) + k_states_inter = k[block_tables_inter].view( + -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + v_states_inter = v[block_tables_inter].view( + -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + else: + k_states_inter = k[:prev_chunk_end_pos - chunk_len] + v_states_inter = v[:prev_chunk_end_pos - chunk_len] + + if sparse_attn_enabled: + k_states_inter = (k_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_inter = (v_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_inter_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_inter.permute(1, 2, 0)) + + if sparse_attn_enabled: + reversed_qk = qk_chunks[::-1] + qk = torch.cat(reversed_qk, dim=-1) + + qk[:, :, -last_q_size:] = torch.where( + self.last_q_mask[..., -last_q_size:, + -last_q_size:].to(qk.device), + qk[:, :, -last_q_size:], -torch.inf) + qk = F.softmax(qk, dim=-1, dtype=torch.float32) + + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + + # Avoid sorting by using the min/max ints to fill the indexer + # buffers. + int32_max = torch.iinfo(torch.int32).max + int32_min = torch.iinfo(torch.int32).min + n_heads = qk.size()[0] + max_slash_topk = torch.max(heads_slash_size).item() + max_vertical_topk = torch.max(heads_vertical_size).item() + # store each head's slash topk, vertical topk + vertical = vertical.reshape((n_heads, -1)) + # prevent out of range when prompt size < max_vertical_topk + max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) + vertical_topk_buffer = torch.topk(vertical, max_vertical_topk, + -1).indices + slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk), + dtype=torch.int64, + device=qk.device) + for head_i in range(n_heads): + # (nqheads=1, lastq, k_len) + head_score = qk[head_i:head_i + 1, :, :] + slash_scores = _sum_all_diagonal_matrix(head_score) + if head_score.size(1) != 1: + # drop right up corner + slash_scores = slash_scores[..., :-last_q_size + 1] + slash_scores[..., -100:] = torch.inf + + head_slash_size = heads_slash_size[head_i] + head_slash_size = min(head_slash_size, vertical.size(-1)) + slash_topk = torch.topk(slash_scores, head_slash_size, + -1).indices + #๏ผˆnheads, max_topk๏ผ‰ + slash_topk_buffer[head_i, :head_slash_size] = slash_topk + + # reset heads topk + heads_slash_size[head_i] = head_slash_size + heads_vertical_size[head_i] = min( + heads_vertical_size[head_i], max_vertical_topk) + + # store + vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + succ_vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + succ_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + inter_vertical_buffer = torch.full( + (n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + inter_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + + vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + + for head_i in range(n_heads): + vertical_topk = vertical_topk_buffer[ + head_i, :heads_vertical_size[head_i]] + # intra + intra_vertical_indices = vertical_topk[ + vertical_topk >= + prev_chunk_end_pos] - prev_chunk_end_pos + if intra_vertical_indices.nelement() == 0: + intra_vertical_indices = torch.cat([ + intra_vertical_indices, + torch.arange(0, + k_states_intra.size(0), + max(1, + k_states_intra.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + slash_topk = slash_topk_buffer[ + head_i, :heads_slash_size[head_i]] + intra_slash_indices = ( + (qk.size(-1) - 1) - + slash_topk[slash_topk >= prev_chunk_end_pos]) + # fill buffer + v_count = intra_vertical_indices.nelement() + s_count = intra_slash_indices.nelement() + vertical_size_buffer[head_i] = v_count + slash_sizes_buffer[head_i] = s_count + vertical_buffer[head_i, :v_count].copy_( + intra_vertical_indices) + slash_buffer[head_i, :s_count].copy_(intra_slash_indices) + # succ + if prev_chunk_end_pos - chunk_len >= 0: + succ_vertical_indices = vertical_topk[ + (vertical_topk < prev_chunk_end_pos) + & (vertical_topk >= prev_chunk_end_pos - + chunk_len)] - (prev_chunk_end_pos - chunk_len) + # TODO: support no vertical + if succ_vertical_indices.nelement() == 0: + succ_vertical_indices = torch.cat([ + succ_vertical_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + succ_slash_indices = ( + (prev_chunk_end_pos + (qend - qbegin) - 1) - + slash_topk[((slash_topk >= + (prev_chunk_end_pos - chunk_len)) & + (slash_topk < (prev_chunk_end_pos + + (qend - qbegin))))]) + if succ_slash_indices.nelement() == 0: + succ_slash_indices = torch.cat([ + succ_slash_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = succ_vertical_indices.nelement() + s_count = succ_slash_indices.nelement() + succ_vertical_size_buffer[head_i] = v_count + succ_slash_sizes_buffer[head_i] = s_count + succ_vertical_buffer[head_i, :v_count].copy_( + succ_vertical_indices) + succ_slash_buffer[head_i, :s_count].copy_( + succ_slash_indices) + + if prev_chunk_end_pos - 2 * chunk_len >= 0: + inter_vertical_indices = vertical_topk[ + vertical_topk < prev_chunk_end_pos - chunk_len] + + if inter_vertical_indices.nelement() == 0: + inter_vertical_indices = torch.cat([ + inter_vertical_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + inter_slash_indices = ( + (prev_chunk_end_pos - chunk_len + + (qend - qbegin) - 1) - + slash_topk[slash_topk < (prev_chunk_end_pos - + chunk_len + + (qend - qbegin))]) + if inter_slash_indices.nelement() == 0: + inter_slash_indices = torch.cat([ + inter_slash_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = inter_vertical_indices.nelement() + s_count = inter_slash_indices.nelement() + inter_vertical_size_buffer[head_i] = v_count + inter_slash_sizes_buffer[head_i] = s_count + inter_vertical_buffer[head_i, :v_count].copy_( + inter_vertical_indices) + inter_slash_buffer[head_i, :s_count].copy_( + inter_slash_indices) + else: + intra_vertical_indices, intra_slash_indices = None, None + succ_vertical_indices, succ_slash_indices = None, None + inter_vertical_indices, inter_slash_indices = None, None + + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + block_table=block_table, + stage="intra", + vertical_indices=vertical_buffer, + slash_indices=slash_buffer, + vertical_indices_count=vertical_size_buffer, + slash_indices_count=slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + block_table=block_table, + stage="intra", + vertical_indices=intra_vertical_indices, + slash_indices=intra_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="succ", + vertical_indices=succ_vertical_buffer, + slash_indices=succ_slash_buffer, + vertical_indices_count=succ_vertical_size_buffer, + slash_indices_count=succ_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="succ", + vertical_indices=succ_vertical_indices, + slash_indices=succ_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="inter", + vertical_indices=inter_vertical_buffer, + slash_indices=inter_slash_buffer, + vertical_indices_count=inter_vertical_size_buffer, + slash_indices_count=inter_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="inter", + vertical_indices=inter_vertical_indices, + slash_indices=inter_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + flash_results.append(flash_per_chunk) + begin = end + + attn_output = self._merge_attn_outputs(flash_results) + del flash_results + return attn_output + + def _do_flash_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + softmax_scale: float, + causal: bool = True, + block_table: torch.Tensor = None, + max_seqlen_k: Optional[int] = None, + stage: str = "intra", + vertical_indices: Optional[torch.Tensor] = None, + slash_indices: Optional[torch.Tensor] = None, + vertical_indices_count: Optional[torch.Tensor] = None, + slash_indices_count: Optional[torch.Tensor] = None, + mergehead_softmax_scale: Optional[float] = None, + sparse_attn_enabled: Optional[bool] = False, + ): + if max_seqlen_k is None: + max_seqlen_k = key_states.shape[0] + + q_len = query_states.shape[0] + q_heads = query_states.shape[1] + h_dim = query_states.shape[-1] + + if sparse_attn_enabled: + assert slash_indices is not None + if stage == "intra": + assert causal + else: + assert not causal + + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + value_states = value_states.unsqueeze(0).transpose(1, 2) + + q = query_states + k = key_states + v = value_states + + if (vertical_indices_count is not None and \ + slash_indices_count is not None): + assert mergehead_softmax_scale is not None + + res, s_lse = _vertical_slash_sparse_attention( + q, + k, + v, + vertical_indices, + slash_indices, + mergehead_softmax_scale, + causal=causal, + stage=stage, + vertical_indices_count=vertical_indices_count, + slash_indices_count=slash_indices_count) + res = res.view(q_heads, q_len, + h_dim).transpose(0, 1) # (qlen,nhead,h_dim) + s_lse = s_lse.view( + q_heads, q_len, + 1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen) + else: + res, s_lse = _vertical_slash_sparse_attention(q, + k, + v, + vertical_indices, + slash_indices, + softmax_scale, + causal=causal, + stage=stage) + res = res.view(q_len, q_heads, h_dim) + s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() + return res, s_lse + + output, softmax_lse = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + softmax_scale=softmax_scale, + cu_seqlens_q=torch.tensor([0, query_states.shape[0]], + dtype=torch.int32, + device=query_states.device), + max_seqlen_q=query_states.shape[0], + cu_seqlens_k=torch.tensor([0, max_seqlen_k], + dtype=torch.int32, + device=query_states.device), + max_seqlen_k=max_seqlen_k, + causal=causal, + block_table=block_table.unsqueeze(0), + return_softmax_lse=True, + ) + softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, + 2).float() + return output, softmax_lse + + def _merge_attn_outputs( + self, + flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], + return_lse: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs_all = [] + logits_all = [] + + for flash_per_chunk in flash_results: + if len(flash_per_chunk) == 1: + attn_outputs_all.append(flash_per_chunk[0][0]) + if return_lse: + logits_all.append(flash_per_chunk[0][1]) + continue + + attn_outputs = torch.stack([ + flash_attn_output[0] for flash_attn_output in flash_per_chunk + ]) + logits = torch.stack([ + flash_attn_output[1] for flash_attn_output in flash_per_chunk + ]) + logits = logits.to(torch.float32) + + if return_lse: + max_val = torch.max(logits, dim=0).values + diff = torch.abs(logits[0] - logits[1]) + log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) + logits_all.append(log_sum_exp) + + max_logits = torch.max(logits, dim=0).values + stable_logits = logits - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) + attn_outputs_all.append(attn_outputs.sum(dim=0)) + + if return_lse: + return (torch.cat(attn_outputs_all, + dim=0), torch.cat(logits_all, dim=-1)) + else: + return torch.cat(attn_outputs_all, dim=0) + + def _dual_chunk_flash_attn_decoding( + self, + query: torch.Tensor, + query_succ: torch.Tensor, + query_inter: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + causal: bool, + alibi_slopes: Optional[torch.Tensor], + chunk_size: int, + local_size: int, + original_max_position_embeddings: int, + decode_meta: DualChunkFlashAttentionMetadata, + ): + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + + block_size = value_cache.shape[1] + chunk_len = chunk_size - local_size + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + if original_max_position_embeddings > 0: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + query = (query * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype + ) # possible for numerical issue, need to fused in the kernel + query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + outputs_list = [] + softmax_lses_list = [] + + # intra-attention + intra_output, intra_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query, + key_cache, + value_cache, + decode_meta.block_tables_intra, + decode_meta.seq_lens_intra, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(intra_output) + softmax_lses_list.append(intra_softmax_lse) + + # succ-attention + if decode_meta.max_seq_len_succ: + succ_output, succ_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_succ, + key_cache, + value_cache, + decode_meta.block_tables_succ, + decode_meta.seq_lens_succ, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(succ_output) + softmax_lses_list.append(succ_softmax_lse) + + # inter-attention + if decode_meta.max_seq_len_inter: + inter_output, inter_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_inter, + key_cache, + value_cache, + block_table[:, :decode_meta.max_seq_len_inter], + decode_meta.seq_lens_inter, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(inter_output) + softmax_lses_list.append(inter_softmax_lse) + outputs = torch.stack(outputs_list, dim=0) + del outputs_list + softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) + del softmax_lses_list + max_logits = torch.max(softmax_lses, dim=0).values + stable_logits = softmax_lses - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + outputs *= lse_s.unsqueeze(-1).transpose(2, 3) + return outputs.sum(0) + + def _dual_chunk_flash_attn_decoding_with_exp_sums( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + ): + out, softmax_lse = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + softmax_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + return_softmax_lse=True, + ) + mask = (cache_seqlens == 0) + out[mask] = 0 + softmax_lse[mask] = -float("inf") + return out, softmax_lse + + +def _vertical_slash_sparse_attention( + query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + softmax_scale: float, + causal: bool = True, + stage: str = "intra", + block_size_M: int = 64, + block_size_N: int = 64, + vertical_indices_count: torch.Tensor = None, # [N_HEADS,] + slash_indices_count: torch.Tensor = None, +): + if stage == "intra": + assert causal + else: + assert not causal + + batch_size, num_heads, context_size, head_dim = query.shape + _, _, kv_seq_len, _ = key.shape + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + v_idx = v_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + q_seqlens = torch.tensor([context_size], + dtype=torch.int32, + device=query.device) + kv_seqlens = torch.tensor([kv_seq_len], + dtype=torch.int32, + device=query.device) + + if vertical_indices_count is not None and slash_indices_count is not None: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes_mergehead( + q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, + causal) + else: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx, + s_idx, context_size, + block_size_M, block_size_N, + causal) + + q = query.transpose(1, 2).contiguous() + k = key.transpose(1, 2).contiguous() + v = value.transpose(1, 2).contiguous() + out, lse = sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + causal=causal, + softmax_scale=softmax_scale, + return_softmax_lse=True, + ) + out = out.transpose(1, 2).contiguous() + softmax_lse = lse.reshape(*lse.shape, 1) + return (out[..., :context_size, :head_dim], + softmax_lse[..., :context_size, :]) + + +def _sum_all_diagonal_matrix(mat: torch.tensor): + h, n, m = mat.shape + # Zero matrix used for padding + zero_mat = torch.zeros((h, n, n), device=mat.device) + # pads the matrix on left and right + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) + # Change the strides + mat_strided = mat_padded.as_strided((1, n, n + m), + (n * (2 * n + m), 2 * n + m + 1, 1)) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 1) + return sum_diags[:, 1:] # drop left bottom corner + + +def _get_block(block_table: torch.Tensor, block_size: int, begin: int, + end: int): + begin_block = begin // block_size + end_block = (end - 1) // block_size + 1 + return block_table[begin_block:end_block] diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 0100c082aa21..d48462684906 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -211,8 +211,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON @@ -377,7 +375,6 @@ def graph_capture_get_metadata_for_batch( seq_start_loc=None, context_lens_tensor=None, block_tables=self._graph_block_tables[:batch_size], - input_positions=self._positions[:batch_size], head_dim=self.runner.model_config.get_head_size()) if is_encoder_decoder_model: @@ -393,7 +390,6 @@ def get_graph_input_buffers(self, "slot_mapping": attn_metadata.slot_mapping, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, - "input_positions": attn_metadata.decode_metadata.input_positions, } if is_encoder_decoder_model: raise NotImplementedError( @@ -405,16 +401,10 @@ def prepare_graph_input_buffers(self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): - input_positions = attn_metadata.input_positions - num_positions = input_positions.shape[0] input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - # CUDA graph buffer is padded so only perform a partial copy based on - # num_positions - input_buffers["input_positions"][:num_positions].copy_( - input_positions, non_blocking=True) if is_encoder_decoder_model: raise NotImplementedError( "TritonMLAState does not support encoder/decoder yet") @@ -456,11 +446,6 @@ class MLACommonMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # New for MLA (compared to FlashAttention) - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -563,8 +548,6 @@ def prefill_metadata(self): self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) - input_positions = (None if self.input_positions is None else - self.input_positions[:self.num_prefill_tokens]) self._cached_prefill_metadata = self.__class__( # Required by ModelRunner @@ -578,7 +561,6 @@ def prefill_metadata(self): multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, # MLACommonMetadata - input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -615,8 +597,6 @@ def decode_metadata(self): self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - input_positions = (None if self.input_positions is None else - self.input_positions[self.num_prefill_tokens:]) self._cached_decode_metadata = self.__class__( # Required by ModelRunner @@ -646,7 +626,6 @@ def decode_metadata(self): if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, - input_positions=input_positions, head_dim=self.head_dim, is_profile_run=self.is_profile_run) return self._cached_decode_metadata @@ -765,7 +744,6 @@ def prepare(self): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] - self.input_positions: List[int] = [] self.multimodal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -786,13 +764,11 @@ def _add_seq_group( block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( + curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 @@ -912,8 +888,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - input_positions = async_tensor_h2d(self.input_positions, torch.long, - device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, @@ -987,7 +961,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], multi_modal_placeholder_index_maps=None, # Not Attention Related enable_kv_scales_calculation=False, # MLACommonMetadata - input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, @@ -1033,7 +1006,6 @@ def __init__( qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - rotary_emb: RotaryEmbedding, kv_b_proj: ColumnParallelLinear, ) -> None: self.num_heads = num_heads @@ -1048,10 +1020,6 @@ def __init__( self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - - self.rotary_emb = rotary_emb - self.use_yarn_rope = isinstance(rotary_emb, - DeepseekScalingRotaryEmbedding) self.kv_b_proj = kv_b_proj self.triton_fa_func = triton_attention @@ -1095,7 +1063,7 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, softmax_scale, None, # bias ) - if is_vllm_fa: + elif is_vllm_fa: attn_out = self.flash_attn_varlen_func( q=q, k=k, @@ -1367,41 +1335,15 @@ def forward( has_decode = attn_metadata.decode_metadata is not None has_prefill = attn_metadata.prefill_metadata is not None - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert hasattr(attn_metadata, "input_positions") - num_prefill_tokens: int = attn_metadata.num_prefill_tokens q = q.view(-1, self.num_heads, self.qk_head_dim) decode_q = q[num_prefill_tokens:] - decode_k_pe = k_pe[num_prefill_tokens:] - decode_input_positions = \ - attn_metadata.input_positions[num_prefill_tokens:] prefill_q = q[:num_prefill_tokens] prefill_k_pe = k_pe[:num_prefill_tokens] - prefill_input_positions = \ - attn_metadata.input_positions[:num_prefill_tokens] prefill_k_c_normed = k_c_normed[:num_prefill_tokens] - if has_decode: - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - decode_input_positions, decode_q_pe, decode_k_pe) - - if has_prefill: - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - prefill_input_positions, prefill_q_pe, prefill_k_pe) - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -1424,6 +1366,15 @@ def forward( attn_metadata) if has_decode: + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + output[num_prefill_tokens:] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 2984bc1dad64..b048220020f1 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]: @dataclass class AiterMLAMetadata(MLACommonMetadata): - # The following 4 tensors are for current version of AITER MLA + # The following 5 tensors are for current version of AITER MLA block_table_bound: Optional[torch.Tensor] = None # The indptr of the paged kv cache, shape: [batch_size + 1] paged_kv_indptr: Optional[torch.Tensor] = None @@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata): # the paged kv cache, shape: [batch_size] paged_kv_last_page_lens: Optional[torch.Tensor] = None + # This is just to make new AITER MLA API work + # -- MTP support is not added yet. + qo_indptr: Optional[torch.Tensor] = None + @property def prefill_metadata(self): prefill_metadata = super().prefill_metadata @@ -74,6 +78,7 @@ def prefill_metadata(self): prefill_metadata\ .paged_kv_last_page_lens = self.paged_kv_last_page_lens prefill_metadata.block_table_bound = self.block_table_bound + prefill_metadata.qo_indptr = self.qo_indptr # update the cache self._cached_prefill_metadata = self.__class__( @@ -93,6 +98,7 @@ def decode_metadata(self): decode_metadata\ .paged_kv_last_page_lens = self.paged_kv_last_page_lens decode_metadata.block_table_bound = self.block_table_bound + decode_metadata.qo_indptr = self.qo_indptr # update the cache self._cached_decode_metadata = self.__class__( @@ -136,6 +142,7 @@ def prepare(self): self.paged_kv_indptr: list[int] = [0] self.paged_kv_last_page_lens: list[int] = [] self.total_blocks = 0 + self.qo_indptr: list[int] = [0] def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, prefix_cache_hit: bool): @@ -148,13 +155,11 @@ def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( + curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 @@ -210,6 +215,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): self.paged_kv_indices.extend(block_table[:block_table_bound]) self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + block_table_bound) + self.qo_indptr.append(self.qo_indptr[-1] + 1) last_page_len = seq_len % self.block_size if last_page_len == 0: @@ -228,6 +234,8 @@ def build(self, seq_lens: list[int], query_lens: list[int], self.paged_kv_indptr.extend([last_paged_kv_indptr] * cuda_graph_pad_size) self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + last_qo_indptr = self.qo_indptr[-1] + self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) # For current version of AITER MLA if len(self.paged_kv_indptr) > 0: @@ -247,16 +255,22 @@ def build(self, seq_lens: list[int], query_lens: list[int], 1, device=device, dtype=torch.int) + + qo_indptr = torch.tensor(self.qo_indptr, + device=device, + dtype=torch.int) else: paged_kv_indices_tensor = None paged_kv_indptr_tensor = None paged_kv_last_page_lens_tensor = None block_table_bound_tensor = None + qo_indptr = None metadata.paged_kv_indptr = paged_kv_indptr_tensor metadata.paged_kv_indices = paged_kv_indices_tensor metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor metadata.block_table_bound = block_table_bound_tensor + metadata.qo_indptr = qo_indptr return metadata @@ -265,14 +279,17 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): @contextmanager def graph_capture(self, max_batch_size: int): - kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata( - max_batch_size=max_batch_size, - block_size=self.runner.block_size, - max_block_per_batch=self.runner.get_max_block_per_batch(), - device=self.runner.device) + kv_indices, kv_indptr, last_page_lens, qo_indptr = \ + get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=\ + self.runner.get_max_block_per_batch(), + device=self.runner.device) self._paged_kv_indices_tensor = kv_indices self._paged_kv_indptr_tensor = kv_indptr self._paged_kv_last_page_lens_tensor = last_page_lens + self._qo_indptr_tensor = qo_indptr with super().graph_capture(max_batch_size): yield @@ -280,6 +297,7 @@ def graph_capture(self, max_batch_size: int): del self._paged_kv_indices_tensor del self._paged_kv_indptr_tensor del self._paged_kv_last_page_lens_tensor + del self._qo_indptr_tensor def graph_capture_get_metadata_for_batch( self, @@ -293,10 +311,12 @@ def graph_capture_get_metadata_for_batch( paged_kv_indices = self._paged_kv_indices_tensor paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: batch_size] + qo_indptr = self._qo_indptr_tensor[:batch_size + 1] metadata.paged_kv_indptr = paged_kv_indptr metadata.paged_kv_indices = paged_kv_indices metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + metadata.qo_indptr = qo_indptr return metadata @@ -313,6 +333,7 @@ def get_graph_input_buffers(self, input_buffers[ "paged_kv_last_page_lens"] = attn_metadata.\ decode_metadata.paged_kv_last_page_lens + input_buffers['qo_indptr'] = attn_metadata.qo_indptr return input_buffers @@ -332,6 +353,8 @@ def prepare_graph_input_buffers(self, input_buffers["paged_kv_last_page_lens"].copy_( attn_metadata.decode_metadata.paged_kv_last_page_lens, non_blocking=True) + input_buffers["qo_indptr"].copy_( + attn_metadata.decode_metadata.qo_indptr, non_blocking=True) class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): @@ -372,11 +395,9 @@ def _flash_attn_varlen_diff_headdims( softmax_scale: float, return_softmax_lse: bool, **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v, - softmax_scale=softmax_scale, - return_lse=return_softmax_lse, + q, + k, + v, **kwargs, ) @@ -396,7 +417,7 @@ def _forward_decode( B = q_nope.shape[0] q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, + o = torch.empty(B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, @@ -405,6 +426,8 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.qo_indptr, + attn_metadata.max_query_len, attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_lens) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8076c4791d3c..abcb68911a8b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -861,7 +861,8 @@ def forward( gqa_ratio = num_heads // self.num_kv_heads use_custom = use_rocm_custom_paged_attention( decode_query.dtype, head_size, block_size, gqa_ratio, - decode_meta.max_decode_seq_len, self.sliding_window) + decode_meta.max_decode_seq_len, self.sliding_window, + self.kv_cache_dtype, self.alibi_slopes) if use_custom: max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type != AttentionType.ENCODER_DECODER else diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 54ffd5c45ff9..a281c9771a82 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -345,10 +345,10 @@ def graph_capture_get_metadata_for_batch( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or " \ - f"'FLASH_ATTN', but "\ + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ f"got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -367,10 +367,10 @@ def get_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or "\ - f"'FLASH_ATTN', but "\ + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ f"got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index dc039a0259aa..785799b6bf68 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -9,6 +9,7 @@ import torch from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.triton_utils import tl, triton @@ -267,7 +268,7 @@ def chunked_prefill_paged_decode( assert value_cache.dtype == torch.uint8 if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = torch.float8_e4m3fn + target_dtype = current_platform.fp8_dtype() elif kv_cache_dtype == "fp8_e5m2": target_dtype = torch.float8_e5m2 else: @@ -282,7 +283,8 @@ def chunked_prefill_paged_decode( use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, block_size, num_queries_per_kv, - max_seq_len, sliding_window) + max_seq_len, sliding_window, + kv_cache_dtype, alibi_slopes) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 1c90f8c19b09..421891ab6b73 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -4,6 +4,9 @@ import torch +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + def get_aiter_mla_metadata(max_batch_size: int, block_size: int, max_block_per_batch: int, @@ -17,7 +20,8 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, paged_kv_last_page_lens = torch.full((max_batch_size, ), block_size, dtype=torch.int32) - return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens + qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr def aiter_mla_decode_fwd( @@ -25,18 +29,71 @@ def aiter_mla_decode_fwd( kv_buffer: torch.Tensor, o: torch.Tensor, sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, kv_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None, logit_cap: float = 0.0, ): + + torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, + kv_buffer.view( + -1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: from aiter.mla import mla_decode_fwd mla_decode_fwd(q, kv_buffer.view(-1, 1, 1, q.shape[-1]), o, + qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, + max_seqlen_qo, sm_scale=sm_scale, logit_cap=logit_cap) + + +def mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +if current_platform.is_rocm(): + direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=[torch.Tag.needs_fixed_stride_order]) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 8c0cf9267f35..4bced779785a 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -31,8 +31,8 @@ def apply_softcap(S, x): def kernel_unified_attention_2d( output_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -56,11 +56,11 @@ def kernel_unified_attention_2d( stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int stride_v_cache_0: tl.int64, # int stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, @@ -268,6 +268,10 @@ def unified_attention( assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" + block_size = v.shape[1] + assert q.element_size() >= 2 or block_size >= 32, \ + "Block size must be at least 32 for fp8" + use_alibi_slopes = alibi_slopes is not None block_size = v.shape[1] diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index fab44fb6062d..74a9b2b03391 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -13,7 +13,6 @@ TODO: Implement CustomDataset to parse a JSON file and convert its contents into SampleRequest instances, similar to the approach used in ShareGPT. """ - import base64 import io import json @@ -33,6 +32,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.image import convert_image_mode from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer logger = logging.getLogger(__name__) @@ -129,16 +129,17 @@ def get_random_lora_request( Args: tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no - LoRA is selected. max_loras (Optional[int]): The maximum number of - LoRAs available. If None, LoRA is not used. lora_path - (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA - is not used. + LoRA is selected. + max_loras (Optional[int]): The maximum number of LoRAs available. + If `None`, LoRA is not used. + lora_path (Optional[str]): Path to the LoRA parameters on disk. + If `None`, LoRA is not used. Returns: - tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first - element is a LoRARequest (or None if not applicable) and the second - element is the tokenizer associated with the LoRA request (or the - base tokenizer). + A tuple with the following elements: + - A new [LoRARequest][] (or `None` if not applicable). + - The tokenizer associated with the LoRA request + (or the base tokenizer). """ if max_loras is None or lora_path is None: return None, tokenizer @@ -167,7 +168,7 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, Args: tokenizer (PreTrainedTokenizerBase): The tokenizer to be used - for processing the dataset's text. + for processing the dataset's text. num_requests (int): The number of sample requests to generate. Returns: @@ -184,7 +185,8 @@ def maybe_oversample_requests(self, requests: list[SampleRequest], Args: requests (List[SampleRequest]): The current list of sampled - requests. num_requests (int): The target number of requests. + requests. + num_requests (int): The target number of requests. """ if len(requests) < num_requests: random.seed(self.random_seed) @@ -259,7 +261,7 @@ def process_image(image: Any) -> Mapping[str, Any]: if isinstance(image, dict) and 'bytes' in image: image = Image.open(BytesIO(image['bytes'])) if isinstance(image, Image.Image): - image = image.convert("RGB") + image = convert_image_mode(image, "RGB") with io.BytesIO() as image_data: image.save(image_data, format="JPEG") image_base64 = base64.b64encode( diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index 06f6848f50cb..2c992727b139 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -80,6 +80,9 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser = EngineArgs.add_cli_args(parser) + # V1 enables prefix caching by default which skews the latency + # numbers. We need to disable prefix caching by default. + parser.set_defaults(enable_prefix_caching=True) def main(args: argparse.Namespace): diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index b3e24911cc98..13110a8b4db3 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -148,9 +148,10 @@ async def run_vllm_async( async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: + model_config = await llm.get_model_config() assert all( - llm.model_config.max_model_len >= (request.prompt_len + - request.expected_output_len) + model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) for request in requests), ( "Please ensure that max_model_len is greater than the sum of" " prompt_len and expected_output_len for all requests.") diff --git a/vllm/collect_env.py b/vllm/collect_env.py index 9cfceb7c45cc..86eb465b8f65 100644 --- a/vllm/collect_env.py +++ b/vllm/collect_env.py @@ -637,33 +637,50 @@ def get_version_or_na(cfg, prefix): env_info_fmt = """ -PyTorch version: {torch_version} -Is debug build: {is_debug_build} -CUDA used to build PyTorch: {cuda_compiled_version} -ROCM used to build PyTorch: {hip_compiled_version} - -OS: {os} -GCC version: {gcc_version} -Clang version: {clang_version} -CMake version: {cmake_version} -Libc version: {libc_version} - -Python version: {python_version} -Python platform: {python_platform} -Is CUDA available: {is_cuda_available} -CUDA runtime version: {cuda_runtime_version} -CUDA_MODULE_LOADING set to: {cuda_module_loading} -GPU models and configuration: {nvidia_gpu_models} -Nvidia driver version: {nvidia_driver_version} -cuDNN version: {cudnn_version} -HIP runtime version: {hip_runtime_version} -MIOpen runtime version: {miopen_runtime_version} -Is XNNPACK available: {is_xnnpack_available} - -CPU: +============================== + System Info +============================== +OS : {os} +GCC version : {gcc_version} +Clang version : {clang_version} +CMake version : {cmake_version} +Libc version : {libc_version} + +============================== + PyTorch Info +============================== +PyTorch version : {torch_version} +Is debug build : {is_debug_build} +CUDA used to build PyTorch : {cuda_compiled_version} +ROCM used to build PyTorch : {hip_compiled_version} + +============================== + Python Environment +============================== +Python version : {python_version} +Python platform : {python_platform} + +============================== + CUDA / GPU Info +============================== +Is CUDA available : {is_cuda_available} +CUDA runtime version : {cuda_runtime_version} +CUDA_MODULE_LOADING set to : {cuda_module_loading} +GPU models and configuration : {nvidia_gpu_models} +Nvidia driver version : {nvidia_driver_version} +cuDNN version : {cudnn_version} +HIP runtime version : {hip_runtime_version} +MIOpen runtime version : {miopen_runtime_version} +Is XNNPACK available : {is_xnnpack_available} + +============================== + CPU Info +============================== {cpu_info} -Versions of relevant libraries: +============================== +Versions of relevant libraries +============================== {pip_packages} {conda_packages} """.strip() @@ -671,17 +688,23 @@ def get_version_or_na(cfg, prefix): # both the above code and the following code use `strip()` to # remove leading/trailing whitespaces, so we need to add a newline # in between to separate the two sections -env_info_fmt += "\n" +env_info_fmt += "\n\n" env_info_fmt += """ -ROCM Version: {rocm_version} -Neuron SDK Version: {neuron_sdk_version} -vLLM Version: {vllm_version} +============================== + vLLM Info +============================== +ROCM Version : {rocm_version} +Neuron SDK Version : {neuron_sdk_version} +vLLM Version : {vllm_version} vLLM Build Flags: -{vllm_build_flags} + {vllm_build_flags} GPU Topology: -{gpu_topo} + {gpu_topo} +============================== + Environment Variables +============================== {env_vars} """.strip() @@ -792,4 +815,4 @@ def main(): if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 1917ed8bbebb..dc3e1482e2b4 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -7,6 +7,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from .vllm_inductor_pass import VllmInductorPass @@ -41,7 +42,7 @@ def empty_bf16(*args, **kwargs): def empty_fp8(*args, **kwargs): - fp8 = torch.float8_e4m3fn + fp8 = current_platform.fp8_dtype() return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a1ff5fb1196b..8114cddcd9fa 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -5,9 +5,8 @@ import os import pprint import time -from contextlib import ExitStack -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple -from unittest.mock import patch +from collections.abc import Sequence +from typing import Any, Callable, Optional import torch import torch.fx as fx @@ -15,17 +14,31 @@ import vllm.envs as envs from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors +from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname -from .compiler_interface import EagerAdaptor, InductorAdaptor +from .compiler_interface import (CompilerInterface, EagerAdaptor, + InductorAdaptor, InductorStandaloneAdaptor) from .counter import compilation_counter from .inductor_pass import InductorPass -from .monitor import end_monitoring_torch_compile from .pass_manager import PostGradPassManager logger = init_logger(__name__) +def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: + if compilation_config.use_inductor: + if envs.VLLM_TEST_STANDALONE_COMPILE: + logger.info("Using InductorStandaloneAdaptor") + return InductorStandaloneAdaptor() + else: + logger.info("Using InductorAdaptor") + return InductorAdaptor() + else: + logger.info("Using EagerAdaptor") + return EagerAdaptor() + + class CompilerManager: """ A manager to manage the compilation process, including @@ -41,11 +54,11 @@ class CompilerManager: support int as key. """ - def __init__(self, use_inductor: bool): - self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() - cls = InductorAdaptor if use_inductor else EagerAdaptor - self.compiler = cls() + def __init__(self, compilation_config: CompilationConfig): + self.cache: dict[tuple[Optional[int], int, str], Any] = dict() self.is_cache_updated = False + self.compilation_config = compilation_config + self.compiler = make_compiler(compilation_config) def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @@ -76,7 +89,7 @@ def save_to_file(self): def load(self, graph: fx.GraphModule, - example_inputs: List[Any], + example_inputs: list[Any], graph_index: int, runtime_shape: Optional[int] = None) -> Optional[Callable]: if (runtime_shape, graph_index, self.compiler.name) not in self.cache: @@ -123,8 +136,15 @@ def compile(self, # no compiler cached the graph, or the cache is disabled, # we need to compile it + if isinstance(self.compiler, InductorAdaptor): + # Let compile_fx generate a key for us + maybe_key = None + else: + maybe_key = \ + f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape) + graph, example_inputs, additional_inductor_config, runtime_shape, + maybe_key) assert compiled_graph is not None, "Failed to compile the graph" @@ -165,7 +185,7 @@ class SplitItem: def split_graph(graph: fx.GraphModule, - ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]: + ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 node_to_subgraph_id = {} @@ -231,7 +251,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): """ def __init__(self, module: torch.fx.GraphModule, - compile_submod_names: List[str], vllm_config: VllmConfig, + compile_submod_names: list[str], vllm_config: VllmConfig, graph_pool, vllm_backend: "VllmBackend"): super().__init__(module) from torch._guards import detect_fake_mode @@ -253,8 +273,8 @@ def run(self, *args): return super().run(*fake_args) def call_module(self, target: torch.fx.node.Target, - args: Tuple[torch.fx.node.Argument, - ...], kwargs: Dict[str, Any]) -> Any: + args: tuple[torch.fx.node.Argument, + ...], kwargs: dict[str, Any]) -> Any: assert isinstance(target, str) output = super().call_module(target, args, kwargs) @@ -275,7 +295,9 @@ def call_module(self, target: torch.fx.node.Target, num_graphs=len(self.compile_submod_names), runtime_shape=None) - self.module.__dict__[target] = PiecewiseBackend( + piecewise_backend = resolve_obj_by_qualname( + current_platform.get_piecewise_backend_cls()) + self.module.__dict__[target] = piecewise_backend( submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape, self.vllm_backend) @@ -305,12 +327,12 @@ class VllmBackend: graph: fx.GraphModule # the stiching graph module for all the piecewise graphs split_gm: fx.GraphModule - piecewise_graphs: List[SplitItem] + piecewise_graphs: list[SplitItem] returned_callable: Callable # Inductor passes to run on the graph pre-defunctionalization post_grad_passes: Sequence[Callable] - sym_tensor_indices: List[int] - input_buffers: List[torch.Tensor] + sym_tensor_indices: list[int] + input_buffers: list[torch.Tensor] compiler_manager: CompilerManager def __init__( @@ -319,7 +341,7 @@ def __init__( ): global global_graph_pool if global_graph_pool is None: - global_graph_pool = torch.cuda.graph_pool_handle() + global_graph_pool = current_platform.graph_pool_handle() # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. @@ -336,7 +358,7 @@ def __init__( self.compilation_config = vllm_config.compilation_config self.compiler_manager: CompilerManager = CompilerManager( - self.compilation_config.use_inductor) + self.compilation_config) # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -536,197 +558,3 @@ def copy_and_call(*args): return self.split_gm(*list_args) return copy_and_call - - -@dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int - need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in cudagraph_capture_sizes - - compiled: bool = False - runnable: Callable = None # type: ignore - num_finished_warmup: int = 0 - cudagraph: Optional[torch.cuda.CUDAGraph] = None - output: Optional[Any] = None - - # for cudagraph debugging, track the input addresses - # during capture, and check if they are the same during replay - input_addresses: Optional[List[int]] = None - - -class PiecewiseBackend: - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: List[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend): - """ - The backend for piecewise compilation. - It mainly handles the compilation and cudagraph capturing. - - We will compile `self.graph` once for the general shape, - and then compile for different shapes specified in - `compilation_config.compile_sizes`. - - Independently, we will capture cudagraph for different shapes. - - If a shape needs both compilation and cudagraph, we will - compile it first, and then capture cudagraph. - """ - self.graph = graph - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool - self.piecewise_compile_index = piecewise_compile_index - self.total_piecewise_compiles = total_piecewise_compiles - self.vllm_backend = vllm_backend - - self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = ( - piecewise_compile_index == total_piecewise_compiles - 1) - - self.compile_sizes: Set[int] = set( - self.compilation_config.compile_sizes) - self.cudagraph_capture_sizes: Set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() - - self.first_run_finished = False - - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - - self.sym_shape_indices = sym_shape_indices - - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - - # the entries for different shapes that we need to either - # compile or capture cudagraph - self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} - - # to_be_compiled_sizes tracks the remaining sizes to compile, - # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.cudagraph_capture_sizes, - ) - - def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: - # no specific sizes to compile - # save the hash of the inductor graph for the next run - self.vllm_backend.compiler_manager.save_to_file() - end_monitoring_torch_compile(self.vllm_config) - - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - - entry = self.concrete_size_entries[runtime_shape] - - if entry.runnable is None: - entry.runnable = self.compiled_graph_for_general_shape - - if entry.need_to_compile and not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) - # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( - self.graph, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=self.piecewise_compile_index, - num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape) - - # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() - - if not entry.use_cudagraph: - return entry.runnable(*args) - - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if self.is_first_graph: - logger.debug( - "Warming up %s/%s for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - runtime_shape) - return entry.runnable(*args) - - if self.is_first_graph: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every shape. - # We only log it in the debug mode. - logger.debug("Capturing a cudagraph for shape %s", - runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack() as stack: - if not self.is_first_graph: - # during every model forward, we will capture - # many pieces of cudagraphs (roughly one per layer). - # running gc again and again across layers will - # make the cudagraph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) - - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if self.is_last_graph: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last graph - # will not be used by any other cuda graph. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_caputured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) - - entry.cudagraph.replay() - return entry.output diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py new file mode 100644 index 000000000000..84d1e1f77739 --- /dev/null +++ b/vllm/compilation/base_piecewise_backend.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Protocol + +import torch.fx as fx + +from vllm.compilation.backends import VllmBackend +from vllm.config import VllmConfig + + +class AbstractPiecewiseBackend(Protocol): + """ + PiecewiseBackend interface that allows platforms to extend + piecewise static graph. + """ + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend, **kwargs): + """ + Initializes the PiecewiseBackend class with compilation and + execution-related configurations. + + This class handles piecewise compilation, graph capturing, + and dispatching for specific input shapes. + + Args: + graph (fx.GraphModule): The graph represented in fx. + vllm_config (VllmConfig): Global configuration for vLLM. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + piecewise_compile_index (int): + Index of the current piecewise subgraph. + total_piecewise_compiles (int): + Total number of piecewise-compiled graphs. + sym_shape_indices (list[int]): + Indices of symbolic shape. + compiled_graph_for_general_shape (Callable): + Callable that executes the graph compiled for general shapes. + vllm_backend (VllmBackend): + Backend compiler that manages compilation and graph runtime + for vLLM. + + Keyword Args: + kwargs: Additional keyword arguments reserved for future + extensions or custom platforms. + """ + raise NotImplementedError + + def __call__(self, *args) -> Any: + """Executes the compiled graph for given input args. + + If this is the first invocation, executes the general compiled graph + and initiates the compilation process tracking. For subsequent calls, + dynamically dispatches execution to either a compiled graph or a static + graph based on the input shape. + + Args: + *args: Variable length input arguments to be passed into the + graph. The symbolic shape is expected to be in position + `sym_shape_indices[0]`. + + Returns: + Any: Output of the executed graph. This can be from the general + compiled graph, a specialized compiled version for the given shape, + or a replayed static graph. + """ + raise NotImplementedError diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py new file mode 100644 index 000000000000..f651ee6912ab --- /dev/null +++ b/vllm/compilation/collective_fusion.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch.distributed._symmetric_memory import enable_symm_mem_for_group + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class BasePattern: + + def __init__(self, dtype: torch.dtype, device: str): + self.dtype = dtype + self.device = device + self.tp = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + +class GEMMReduceScatterPattern(BasePattern): + + def get_inputs(self): + mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [mul, mm_weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): + mm = torch.ops.aten.mm.default(mul, mm_weight) + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + return reduce_scatter + + def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): + gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( + mul, + mm_weight, + "avg", + scatter_dim=0, + group_name=self.tp.device_group.group_name, + ) + + return gemm_rs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllGatherGEMMPattern(BasePattern): + + def get_inputs(self): + x = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + return [x, weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_gather = torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + + return torch.ops.aten.mm.default(all_gather, weight) + + def replacement( + x: torch.Tensor, + weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( + x, + [weight], + gather_dim=0, + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AsyncTPPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + + # Enable symmetric memory for the TP process group + enable_symm_mem_for_group(get_tp_group().device_group.group_name) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="async_tp_pass") + GEMMReduceScatterPattern(self.model_dtype, + self.device).register(self.patterns) + + AllGatherGEMMPattern(self.model_dtype, + self.device).register(self.patterns) + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + # only do replace for specific shapes + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_async_tp_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_async_tp_pass") + self.end_and_log() diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b7e7a79bef0b..21af5eb76ee8 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -4,7 +4,7 @@ import hashlib import os from contextlib import ExitStack -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional from unittest.mock import patch import torch @@ -39,7 +39,8 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: Gather all the relevant information from the vLLM config, to compute a hash so that we can cache the compiled model. - See {meth}`VllmConfig.compute_hash` to check what information + See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash] + to check what information is already considered by default. This function should only consider the information that is specific to the compiler. """ @@ -48,10 +49,11 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: def compile( self, graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None - ) -> Tuple[Optional[Callable], Optional[Any]]: + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: """ Compile the graph with the given example inputs and compiler config, with a runtime shape. If the `runtime_shape` is None, it means @@ -71,13 +73,17 @@ def compile( If the compiler doesn't support caching, it should return None for the handle. If the compiler fails to compile the graph, it should return None for the compiled function as well. + + `key` is required for StandaloneInductorAdapter, it specifies where to + save the compiled artifact. The compiled artifact gets saved to + `cache_dir/key`. """ return None, None def load(self, handle: Any, graph: fx.GraphModule, - example_inputs: List[Any], + example_inputs: list[Any], graph_index: int, runtime_shape: Optional[int] = None) -> Callable: """ @@ -115,7 +121,7 @@ class AlwaysHitShapeEnv: """ def __init__(self) -> None: - self.guards: List[Any] = [] + self.guards: list[Any] = [] def evaluate_guards_expression(self, *args, **kwargs): return True @@ -127,23 +133,108 @@ def produce_guards_expression(self, *args, **kwargs): return "" +def get_inductor_factors() -> list[Any]: + factors: list[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + return factors + + +class InductorStandaloneAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler. + Requires PyTorch 2.8+. + This is not on by default yet, but we plan to turn it on by default for + PyTorch 2.8. + + Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off. + """ + name = "inductor_standalone" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.cache_dir = cache_dir + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + set_inductor_config(current_config, runtime_shape) + + if isinstance(runtime_shape, int): + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_tracing_context" + + from torch._inductor import standalone_compile + with pass_context(runtime_shape): + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}) + + # Save the compiled artifact to disk in the specified path + assert key is not None + path = os.path.join(self.cache_dir, key) + compiled_graph.save(path=path, format="unpacked") + return compiled_graph, (key, path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + path = handle[1] + inductor_compiled_graph = torch._inductor.CompiledArtifact.load( + path=path, format="unpacked") + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + def compiled_graph_wrapper(*args): + graph_output = inductor_compiled_graph(*args) + # unpack the tuple if needed + # TODO(rzou): the implication is that we're not + # reading the python bytecode correctly in vLLM? + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph_wrapper + + class InductorAdaptor(CompilerInterface): """ - The adaptor for the Inductor compiler, version 2.5 and 2.6. + The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. """ name = "inductor" def compute_hash(self, vllm_config: VllmConfig) -> str: - factors: List[Any] = [] - # summarize system state - from torch._inductor.codecache import CacheBase - system_factors = CacheBase.get_system() - factors.append(system_factors) - - # summarize pytorch state - from torch._inductor.codecache import torch_key - torch_factors = torch_key() - factors.append(torch_factors) + factors = get_inductor_factors() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()[:10] return hash_str @@ -166,25 +257,21 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False): def compile( self, graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None - ) -> Tuple[Optional[Callable], Optional[Any]]: - current_config = {} + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: from torch._inductor.compile_fx import compile_fx + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) # disable remote cache current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - if compiler_config is not None: - current_config.update(compiler_config) - - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters - # can be beneficial - current_config["max_autotune"] = True - current_config["coordinate_descent_tuning"] = True + set_inductor_config(current_config, runtime_shape) # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -334,7 +421,7 @@ def _get_shape_env() -> AlwaysHitShapeEnv: def load(self, handle: Any, graph: fx.GraphModule, - example_inputs: List[Any], + example_inputs: list[Any], graph_index: int, runtime_shape: Optional[int] = None) -> Callable: assert isinstance(handle, tuple) @@ -422,16 +509,25 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() +def set_inductor_config(config, runtime_shape): + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + config["max_autotune"] = True + config["coordinate_descent_tuning"] = True + + class EagerAdaptor(CompilerInterface): name = "eager" def compile( self, graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None - ) -> Tuple[Optional[Callable], Optional[Any]]: + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. return graph, None diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py new file mode 100644 index 000000000000..0ad480e28cd7 --- /dev/null +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.backends import VllmBackend +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import end_monitoring_torch_compile +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +class CUDAPiecewiseBackend: + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) + + self.compile_sizes: set[int] = set( + self.compilation_config.compile_sizes) + self.cudagraph_capture_sizes: set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.vllm_backend.compiler_manager.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + if not entry.use_cudagraph: + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.is_first_graph: + logger.debug( + "Warming up %s/%s for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + runtime_shape) + return entry.runnable(*args) + + if self.is_first_graph: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + logger.debug("Capturing a cudagraph for shape %s", + runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_caputured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 20afe6967df3..f02994c55527 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from typing import Callable, Dict, List, Optional, TypeVar, Union, overload +from typing import Callable, Optional, TypeVar, Union, overload from unittest.mock import patch import torch @@ -25,7 +25,7 @@ @overload def support_torch_compile( *, - dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]], + dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]], ) -> Callable[[_T], _T]: ... @@ -38,7 +38,7 @@ def support_torch_compile(cls: _T) -> _T: def support_torch_compile( cls: Optional[_T] = None, *, - dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None, + dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -131,7 +131,7 @@ def cls_decorator_helper(cls: _T) -> _T: def _support_torch_compile( cls: _T, - dynamic_arg_dims: Dict[str, Union[int, List[int]]], + dynamic_arg_dims: dict[str, Union[int, list[int]]], ) -> _T: """ A decorator to add support for compiling the forward method of a class. diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 7f3120660329..70f3b8b6df94 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import operator -from typing import Dict, Iterable, List, Optional, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -27,7 +28,7 @@ def __call__(self, graph: torch.fx.Graph): self.begin() self.dump_graph(graph, "before_fix_functionalization") - self.nodes_to_remove: List[torch.fx.Node] = [] + self.nodes_to_remove: list[torch.fx.Node] = [] count = 0 for node in graph.nodes: if not is_func(node, auto_functionalized): @@ -117,8 +118,8 @@ def _remove(self, node_or_nodes: Union[torch.fx.Node, def defunctionalize(self, graph: torch.fx.Graph, node: torch.fx.Node, - mutated_args: Dict[int, Union[torch.fx.Node, str]], - args: Optional[Tuple[Union[torch.fx.Node, str], + mutated_args: dict[int, Union[torch.fx.Node, str]], + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None): """ De-functionalize a node by replacing it with a call to the original. @@ -130,7 +131,7 @@ def defunctionalize(self, self._remove(node) def replace_users_with_mutated_args(self, node: torch.fx.Node, - mutated_args: Dict[int, + mutated_args: dict[int, Union[torch.fx.Node, str]]): """ @@ -146,7 +147,7 @@ def replace_users_with_mutated_args(self, node: torch.fx.Node, user.replace_all_uses_with(arg) self._remove(user) - def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]: + def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]: """ Returns the operator.getitem users of the auto-functionalized node, indexed by the index they are getting. @@ -161,7 +162,7 @@ def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]: def insert_defunctionalized(self, graph: torch.fx.Graph, node: torch.fx.Node, - args: Optional[Tuple[Union[torch.fx.Node, str], + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None): """ Insert a new defunctionalized node into the graph before node. diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8f32fdb03f8b..618b2fe94d3a 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple +from typing import Callable, NamedTuple, Optional import torch import torch._inductor.pattern_matcher as pm @@ -57,7 +57,7 @@ def __str__(self): kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) -QUANT_OPS: Dict[QuantKey, OpOverload] = { +QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa @@ -80,7 +80,7 @@ def __str__(self): f"{'' if self.fused_add else 'out'} residual)") -FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = { +FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { FusedRMSQuantKey(kFp8StaticTensorSym, False): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa FusedRMSQuantKey(kFp8StaticTensorSym, True): @@ -101,7 +101,7 @@ def __init__(self, match: pm.Match, quant_op, fused_op): self.QUANT_OP = quant_op # in-place quant op self.FUSED_OP = fused_op # in-place fused quant op - def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, + def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node, int]], **kwargs): """ @@ -548,7 +548,7 @@ def __init__(self, config: VllmConfig): "FusionPass singleton instance already exists" super().__init__(config) - self.matches: List[MultiOutputMatch] = [] + self.matches: list[MultiOutputMatch] = [] self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="fusion_pass") diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index f9427e48ac31..b9eeb0c8d2af 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import operator -from typing import Iterable, Optional +from collections.abc import Iterable +from typing import Optional from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 6cd7720fca2f..a9359fe1e117 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -5,7 +5,7 @@ import json import types from contextlib import contextmanager -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch import fx @@ -16,7 +16,7 @@ from torch._inductor.custom_graph_pass import CustomGraphPass else: # CustomGraphPass is not present in 2.5 or lower, import our version - from .torch25_custom_graph_pass import ( # noqa: yapf + from .torch25_custom_graph_pass import ( # noqa: E501 Torch25CustomGraphPass as CustomGraphPass) _pass_context = None @@ -83,7 +83,7 @@ def hash_source(*srcs: Union[str, Any]): return hasher.hexdigest() @staticmethod - def hash_dict(dict_: Dict[Any, Any]): + def hash_dict(dict_: dict[Any, Any]): """ Utility method to hash a dictionary, can alternatively be used for uuid. :return: A sha256 hash of the json rep of the dictionary. diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py index e6f6a60b2595..cef19f9257ed 100644 --- a/vllm/compilation/multi_output_match.py +++ b/vllm/compilation/multi_output_match.py @@ -3,7 +3,7 @@ import abc import operator from abc import abstractmethod -from typing import Iterable, List, Tuple +from collections.abc import Iterable from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -56,7 +56,7 @@ def process(self): raise NotImplementedError @property - def nodes(self) -> List[fx.Node]: + def nodes(self) -> list[fx.Node]: return self.match.nodes @property @@ -87,7 +87,7 @@ def inserting_after_match(self): return self.graph.inserting_after(last_node_in_match) def insert_getitems(self, tuple_node: fx.Node, - indices: Iterable[int]) -> Tuple[fx.Node, ...]: + indices: Iterable[int]) -> tuple[fx.Node, ...]: """ Insert operator.getitem nodes to extract elements from a tuple node. diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 19127e933ec4..13e4cd73f8ce 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Union +from collections.abc import Iterable +from typing import Union import torch.fx from torch import SymInt diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index b1646914c7ed..07ebd3e1b7dd 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - from torch import fx as fx from vllm.config import VllmConfig from vllm.logger import init_logger from .activation_quant_fusion import ActivationQuantFusionPass +from .collective_fusion import AsyncTPPass from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context @@ -34,7 +33,7 @@ class PostGradPassManager(CustomGraphPass): """ def __init__(self): - self.passes: List[VllmInductorPass] = [] + self.passes: list[VllmInductorPass] = [] def __call__(self, graph: fx.Graph): shape = get_pass_context().runtime_shape @@ -56,6 +55,8 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_sequence_parallelism: self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_async_tp: + self.passes += [AsyncTPPass(config)] self.fix_functionalization = FixFunctionalizationPass(config) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 95db63d34f7e..17dded87fe8d 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch import torch._inductor.pattern_matcher as pm @@ -125,7 +125,7 @@ def pattern( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = tensor_model_parallel_all_reduce(mm_1) rmsnorm = torch.ops.higher_order.auto_functionalized( @@ -142,7 +142,7 @@ def replacement( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: tp = get_tp_group() tp_size = get_tensor_model_parallel_world_size() reduce_scatter = torch.ops.vllm.reduce_scatter.default( @@ -190,7 +190,7 @@ def pattern( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = tensor_model_parallel_all_reduce(mm_1) rmsnorm = torch.ops.higher_order.auto_functionalized( @@ -207,7 +207,7 @@ def replacement( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: tp = get_tp_group() tp_size = get_tensor_model_parallel_world_size() reduce_scatter = torch.ops.vllm.reduce_scatter.default( @@ -243,24 +243,25 @@ def __init__(self, config: VllmConfig): pass_name="sequence_parallelism_pass") for epsilon in [1e-5, 1e-6]: EmbeddingAllReduceRMSNormPattern( - epsilon, self.dtype, self.device).register(self.patterns) + epsilon, self.model_dtype, self.device).register(self.patterns) - MiddleAllReduceRMSNormPattern(epsilon, self.dtype, + MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) - LastAllReduceRMSNormPattern(epsilon, self.dtype, + LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() def is_applicable_for_shape(self, shape: Optional[int]) -> bool: - # only do replace for specific shapes tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 def __call__(self, graph: fx.Graph): + self.begin() self.dump_graph(graph, "before_sequence_parallelism_pass") count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", count) self.dump_graph(graph, "after_sequence_parallelism_pass") + self.end_and_log() diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index e8bffb406f14..0fe73b72b1de 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -4,7 +4,7 @@ import torch -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import PassConfig, VllmConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( @@ -26,7 +26,8 @@ class VllmInductorPass(InductorPass): def __init__(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config - self.dtype = config.model_config.dtype if config.model_config else None + self.model_dtype = config.model_config.dtype if config.model_config \ + else None self.device = config.device_config.device if config.device_config \ else None self.pass_name = self.__class__.__name__ @@ -56,10 +57,7 @@ def end_and_log(self): class PrinterInductorPass(VllmInductorPass): - def __init__(self, - name: str, - config: CompilationConfig.PassConfig, - always=False): + def __init__(self, name: str, config: PassConfig, always=False): super().__init__(config) self.name = name self.always = always diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index a8a283ddd8c0..1a8211f0ab7c 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -5,7 +5,7 @@ from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Callable, List, Optional +from typing import Callable, Optional import torch @@ -48,7 +48,7 @@ def __init__(self, self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ - self.compiled_codes: List[CodeType] = [] + self.compiled_codes: list[CodeType] = [] torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) # read the env var to determine whether to use the custom dispatcher diff --git a/vllm/config.py b/vllm/config.py index fca2865f85d5..db35c848b33a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,22 +6,21 @@ import hashlib import inspect import json -import re -import sys import textwrap +import uuid import warnings from collections import Counter from contextlib import contextmanager -from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, - replace) +from dataclasses import (MISSING, Field, asdict, dataclass, field, fields, + is_dataclass, replace) from functools import cached_property from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args, get_origin) +import regex as re import torch -from pydantic import BaseModel, Field, PrivateAttr from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated @@ -34,7 +33,7 @@ QuantizationMethods, get_quantization_config) from vllm.model_executor.models import ModelRegistry -from vllm.platforms import CpuArchEnum, current_platform +from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -43,7 +42,10 @@ try_get_generation_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect -from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, +from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, + LayerBlockType, cuda_device_count_stateless, get_cpu_memory, get_open_port, is_torch_equal_or_newer, random_uuid, resolve_obj_by_qualname) @@ -58,19 +60,13 @@ ConfigType = type[DataclassInstance] else: - QuantizationConfig = None + QuantizationConfig = Any ConfigType = type logger = init_logger(__name__) ConfigT = TypeVar("ConfigT", bound=ConfigType) -# This value is chosen to have a balance between ITL and TTFT. Note it is -# not optimized for throughput. -_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 -_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 -_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 - TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", "score", "reward", "transcription"] @@ -170,6 +166,12 @@ def config(cls: ConfigT) -> ConfigT: """ A decorator that ensures all fields in a dataclass have default values and that each field has a docstring. + + If a `ConfigT` is used as a CLI argument itself, the default value provided + by `get_kwargs` will be the result parsing a JSON string as the kwargs + (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` + requires custom construction from CLI (i.e. `CompilationConfig`), it can + have a `from_cli` method, which will be called instead. """ if not is_dataclass(cls): raise TypeError("The decorated class must be a dataclass.") @@ -203,7 +205,7 @@ def get_field(cls: ConfigType, name: str) -> Field: cls_fields = {f.name: f for f in fields(cls)} if name not in cls_fields: raise ValueError(f"Field '{name}' not found in {cls.__name__}.") - named_field: Field = cls_fields.get(name) + named_field: Field = cls_fields[name] if (default_factory := named_field.default_factory) is not MISSING: return field(default_factory=default_factory) if (default := named_field.default) is not MISSING: @@ -212,6 +214,10 @@ def get_field(cls: ConfigType, name: str) -> Field: f"{cls.__name__}.{name} must have a default value or default factory.") +def is_init_field(cls: ConfigType, name: str) -> bool: + return next(f for f in fields(cls) if f.name == name).init + + TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] @@ -252,7 +258,8 @@ class ModelConfig: - "float" is shorthand for FP32 precision.\n - "float32" for FP32 precision.""" seed: Optional[int] = None - """Random seed for reproducibility.""" + """Random seed for reproducibility. Initialized to None in V0, but + initialized to 0 in V1.""" hf_config_path: Optional[str] = None """Name or path of the Hugging Face config to use. If unspecified, model name or path will be used.""" @@ -287,7 +294,7 @@ class ModelConfig: - 1K -> 1024\n - 25.6k -> 25,600""" spec_target_max_model_len: Optional[int] = None - """Specify the the maximum length for spec decoding draft models.""" + """Specify the maximum length for spec decoding draft models.""" quantization: Optional[QuantizationMethods] = None """Method used to quantize the weights. If `None`, we first check the `quantization_config` attribute in the model config file. If that is @@ -432,6 +439,24 @@ def compute_hash(self) -> str: return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: + # Set the default seed to 0 in V1. + # NOTE(woosuk): In V0, we set the default seed to None because the + # driver worker shares the same process as the user process, and thus + # setting a seed affects the user process as well. + # In V1, we use separate processes for workers (unless + # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here + # doesn't affect the user process. However, without a consistent seed, + # different tensor parallel workers would sample different tokens, + # leading to inconsistent results. + if envs.VLLM_USE_V1 and self.seed is None: + self.seed = 0 + if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: + logger.warning( + "The global random seed is set to %d. Since " + "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " + "affect the random state of the Python process that " + "launched vLLM.", self.seed) + self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. if self.tokenizer is None: @@ -508,13 +533,19 @@ def __post_init__(self) -> None: self.model, hf_token=self.hf_token, revision=self.revision) self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype) - interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] + # Workaround for Gemma 2 which uses interleaved sliding window + # attention, but it's not specified in its config. TODO: remove this + # when Gemma 2 is fixed in Transformers. + if self.hf_text_config.model_type == "gemma2": + self.hf_text_config.sliding_window_pattern = 2 + sliding_window = getattr(self.hf_text_config, "sliding_window", None) - has_interleaved_attention = (sliding_window is not None) and ( - isinstance(sliding_window, list) or - (self.hf_text_config.model_type in interleaved_attn_models)) + sliding_window_pattern = getattr(self.hf_text_config, + "sliding_window_pattern", None) + has_interleaved_attention = sliding_window_pattern is not None or ( + isinstance(sliding_window, list)) - if (not self.disable_sliding_window and has_interleaved_attention): + if not self.disable_sliding_window and has_interleaved_attention: if (backend := envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): sliding_window_len_min = get_min_sliding_window( @@ -534,7 +565,10 @@ def __post_init__(self) -> None: # only the attention layer itself is aware of the sliding # window, and use the window size to compute the attention. self.hf_text_config.interleaved_sliding_window = sliding_window - delattr(self.hf_text_config, "sliding_window") + + if hasattr(self.hf_text_config, "sliding_window"): + delattr(self.hf_text_config, "sliding_window") + sliding_window = None self.max_model_len = _get_and_verify_max_len( @@ -583,28 +617,35 @@ def architectures(self) -> list[str]: def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: + """Pull model/tokenizer from S3 to temporary directory when needed. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path """ - Pull the model config or tokenizer to a temporary - directory in case of S3. + if not (is_s3(model) or is_s3(tokenizer)): + return - Args: - model: The model name or path. - tokenizer: The tokenizer name or path. + if is_s3(model): + s3_model = S3Model() + s3_model.pull_files(model, + allow_pattern=["*.model", "*.py", "*.json"]) + self.model_weights = model + self.model = s3_model.dir - """ - if is_s3(model) or is_s3(tokenizer): - if is_s3(model): - s3_model = S3Model() + # If tokenizer is same as model, download to same directory + if model == tokenizer: s3_model.pull_files( - model, allow_pattern=["*.model", "*.py", "*.json"]) - self.model_weights = self.model - self.model = s3_model.dir - - if is_s3(tokenizer): - s3_tokenizer = S3Model() - s3_tokenizer.pull_files( model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) - self.tokenizer = s3_tokenizer.dir + self.tokenizer = s3_model.dir + return + + # Only download tokenizer if needed and not already handled + if is_s3(tokenizer): + s3_tokenizer = S3Model() + s3_tokenizer.pull_files( + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = s3_tokenizer.dir def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self.registry.is_multimodal_model(self.architectures): @@ -789,7 +830,7 @@ def _verify_quantization(self) -> None: optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", - "quark", "nvfp4", "bitblas", "gptq_bitblas" + "quark", "modelopt_fp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: self.quantization = cast(QuantizationMethods, @@ -871,12 +912,17 @@ def _verify_quantization(self) -> None: def _verify_cuda_graph(self) -> None: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) + # CUDAGraph capture not supported for enc-dec models and mllama on ROCm ROCM_UNSUPPORTED_MODELS = ['mllama'] - if (self.hf_config.model_type in ROCM_UNSUPPORTED_MODELS - and not self.enforce_eager and current_platform.is_rocm()): + unsupported_rocm = (self.hf_config.model_type + in ROCM_UNSUPPORTED_MODELS + or self.is_encoder_decoder) + + if (unsupported_rocm and not self.enforce_eager + and current_platform.is_rocm()): logger.warning( "CUDA graph is not supported for %s on ROCm yet, fallback " - "to the eager mode.", self.hf_config.model_type) + "to eager mode.", self.hf_config.model_type) self.enforce_eager = True def _verify_bnb_config(self) -> None: @@ -920,6 +966,23 @@ def _verify_with_expert_parallelism(self) -> None: "Number of experts in the model must be greater than 0 " "when expert parallelism is enabled.") + def verify_dual_chunk_attention_config( + self, + load_config: "LoadConfig", + ) -> None: + if hasattr(self.hf_config, "dual_chunk_attention_config"): + # Try loading the sparse attention config + from vllm.model_executor.model_loader.weight_utils import ( + get_sparse_attention_config) + sparse_attn_config = get_sparse_attention_config(self, load_config) + if sparse_attn_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_config"] = sparse_attn_config + if "sparse_attention_enabled" not in \ + self.hf_config.dual_chunk_attention_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_enabled"] = True + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -930,7 +993,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - # Reminder: Please update docs/source/features/compatibility_matrix.md + # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid from vllm.platforms import current_platform if not current_platform.is_async_output_supported(self.enforce_eager): @@ -946,7 +1009,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, if self.runner_type == "pooling": self.use_async_output_proc = False - # Reminder: Please update docs/source/features/compatibility_matrix.md + # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid if speculative_config: self.use_async_output_proc = False @@ -1130,7 +1193,8 @@ def get_num_attention_heads(self, def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices - if self.hf_text_config.model_type == "deepseek_mtp": + if (self.hf_text_config.model_type == "deepseek_mtp" + or self.hf_config.model_type == "mimo_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1622,30 +1686,21 @@ class ParallelConfig: data_parallel_size: int = 1 """Number of data parallel groups. MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.""" + data_parallel_size_local: int = 1 + """Number of local data parallel groups.""" data_parallel_rank: int = 0 """Rank of the data parallel group.""" - _data_parallel_rank_local: Optional[int] = field(default=None, init=False) - """Private field to store the local rank of the data parallel group.""" - - @property - def data_parallel_rank_local(self) -> int: - """Local rank of the data parallel group, defaults to global rank.""" - if self._data_parallel_rank_local is None: - return self.data_parallel_rank - return self._data_parallel_rank_local - - @data_parallel_rank_local.setter - def data_parallel_rank_local(self, value: int) -> None: - """Set the local rank of the data parallel group.""" - self._data_parallel_rank_local = value - + data_parallel_rank_local: Optional[int] = None + """Local rank of the data parallel group, + set only in SPMD mode.""" data_parallel_master_ip: str = "127.0.0.1" """IP of the data parallel master.""" + data_parallel_rpc_port: int = 29550 + """Port for data parallel messaging.""" data_parallel_master_port: int = 29500 """Port of the data parallel master.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" - max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor @@ -1688,13 +1743,16 @@ class is dynamically inherited by the worker class. This is used to inject world_size: int = field(init=False) """world_size is TPxPP, it affects the number of workers we create.""" - world_size_across_dp: int = field(init=False) - """world_size_across_dp is TPxPPxDP, it is the size of the world - including data parallelism.""" rank: int = 0 """Global rank in distributed setup.""" + @property + def world_size_across_dp(self) -> int: + """world_size_across_dp is TPxPPxDP, it is the size of the world + including data parallelism.""" + return self.world_size * self.data_parallel_size + def get_next_dp_init_port(self) -> int: """ We might need to initialize process groups in multiple @@ -1754,10 +1812,14 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size - if self.data_parallel_size > 1: + if self.data_parallel_size_local > self.data_parallel_size: + raise ValueError( + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})") + + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() - # TODO multi-node else: # Otherwise fall back to env vars (e.g. for offline SPMD case). self.data_parallel_size = envs.VLLM_DP_SIZE @@ -1766,8 +1828,6 @@ def __post_init__(self) -> None: self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT - self.world_size_across_dp = self.world_size * self.data_parallel_size - if self.distributed_executor_backend == "external_launcher": import os os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" @@ -2008,15 +2068,9 @@ def compute_hash(self) -> str: def __post_init__(self) -> None: if self.max_model_len is None: self.max_model_len = 8192 - logger.warning( - "max_model_len was is not set. Defaulting to arbitrary value " - "of %d.", self.max_model_len) if self.max_num_seqs is None: self.max_num_seqs = 128 - logger.warning( - "max_num_seqs was is not set. Defaulting to arbitrary value " - "of %d.", self.max_num_seqs) if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: @@ -2026,30 +2080,37 @@ def __post_init__(self) -> None: # so we don't reject sequences on account of a short # max_num_batched_tokens. self.max_num_batched_tokens = max( - self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) else: self.max_num_batched_tokens = ( - _DEFAULT_MAX_NUM_BATCHED_TOKENS) + DEFAULT_MAX_NUM_BATCHED_TOKENS) else: # If max_model_len is too short, use - # _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value + # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value # for higher throughput. self.max_num_batched_tokens = max( - self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.runner_type == "pooling": # Choose specific value for higher throughput self.max_num_batched_tokens = max( self.max_num_batched_tokens, - _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, ) if self.is_multimodal_model: # The value needs to be at least the number of multimodal tokens self.max_num_batched_tokens = max( self.max_num_batched_tokens, - _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, ) + # When using default settings, + # Ensure max_num_batched_tokens does not exceed model limit. + # Some models (e.g., Whisper) have embeddings tied to max length. + self.max_num_batched_tokens = min( + self.max_num_seqs * self.max_model_len, + self.max_num_batched_tokens) + self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens @@ -2090,6 +2151,13 @@ def _verify_args(self) -> None: "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).") + if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: + logger.warning( + "max_num_batched_tokens (%d) exceeds max_num_seqs" + "* max_model_len (%d). This may lead to unexpected behavior.", + self.max_num_batched_tokens, + self.max_num_seqs * self.max_model_len) + if self.num_lookahead_slots < 0: raise ValueError( "num_lookahead_slots " @@ -2139,7 +2207,11 @@ class DeviceConfig: """Configuration for the device to use for vLLM execution.""" device: Union[Device, torch.device] = "auto" - """Device type for vLLM execution.""" + """Device type for vLLM execution. + This parameter is deprecated and will be + removed in a future release. + It will now be set automatically based + on the current platform.""" device_type: str = field(init=False) """Device type from the current platform. This is set in `__post_init__`.""" @@ -2189,7 +2261,7 @@ def __post_init__(self): SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", - "draft_model"] + "draft_model", "deepseek_mtp"] SpeculativeAcceptanceMethod = Literal["rejection_sampler", "typical_acceptance_sampler"] @@ -2274,7 +2346,7 @@ class SpeculativeConfig: `TypicalAcceptanceSampler`.""" speculative_token_tree: Optional[str] = None - """Specifies the tree structure for speculative token generation. + """Specifies the tree structure for speculative token generation. """ # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, @@ -2334,6 +2406,17 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: "n_predict": n_predict, "architectures": ["DeepSeekMTPModel"] }) + + if hf_config.architectures[0] == "MiMoForCausalLM": + hf_config.model_type = "mimo_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"] + }) + return hf_config + return hf_config def __post_init__(self): @@ -2350,8 +2433,10 @@ def __post_init__(self): # TODO(Shangming): Refactor mtp configuration logic when supporting # mtp acceleration for more models besides deepseek_v3 if self.target_model_config and \ - self.target_model_config.hf_text_config.model_type \ - == "deepseek_v3": + (self.target_model_config.hf_text_config.model_type \ + == "deepseek_v3" or + self.target_model_config.hf_text_config.model_type \ + == "mimo"): # use the draft model from the same model: self.model = self.target_model_config.model elif self.method in ("ngram", "[ngram]"): @@ -2440,6 +2525,15 @@ def __post_init__(self): elif (self.draft_model_config.hf_config.model_type == "mlp_speculator"): self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type == + "deepseek_mtp"): + self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" @@ -2450,11 +2544,10 @@ def __post_init__(self): "Chunked prefill and EAGLE are not compatible " "when using V0.") - from vllm.platforms import current_platform from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) if isinstance(self.draft_model_config.hf_config, - EAGLEConfig) or current_platform.is_neuron(): + EAGLEConfig): pass else: eagle_config = EAGLEConfig( @@ -2660,7 +2753,7 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3") + return self.method in ("eagle", "eagle3", "deepseek_mtp") def __repr__(self) -> str: method = self.method @@ -2827,8 +2920,8 @@ def verify_with_model_config(self, model_config: ModelConfig): class MultiModalConfig: """Controls the behavior of multimodal models.""" - limit_per_prompt: dict[str, int] = get_field(ModelConfig, - "limit_mm_per_prompt") + limit_per_prompt: dict[str, int] = \ + cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt")) """ The maximum number of input items allowed per prompt for each modality. Defaults to 1 (V0) or 999 (V1) for each modality. @@ -2893,7 +2986,7 @@ class PoolerConfig: pooling_type: Optional[str] = None """ The pooling method of the pooling model. This should be a key in - {class}`vllm.model_executor.layers.pooler.PoolingType`. + [`vllm.model_executor.layers.pooler.PoolingType`][]. """ normalize: Optional[bool] = None @@ -2974,6 +3067,7 @@ def _get_and_verify_dtype( if isinstance(dtype, str): dtype = dtype.lower() if dtype == "auto": + # Set default dtype from model config if config_dtype == torch.float32: # Following common practice, we use float16 for float32 models torch_dtype = torch.float16 @@ -2981,37 +3075,33 @@ def _get_and_verify_dtype( torch_dtype = config_dtype if config.model_type == "plamo2": - logger.info( + logger.warning( "For PLaMo2, we cast models to bfloat16 instead of using " "float16 by default. This is because float16 does not work." ) torch_dtype = torch.bfloat16 + # Deal with torch dtype fallback for device compatibility. from vllm.platforms import current_platform - if (current_platform.is_cpu() - and current_platform.get_cpu_architecture() - == CpuArchEnum.POWERPC - and (config_dtype == torch.float16 - or config_dtype == torch.float32)): - logger.info( - "For POWERPC, we cast models to bfloat16 instead of " - "using float16 by default. Float16 is not currently " - "supported for POWERPC.") - torch_dtype = torch.bfloat16 + if torch_dtype not in current_platform.supported_dtypes: + device_name = current_platform.get_device_name() - # TODO: change this condition to check if the platform support bf16 - # instead of checking the OS. For instance M2 shall supports bf16 - # already. But we need to modify `cpu_extension.cmake` to activate - # the feature in the build. - if (current_platform.is_cpu() and sys.platform.startswith("darwin") - and current_platform.get_cpu_architecture() - == CpuArchEnum.ARM and config_dtype == torch.bfloat16): - logger.info("For macOS with Apple Silicon, currently bfloat16 " - "is not supported. Setting dtype to float16.") - torch_dtype = torch.float16 + if ((capability := current_platform.get_device_capability()) + is None): + compute_str = "" + else: + version_str = capability.as_version_str() + compute_str = f" (with compute capability {version_str})" + fallback_dtype = current_platform.supported_dtypes[0] + logger.warning( + "Your %s device%s doesn't support %s. " \ + "Falling back to %s for compatibility.", + device_name, compute_str, torch_dtype, fallback_dtype + ) + torch_dtype = fallback_dtype - if current_platform.is_hpu() and config_dtype == torch.float16: - logger.info( + if current_platform.is_hpu() and torch_dtype == torch.float16: + logger.warning( "For HPU, we cast models to bfloat16 instead of " "using float16 by default. Please specify `dtype` if you " "want to use float16.") @@ -3405,41 +3495,56 @@ def _parse_collect_detailed_traces(self): self.collect_detailed_traces[0].split(",")) -class KVTransferConfig(BaseModel): +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: """Configuration for distributed KV cache transfer.""" - # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ + + engine_id: Optional[str] = None + """The engine id for KV transfers.""" - # The device used by kv connector to buffer the KV cache. - # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. + Currently only support 'cuda'.""" - # The buffer size for TorchDistributedConnector. Measured in number of - # bytes. Recommended value: 1e9 (about 1GB). kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" - # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. - kv_role: Optional[str] = None + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" - # The rank of this vLLM instance in the KV cache transfer. Typical value: - # 0 for prefill instance, 1 for decode instance. - # Currently only 1P1D is supported. kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" - # The number of parallel instances for KV cache transfer. For - # PyNcclConnector, this should be 2. kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + PyNcclConnector, this should be 2.""" - # The KV connector ip, used to build distributed connection kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" - # The KV connector port, used to build distributed connection kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" + + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" - # any extra config that the connector may need - kv_connector_extra_config: dict[str, Any] = {} + kv_connector_module_path: Optional[str] = None + """The Python module path to dynamically load the KV connector from. + Only supported in V1.""" def compute_hash(self) -> str: """ @@ -3460,46 +3565,40 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - @classmethod - def from_cli(cls, cli_value: str) -> "KVTransferConfig": - """Parse the CLI value for the kv cache transfer config.""" - return KVTransferConfig.model_validate_json(cli_value) - - def model_post_init(self, __context: Any) -> None: + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) - if self.kv_role is not None and self.kv_role not in [ - "kv_producer", "kv_consumer", "kv_both" - ]: - raise ValueError( - f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are `kv_producer`, `kv_consumer`, " - f"and `kv_both`") + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError(f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}") if self.kv_connector is not None and self.kv_role is None: raise ValueError("Please specify kv_disagg_role when kv_connector " - "is set, supported roles are `kv_producer`, " - "`kv_consumer`, and `kv_both`") + f"is set, supported roles are {get_args(KVRole)}") @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + self.kv_role in get_args(KVRole) @property def is_kv_producer(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_producer", "kv_both"] + self.kv_role in get_args(KVProducer) @property def is_kv_consumer(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_consumer", "kv_both"] + self.kv_role in get_args(KVConsumer) def get_from_extra_config(self, key, default) -> Any: return self.kv_connector_extra_config.get(key, default) -class KVEventsConfig(BaseModel): +@config +@dataclass +class KVEventsConfig: """Configuration for KV event publishing.""" enable_kv_cache_events: bool = False @@ -3538,11 +3637,6 @@ class KVEventsConfig(BaseModel): this topic to receive events. """ - @classmethod - def from_cli(cls, cli_value: str) -> "KVEventsConfig": - """Parse the CLI value for the event publisher config.""" - return KVEventsConfig.model_validate_json(cli_value) - class CompilationLevel: # constants for the levels of the compilation process @@ -3552,80 +3646,79 @@ class CompilationLevel: PIECEWISE = 3 -class CompilationConfig(BaseModel): - """ - Configuration for compilation. - It has three parts: +@config +@dataclass +class PassConfig: + """Configuration for custom Inductor passes. + + This is separate from general `CompilationConfig` so that inductor passes + don't all have access to full configuration - that would create a cycle as + the `PassManager` is set as a property of config.""" + + dump_graph_stages: list[str] = field(default_factory=list) + """List of stages for which we want to dump the graph. Each pass defines + its own stages (before, after, maybe in-between).""" + dump_graph_dir: Path = Path(".") + """Directory to dump the graphs.""" + # TODO(luka) better pass enabling system. + enable_fusion: bool = True + """Whether to enable the custom fusion pass.""" + enable_noop: bool = True + """Whether to enable the custom no-op elimination pass.""" + enable_sequence_parallelism: bool = False + """Whether to enable sequence parallelism.""" + enable_async_tp: bool = False + """Whether to enable async TP.""" + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + include = { + "enable_fusion", "enable_noop", "enable_sequence_parallelism", + "enable_async_tp" + } + dict_ = {k: v for k, v in asdict(self).items() if k in include} + return InductorPass.hash_dict(dict_) + + def __post_init__(self) -> None: + if not self.enable_noop and self.enable_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm + quant (fp8) fusion might not work") + + +@config +@dataclass +class CompilationConfig: + """Configuration for compilation. It has three parts: + - Top-level Compilation control: - - level: the level of compilation. - - 0: no compilation. - - 1: dynamo as is. - - 2: dynamo once. - - 3: piecewise compilation. - - debug_dump_path: the path to dump the debug information. - - cache_dir: the directory to store the compiled graph, to - accelerate Inductor compilation. By default, it will use - model-related information to generate a cache directory. - - backend: the backend for compilation. It needs to be a string. - - "" (empty string): use the default backend. - - "eager"/"openxla"/...: use the specified backend registered in PyTorch. - - "full.module.name": a qualified name which can be used to import the backend function. - We use string to avoid serialization issues when using compilation in a distributed setting. - When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). - When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). - - custom_ops: fine-grained control over which custom ops to enable/disable. - Use 'all' to enable all, 'none' to disable all. - Also specify a list of custom op names to enable (prefixed with a '+'), - or disable (prefixed with a '-'). - Examples: - - 'all,-op1' to enable all except op1 - - 'none,+op1,+op2' to enable only op1 and op2 - By default, all custom ops are enabled when running without Inductor - and disabled when running with Inductor (compile_level >= Inductor). - - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation. + - [`level`][vllm.config.CompilationConfig.level] + - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path] + - [`cache_dir`][vllm.config.CompilationConfig.cache_dir] + - [`backend`][vllm.config.CompilationConfig.backend] + - [`custom_ops`][vllm.config.CompilationConfig.custom_ops] + - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] - CudaGraph capture: - - use_cudagraph: whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses, and all - splitting ops write their outputs to input buffers. - Note that this is orthogonal to the cudagraph capture logic - outside of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. - - None (default): capture sizes are inferred from vllm config. - - list[int]: capture sizes are specified as given. - - cudagraph_num_of_warmups: number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs. - - cudagraph_copy_inputs: whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. - - full_cuda_graph: whether to use a full cuda graph for the entire forward - pass rather than splitting certain operations such as attention into subgraphs. - Thus this flag cannot be used together with splitting_ops. This may provide - performance benefits for smaller models. + - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] + - [`cudagraph_capture_sizes`] + [vllm.config.CompilationConfig.cudagraph_capture_sizes] + - [`cudagraph_num_of_warmups`] + [vllm.config.CompilationConfig.cudagraph_num_of_warmups] + - [`cudagraph_copy_inputs`] + [vllm.config.CompilationConfig.cudagraph_copy_inputs] + - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] - Inductor compilation: - - use_inductor: whether to use inductor compilation. - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for compile_sizes, - using configurations in inductor_compile_config. - - compile_sizes: sizes to compile for inductor. In addition - to integers, it also supports "cudagraph_capture_sizes" to - specify the sizes for cudagraph capture. - - inductor_compile_config: additional configurations for inductor. - - None: use default configurations. - - inductor_passes: additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses json format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - custom inductor passes: see PassConfig for more details + - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] + - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`inductor_compile_config`] + [vllm.config.CompilationConfig.inductor_compile_config] + - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] + - custom inductor passes Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -3636,83 +3729,135 @@ class CompilationConfig(BaseModel): static shapes. However, we find the general shape compilation is sufficient for most cases. It might be beneficial to compile for certain small batchsizes, where inductor is good at optimizing. - """ # noqa + """ + # Top-level Compilation control level: int = 0 + """The level of compilation: + + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation.""" debug_dump_path: str = "" + """The path to dump the debug information.""" cache_dir: str = "" + """The directory to store the compiled graph, to accelerate Inductor + compilation. By default, it will use model-related information to generate + a cache directory.""" backend: str = "" - custom_ops: list[str] = Field(default_factory=list) - splitting_ops: list[str] = Field(default=None) # type: ignore - + """The backend for compilation. It needs to be a string: + + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the + + backend function. + We use string to avoid serialization issues when using compilation in a + distributed setting. When the compilation level is 1 or 2, the backend is + used for the compilation directly (it sees the whole graph). When the + compilation level is 3, the backend is used for the piecewise compilation + (it sees a part of the graph).""" + custom_ops: list[str] = field(default_factory=list) + """Fine-grained control over which custom ops to enable/disable. Use 'all' + to enable all, 'none' to disable all. Also specify a list of custom op + names to enable (prefixed with a '+'), or disable (prefixed with a '-'). + Examples: + + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + + By default, all custom ops are enabled when running without Inductor and + disabled when running with Inductor (compile_level >= Inductor).""" + splitting_ops: list[str] = field(default_factory=list) + """A list of ops to split the full graph into subgraphs, used in piecewise + compilation.""" + + # Inductor capture use_inductor: bool = True - compile_sizes: Optional[list[Union[int, str]]] = Field(default=None) - inductor_compile_config: dict = Field(default_factory=dict) - inductor_passes: dict[str, str] = Field(default_factory=dict) - + """Whether to use inductor compilation: + + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for compile_sizes, + using configurations in inductor_compile_config.""" + compile_sizes: Optional[list[Union[int, str]]] = None + """Sizes to compile for inductor. In addition + to integers, it also supports "cudagraph_capture_sizes" to + specify the sizes for cudagraph capture.""" + inductor_compile_config: dict = field(default_factory=dict) + """Additional configurations for inductor. + - None: use default configurations.""" + inductor_passes: dict[str, str] = field(default_factory=dict) + """Additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses JSON format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" + + # CudaGraph compilation use_cudagraph: bool = False + """Whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future.""" cudagraph_num_of_warmups: int = 0 + """Number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs.""" cudagraph_capture_sizes: Optional[list[int]] = None + """Sizes to capture cudagraph. + - None (default): capture sizes are inferred from vllm config. + - list[int]: capture sizes are specified as given.""" cudagraph_copy_inputs: bool = False + """Whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False.""" full_cuda_graph: bool = False - - class PassConfig(BaseModel): - """ - Configuration for custom Inductor passes. - This is separate from general CompilationConfig so that inductor passes - don't all have access to full configuration - that would create a cycle - as the PassManager is set as a property of config. - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graphs. Default is . - - enable_fusion: whether to enable the custom fusion pass. - - enable_noop: whether to enable the custom no-op elimination pass. - TODO(luka) better pass enabling system. - - enable_sequence_parallelism: whether to enable sequence parallelism. - """ - dump_graph_stages: list[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True - enable_noop: bool = True - enable_sequence_parallelism: bool = False - - def uuid(self): - """ - Produces a hash unique to the pass configuration. - Any new fields that affect compilation should be added to the hash. - Do not include dump_graph_* in the hash - they don't affect - compilation. - """ - dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \ - "enable_sequence_parallelism"}) - return InductorPass.hash_dict(dict_) - - def model_post_init(self, __context: Any) -> None: - if not self.enable_noop and self.enable_fusion: - logger.warning_once( - "Fusion enabled but reshape elimination disabled. " - "RMSNorm + quant (fp8) fusion might not work") - - pass_config: PassConfig = Field(default_factory=PassConfig) - - # not configurable, computed after init - max_capture_size: int = PrivateAttr - local_cache_dir: str = PrivateAttr # local cache dir for each rank - # optimization: - # Intuitively, bs_to_padded_graph_size should be dict[int, int]. - # since we know all keys are in a range [0, max_capture_size], - # we can optimize it to list[int] for better lookup performance. - bs_to_padded_graph_size: list[int] = PrivateAttr + """whether to use a full cuda graph for the entire forward pass rather than + splitting certain operations such as attention into subgraphs. Thus this + flag cannot be used together with splitting_ops. This may provide + performance benefits for smaller models.""" + + pass_config: PassConfig = field(default_factory=PassConfig) + """Custom inductor passes, see PassConfig for more details""" + + max_capture_size: int = field(default=None, init=False) # type: ignore + """not configurable, computed after init""" + local_cache_dir: str = field(default=None, init=False) # type: ignore + """local cache dir for each rank""" + bs_to_padded_graph_size: list[int] = field( + default=None, # type: ignore + init=False) + """optimization: + Intuitively, bs_to_padded_graph_size should be dict[int, int]. + since we know all keys are in a range [0, max_capture_size], + we can optimize it to list[int] for better lookup performance.""" # keep track of enabled and disabled custom ops - enabled_custom_ops: Counter[str] = PrivateAttr - disabled_custom_ops: Counter[str] = PrivateAttr - traced_files: set[str] = PrivateAttr - compilation_time: float = PrivateAttr - - # Per-model forward context - # Map from layer name to layer objects that need to be accessed outside - # model code, e.g., Attention, FusedMOE when dp_size>1. - static_forward_context: dict[str, Any] = PrivateAttr + enabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are enabled""" + disabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are disabled""" + traced_files: set[str] = field(default_factory=set, init=False) + """files that are traced for compilation""" + compilation_time: float = field(default=0.0, init=False) + """time taken for compilation""" + + static_forward_context: dict[str, Any] = field(default_factory=dict, + init=False) + """Per-model forward context + Map from layer name to layer objects that need to be accessed outside + model code, e.g., Attention, FusedMOE when dp_size>1.""" def compute_hash(self) -> str: """ @@ -3747,7 +3892,17 @@ def __repr__(self) -> str: "pass_config", "traced_files", } - return self.model_dump_json(exclude=exclude, exclude_unset=True) + include = dict() + for k, v in asdict(self).items(): + if k in exclude: + continue + f = get_field(CompilationConfig, k) + if (d := f.default) is not MISSING and d == v: + continue + if (df := f.default_factory) is not MISSING and df() == v: + continue + include[k] = v + return json.dumps(include) __str__ = __repr__ @@ -3756,12 +3911,9 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - # do not use `eval`, it is dangerous and can execute arbitrary code - dict_value = ast.literal_eval(cli_value) - return CompilationConfig.model_validate(dict_value) - - def model_post_init(self, __context: Any) -> None: + return cls(**json.loads(cli_value)) + def __post_init__(self) -> None: count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" @@ -3779,9 +3931,6 @@ def model_post_init(self, __context: Any) -> None: if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False - if self.splitting_ops is None: - self.splitting_ops = [] - for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( @@ -3798,11 +3947,8 @@ def model_post_init(self, __context: Any) -> None: self.inductor_compile_config[k] = func if isinstance( func, InductorPass) else CallableInductorPass(func) - self.enabled_custom_ops = Counter() - self.disabled_custom_ops = Counter() - self.traced_files = set() - self.static_forward_context = {} - self.compilation_time = 0.0 + if isinstance(self.pass_config, dict): + self.pass_config = PassConfig(**self.pass_config) def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: @@ -3835,11 +3981,12 @@ def init_with_cudagraph_sizes(self, self.cudagraph_capture_sizes = cudagraph_capture_sizes else: # de-duplicate the sizes provided by the config - self.cudagraph_capture_sizes = list( - set(self.cudagraph_capture_sizes)) - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - cudagraph_capture_sizes, self.cudagraph_capture_sizes) + dedup_sizes = list(set(self.cudagraph_capture_sizes)) + if len(dedup_sizes) < len(self.cudagraph_capture_sizes): + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + cudagraph_capture_sizes, dedup_sizes) + self.cudagraph_capture_sizes = dedup_sizes computed_compile_sizes = [] if self.compile_sizes is not None: @@ -3889,39 +4036,69 @@ def set_splitting_ops_for_v1(self): ] +@config @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. """ - model_config: ModelConfig = field(default=None, init=True) # type: ignore - cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default_factory=ParallelConfig, - init=True) - scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig, - init=True) - device_config: DeviceConfig = field(default=None, - init=True) # type: ignore - load_config: LoadConfig = field(default=None, init=True) # type: ignore + # TODO: use default_factory once default constructing ModelConfig doesn't + # try to download a model + model_config: ModelConfig = None # type: ignore + """Model configuration.""" + cache_config: CacheConfig = field(default_factory=CacheConfig) + """Cache configuration.""" + parallel_config: ParallelConfig = field(default_factory=ParallelConfig) + """Parallel configuration.""" + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) + """Scheduler configuration.""" + device_config: DeviceConfig = field(default_factory=DeviceConfig) + """Device configuration.""" + load_config: LoadConfig = field(default_factory=LoadConfig) + """Load configuration.""" lora_config: Optional[LoRAConfig] = None - speculative_config: SpeculativeConfig = field(default=None, - init=True) # type: ignore - decoding_config: Optional[DecodingConfig] = None + """LoRA configuration.""" + speculative_config: Optional[SpeculativeConfig] = None + """Speculative decoding configuration.""" + decoding_config: DecodingConfig = field(default_factory=DecodingConfig) + """Decoding configuration.""" observability_config: Optional[ObservabilityConfig] = None + """Observability configuration.""" prompt_adapter_config: Optional[PromptAdapterConfig] = None + """Prompt adapter configuration.""" quant_config: Optional[QuantizationConfig] = None - compilation_config: CompilationConfig = field(default=None, - init=True) # type: ignore - kv_transfer_config: KVTransferConfig = field(default=None, - init=True) # type: ignore + """Quantization configuration.""" + compilation_config: CompilationConfig = field( + default_factory=CompilationConfig) + """`torch.compile` configuration for the model. + + When it is a number (0, 1, 2, 3), it will be interpreted as the + optimization level. + + NOTE: level 0 is the default level without any optimization. level 1 and 2 + are for internal testing only. level 3 is the recommended level for + production. + + Following the convention of traditional compilers, using `-O` without space + is also supported. `-O3` is equivalent to `-O 3`. + + You can specify the full compilation config like so: + `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` + """ + kv_transfer_config: Optional[KVTransferConfig] = None + """The configurations for distributed KV cache transfer.""" kv_events_config: Optional[KVEventsConfig] = None + """The configurations for event publishing.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. - additional_config: SupportsHash = field(default=None, - init=True) # type: ignore + additional_config: Union[dict, SupportsHash] = field(default_factory=dict) + """Additional config for specified platform. Different platforms may + support different configs. Make sure the configs are valid for the platform + you are using. Contents must be hashable.""" instance_id: str = "" + """The ID of the vLLM instance.""" def compute_hash(self) -> str: """ @@ -4002,7 +4179,14 @@ def compute_hash(self) -> str: else: vllm_factors.append("None") if self.additional_config: - vllm_factors.append(self.additional_config.compute_hash()) + if isinstance(additional_config := self.additional_config, dict): + additional_config_hash = hashlib.md5( + json.dumps(additional_config, sort_keys=True).encode(), + usedforsecurity=False, + ).hexdigest() + else: + additional_config_hash = additional_config.compute_hash() + vllm_factors.append(additional_config_hash) else: vllm_factors.append("None") factors.append(vllm_factors) @@ -4080,6 +4264,8 @@ def __post_init__(self): self.speculative_config, self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) + self.model_config.verify_dual_chunk_attention_config( + self.load_config) if self.cache_config is not None: self.cache_config.verify_with_parallel_config(self.parallel_config) @@ -4110,6 +4296,12 @@ def __post_init__(self): if self.compilation_config is None: self.compilation_config = CompilationConfig() + + # async tp is built on top of sequence parallelism + # and requires it to be enabled. + if self.compilation_config.pass_config.enable_async_tp: + self.compilation_config.pass_config.enable_sequence_parallelism = \ + True if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ @@ -4129,18 +4321,6 @@ def __post_init__(self): self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() - if self.parallel_config is not None and \ - self.parallel_config.tensor_parallel_size > 1 and \ - self.parallel_config.pipeline_parallel_size > 1 and \ - self.compilation_config is not None and \ - self.compilation_config.pass_config is not None and \ - self.compilation_config.pass_config.enable_sequence_parallelism: - logger.warning_once( - "Sequence parallelism is not supported with pipeline " - "parallelism. Disabling sequence parallelism.") - self.compilation_config.pass_config.\ - enable_sequence_parallelism = False - self._set_cudagraph_sizes() if self.cache_config is not None and \ @@ -4166,18 +4346,6 @@ def __post_init__(self): "full_cuda_graph is not supported with " "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True - - if self.model_config and self.model_config.use_mla and \ - not (current_platform.is_cuda() or current_platform.is_rocm()): - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") - self.scheduler_config.enable_chunked_prefill = False - self.scheduler_config.chunked_prefill_enabled = False - self.scheduler_config.max_num_batched_tokens = max( - self.scheduler_config.max_model_len, - _DEFAULT_MAX_NUM_BATCHED_TOKENS) - if self.cache_config is not None: self.cache_config.enable_prefix_caching = False @@ -4403,7 +4571,7 @@ def contains_object_print(text): text (str): The text to check Returns: - bool: True if a match is found, False otherwise + result (bool): `True` if a match is found, `False` otherwise. """ pattern = r'at 0x[a-fA-F0-9]{2,16}>' match = re.search(pattern, text) diff --git a/vllm/connections.py b/vllm/connections.py index 9abc66050e18..84e32a4d5ca9 100644 --- a/vllm/connections.py +++ b/vllm/connections.py @@ -167,4 +167,7 @@ async def async_download_file( global_http_connection = HTTPConnection() -"""The global {class}`HTTPConnection` instance used by vLLM.""" +""" +The global [`HTTPConnection`][vllm.connections.HTTPConnection] instance used +by vLLM. +""" diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 9ff77f14a5e8..6fcbca628c6a 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -11,7 +11,7 @@ import gc import os from contextlib import contextmanager -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch @@ -63,7 +63,7 @@ def find_loaded_library(lib_name) -> Optional[str]: libcudart = None # py_device, py_alignedSize, py_d_mem, py_p_memHandle -HandleType = Tuple[int, int, int, int] +HandleType = tuple[int, int, int, int] @dataclasses.dataclass @@ -148,9 +148,9 @@ def __init__(self): "Please track https://github.com/pytorch/pytorch/issues/147851 " "for the latest updates.") - self.pointer_to_data: Dict[int, AllocationData] = {} + self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag - self.allocator_and_pools: Dict[str, Any] = {} + self.allocator_and_pools: dict[str, Any] = {} def python_malloc_callback(self, allocation_handle: HandleType) -> None: """ @@ -172,7 +172,7 @@ def python_free_callback(self, ptr: int) -> HandleType: def sleep( self, - offload_tags: Optional[Union[Tuple[str, ...], + offload_tags: Optional[Union[tuple[str, ...], str]] = None) -> None: """ Put the allocator in sleep mode. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 894a0fafb640..d85a41ddac22 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import torch import torch.distributed @@ -32,7 +32,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor, return get_tp_group().gather(input_, dst, dim) -def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, +def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0): if not torch.distributed.is_initialized(): diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py new file mode 100644 index 000000000000..a250ec89cd5b --- /dev/null +++ b/vllm/distributed/device_communicators/all2all.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +import importlib.util +from typing import TYPE_CHECKING + +import torch +import torch.distributed as dist + +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger + +from .base_device_communicator import All2AllManagerBase, Cache + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE +else: + FusedMoE = None + + +class NaiveAll2AllManager(All2AllManagerBase): + """ + A naive implementation of all2all communication. + It uses all-reduce under the hood, which is not + efficient at all. The main purpose is for testing and + debugging. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(self.dp_world_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + self.dp_group.broadcast(buffer[start:end, :], idx) + + return buffer + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = self.dp_group.all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] + return hidden_states + + def destroy(self): + pass + + +class PPLXAll2AllManager(All2AllManagerBase): + """ + All2All communication based on PPLX kernels. + """ + + def __init__(self, cpu_group): + has_pplx = importlib.util.find_spec("pplx_kernels") is not None + assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa + super().__init__(cpu_group) + + if self.internode: + # inter-node communication needs nvshmem, + # intra-node communication uses p2p mapping directly + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init) + logger.debug( + "Initialize NVSHMEM for pplx_kernels: " + "rank=%d, world size=%d", self.rank, self.world_size) + uid = nvshmem_get_unique_id( + ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() + dist.broadcast(uid, + src=dist.get_process_group_ranks(self.cpu_group)[0], + group=self.cpu_group) + logger.debug("PPLX NVSHMEM UID = %s", uid) + nvshmem_init(uid, self.rank, self.world_size) + + self.handle_cache = Cache() + + def get_handle(self, kwargs): + import pplx_kernels as pplx + return self.handle_cache.get_or_create( + kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + with self.handle_cache._lock: + for _, handle in self.handle_cache._cache.items(): + handle.destroy() + + if self.internode: + from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX NVSHMEM finalize") + nvshmem_finalize() diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 240313b98c88..52b970949144 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,11 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 +import threading from typing import Optional +from weakref import WeakValueDictionary import torch import torch.distributed as dist from torch.distributed import ProcessGroup +class Cache: + + def __init__(self): + self._cache: WeakValueDictionary = WeakValueDictionary() + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, kwargs, func): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + instance = self._cache.get(key) + if instance is None: + instance = func(**kwargs) + self._cache[key] = instance + return instance + + +class All2AllManagerBase: + + def __init__(self, cpu_group): + self.cpu_group = cpu_group + + # compute some common properties + from vllm.distributed.parallel_state import (get_dp_group, + get_tp_group, + in_the_same_node_as) + + # all2all lives in ep group, which is merged from dp and tp group + self.dp_group = get_dp_group() + self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction + # when we create this object + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + + # all2all communication often has separate implementations for + # intra-node and inter-node communication + self.intranode = in_the_same_node_as(cpu_group, source_rank=0) + self.internode = not self.intranode + + def get_handle(self, kwargs): + # get a handle for the all2all communication, + # based on the kwargs. + # different layers can have different configs, + # e.g. one layer has hidden size 1024, another has 2048. + # usually the underlying implementation caches the handle + # and reuse it for the same config. + raise NotImplementedError + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + class DeviceCommunicatorBase: """ Base class for device-specific communicator. @@ -31,6 +96,18 @@ def __init__(self, self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) + use_ep = False + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + # as long as we use data parallel (coupled data parallel + # where all data parallel ranks execute forward together), + # we initialize the all2all manager used in expert parallel. + use_ep = config.parallel_config.data_parallel_size > 1 + + self.use_all2all = "ep" in unique_name and use_ep + self.all2all_manager: Optional[All2AllManagerBase] = None + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ @@ -149,3 +226,35 @@ def recv(self, def destroy(self): pass + + def prepare_communication_buffer_for_model(self, + model: torch.nn.Module) -> None: + """ + Prepare the communication buffer for the model. + """ + if not self.use_all2all: + return + + moe_modules = [ + module for module in model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + for module in moe_modules: + module.quant_method.init_prepare_finalize(module.moe_config, + module.quant_config) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Dispatch the hidden states and router logits to the appropriate device. + This is a no-op in the base class. + """ + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Combine the hidden states and router logits from the appropriate device. + This is a no-op in the base class. + """ + return hidden_states diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index 1f4b4faf1190..c04218cb9f39 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import List, Optional +from typing import Optional import torch from torch.distributed import ProcessGroup @@ -22,7 +22,10 @@ def __init__(self, super().__init__(cpu_group, device, device_group, unique_name) self.dist_module = torch.distributed - if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + if (current_platform.get_cpu_architecture() + == CpuArchEnum.X86) and hasattr( + torch.ops._C, + "init_shm_manager") and unique_name.startswith("tp"): self.dist_module = _CPUSHMDistributed(self) def all_reduce(self, input_): @@ -95,6 +98,8 @@ class _CPUSHMDistributed: def __init__(self, communicator: CpuCommunicator): instance_identifier = os.environ["VLLM_DIST_IDENT"] + unique_name = communicator.unique_name + instance_identifier = f"{instance_identifier}-{unique_name}" self.communicator = communicator group_ranks = [str(rank) for rank in self.communicator.ranks] @@ -125,7 +130,7 @@ def all_reduce(self, def gather(self, input: torch.Tensor, - gather_list: Optional[List[torch.Tensor]], + gather_list: Optional[list[torch.Tensor]], dst: int = -1, group: Optional[ProcessGroup] = None) -> None: # Note: different from the torch gather, here we use local dst rank. diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 8bca278f3888..a05a13f51d4b 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -5,8 +5,13 @@ import torch from torch.distributed import ProcessGroup +import vllm.envs as envs +from vllm.logger import init_logger + from .base_device_communicator import DeviceCommunicatorBase +logger = init_logger(__name__) + class CudaCommunicator(DeviceCommunicatorBase): @@ -23,7 +28,9 @@ def __init__(self, from vllm.distributed.parallel_state import ( _ENABLE_CUSTOM_ALL_REDUCE) use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE - use_pynccl = True + + # ep does not use pynccl + use_pynccl = "ep" not in unique_name self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce @@ -49,6 +56,19 @@ def __init__(self, device=self.device, ) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") + elif all2all_backend == "pplx": + from .all2all import PPLXAll2AllManager + self.all2all_manager = PPLXAll2AllManager(self.cpu_group) + logger.info("Using PPLX all2all manager.") + else: + raise ValueError(f"Unknown all2all backend: {all2all_backend}") + def all_reduce(self, input_): # always try custom allreduce first, # and then pynccl. @@ -129,3 +149,19 @@ def destroy(self): self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None + if self.all2all_manager is not None: + self.all2all_manager.destroy() + self.all2all_manager = None + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) + return hidden_states diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 1d53b1c5b809..6c15ef644b8c 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -6,7 +6,7 @@ import ctypes from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Optional # this line makes it possible to directly load `libcudart.so` using `ctypes` import torch # noqa @@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure): class Function: name: str restype: Any - argtypes: List[Any] + argtypes: list[Any] def find_loaded_library(lib_name) -> Optional[str]: @@ -97,11 +97,11 @@ class CudaRTLibrary: # class attribute to store the mapping from the path to the library # to avoid loading the same library multiple times - path_to_library_cache: Dict[str, Any] = {} + path_to_library_cache: dict[str, Any] = {} # class attribute to store the mapping from library path # to the corresponding dictionary - path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + path_to_dict_mapping: dict[str, dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): if so_file is None: diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 45fc2a7118b7..5c2dbcc27b13 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from typing import List, Optional, Union +from typing import Optional, Union import torch import torch.distributed as dist @@ -265,7 +265,8 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: def close(self): if not self.disabled and self._ptr: - ops.dispose(self._ptr) + if ops is not None: + ops.dispose(self._ptr) self._ptr = 0 self.free_shared_buffer(self.meta_ptrs, rank=self.rank) self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) @@ -276,7 +277,7 @@ def __del__(self): @staticmethod def create_shared_buffer(size_in_bytes: int, group: Optional[ProcessGroup] = None, - uncached: Optional[bool] = False) -> List[int]: + uncached: Optional[bool] = False) -> list[int]: pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) world_size = dist.get_world_size(group=group) @@ -284,7 +285,7 @@ def create_shared_buffer(size_in_bytes: int, handles = [None] * world_size dist.all_gather_object(handles, handle, group=group) - pointers: List[int] = [] + pointers: list[int] = [] for i, h in enumerate(handles): if i == rank: pointers.append(pointer) # type: ignore @@ -293,9 +294,10 @@ def create_shared_buffer(size_in_bytes: int, return pointers @staticmethod - def free_shared_buffer(pointers: List[int], + def free_shared_buffer(pointers: list[int], group: Optional[ProcessGroup] = None, rank: Optional[int] = 0) -> None: if rank is None: rank = dist.get_rank(group=group) - ops.free_shared_buffer(pointers[rank]) + if ops is not None: + ops.free_shared_buffer(pointers[rank]) diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index d8d6eed2dd7e..11b8b57fe2ae 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -7,8 +7,9 @@ import subprocess import sys import tempfile +from collections.abc import Sequence from itertools import product -from typing import Dict, List, Optional, Sequence +from typing import Optional import torch.distributed as dist import torch.multiprocessing as mp @@ -149,7 +150,7 @@ def can_actually_p2p( p_src.join() p_tgt.join() assert p_src.exitcode == 0 and p_tgt.exitcode == 0 - result: List[bool] = [] + result: list[bool] = [] for src, tgt in zip(batch_src, batch_tgt): a = result_queue.get() b = result_queue.get() @@ -175,7 +176,7 @@ def can_actually_p2p( # e.g. used by different vllm engines. The device id in the cache file is a # **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number # of visible devices in the vllm engine. -_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None +_gpu_p2p_access_cache: Optional[dict[str, bool]] = None def gpu_p2p_access_check(src: int, tgt: int) -> bool: @@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: # only the local master process (with local_rank == 0) can # enter this block to calculate the cache logger.info("generating GPU P2P access cache in %s", path) - cache: Dict[str, bool] = {} + cache: dict[str, bool] = {} ids = list(range(num_dev)) # batch of all pairs of GPUs batch_src, batch_tgt = zip(*list(product(ids, ids))) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 4f04899e92e6..6f69089b6196 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -24,7 +24,7 @@ import ctypes import platform from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch.distributed import ReduceOp @@ -121,7 +121,7 @@ def from_torch(cls, op: ReduceOp) -> int: class Function: name: str restype: Any - argtypes: List[Any] + argtypes: list[Any] class NCCLLibrary: @@ -210,11 +210,11 @@ class NCCLLibrary: # class attribute to store the mapping from the path to the library # to avoid loading the same library multiple times - path_to_library_cache: Dict[str, Any] = {} + path_to_library_cache: dict[str, Any] = {} # class attribute to store the mapping from library path # to the corresponding dictionary - path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + path_to_dict_mapping: dict[str, dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): @@ -238,7 +238,7 @@ def __init__(self, so_file: Optional[str] = None): raise e if so_file not in NCCLLibrary.path_to_dict_mapping: - _funcs: Dict[str, Any] = {} + _funcs: dict[str, Any] = {} for func in NCCLLibrary.exported_functions: f = getattr(self.lib, func.name) f.restype = func.restype diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index e33cfee21970..40e57e6624d1 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -import os import pickle -import sys import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory from threading import Event -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union from unittest.mock import patch import torch @@ -19,7 +17,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs -from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.logger import init_logger from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, is_valid_ipv6_address) @@ -28,20 +26,6 @@ logger = init_logger(__name__) -# We prefer to use os.sched_yield as it results in tighter polling loops, -# measured to be around 3e-7 seconds. However on earlier versions of Python -# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) -USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) - or (sys.version_info[:2] == (3, 10) - and sys.version_info[2] >= 8)) - - -def sched_yield(): - if USE_SCHED_YIELD: - os.sched_yield() - else: - time.sleep(0) - class ShmRingBuffer: @@ -173,9 +157,9 @@ def get_metadata(self, current_idx: int): @dataclass class Handle: - local_reader_ranks: List[int] = field(default_factory=list) + local_reader_ranks: list[int] = field(default_factory=list) - buffer_handle: Optional[Tuple[int, int, int, str]] = None + buffer_handle: Optional[tuple[int, int, int, str]] = None local_subscribe_addr: Optional[str] = None remote_subscribe_addr: Optional[str] = None remote_addr_ipv6: bool = False @@ -187,7 +171,7 @@ def __init__( self, n_reader, # number of all readers n_local_reader, # number of local readers through shared memory - local_reader_ranks: Optional[List[int]] = None, + local_reader_ranks: Optional[list[int]] = None, max_chunk_bytes: int = 1024 * 1024 * 10, max_chunks: int = 10, connect_ip: Optional[str] = None, diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index de66ceaeef6f..a1775279661d 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -91,3 +91,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." return xm.all_gather(input_, dim=dim) + + +try: + from tpu_commons.distributed.device_communicators import ( + TpuCommunicator as TpuCommonsCommunicator) + TpuCommunicator = TpuCommonsCommunicator # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + pass diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 960913858527..29c6a70c4d26 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -5,6 +5,7 @@ import time from abc import ABC, abstractmethod from collections import deque +from dataclasses import asdict from itertools import count from queue import Queue from typing import Any, Callable, Optional, Union @@ -129,6 +130,7 @@ def __init__( self._endpoint = endpoint self._replay_endpoint = replay_endpoint self._hwm = hwm + self._socket_setup() # Payload self._seq_gen = count() @@ -206,7 +208,6 @@ def _socket_setup(self) -> None: def _publisher_thread(self) -> None: """Background thread that processes the event queue.""" self._pack = msgspec.msgpack.Encoder() - self._socket_setup() assert self._pub is not None # narrows type for mypy @@ -284,7 +285,7 @@ def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher: if not config: return NullEventPublisher() - config_dict = config.model_dump() + config_dict = asdict(config) kind = config_dict.pop("publisher", "null") config_dict.pop("enable_kv_cache_events") diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index a9f26607de49..8b6abf5a80dd 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_transfer_state import ( - ensure_kv_transfer_initialized, get_kv_transfer_group, + KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) __all__ = [ diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 0d1a3d40af41..e9b70610e8cd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -8,7 +8,7 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, Union import torch @@ -55,7 +55,7 @@ def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: @@ -71,7 +71,7 @@ def send_kv_caches_and_hidden_states( start and end layer information. model_input (ModelInputForGPUWithSamplingMetadata): The input metadata from vLLM. - kv_caches (List[torch.Tensor]): List of KV caches (keys and values) + kv_caches (list[torch.Tensor]): List of KV caches (keys and values) for each layer. hidden_or_intermediate_states (Union[torch.Tensor, IntermediateTensors]): @@ -88,8 +88,8 @@ def send_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: """ Receive KV caches and hidden states from the connector. @@ -104,7 +104,7 @@ def recv_kv_caches_and_hidden_states( The model executable from vLLM modelrunner. model_input (ModelInputForGPUWithSamplingMetadata): The model input from vLLM modelrunner. - kv_caches (List[torch.Tensor]): + kv_caches (list[torch.Tensor]): List of KV caches for each layer. Returns: diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6532c101a4f6..06b3983ed68b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import importlib -from typing import TYPE_CHECKING, Callable, Dict, Type +from typing import TYPE_CHECKING, Callable import vllm.envs as envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType @@ -18,7 +18,7 @@ class KVConnectorFactory: - _registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {} + _registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {} @classmethod def register_connector(cls, name: str, module_path: str, @@ -27,7 +27,7 @@ def register_connector(cls, name: str, module_path: str, if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") - def loader() -> Type[KVConnectorBaseType]: + def loader() -> type[KVConnectorBaseType]: module = importlib.import_module(module_path) return getattr(module, class_name) @@ -58,8 +58,17 @@ def create_connector_v1( raise ValueError("Attempting to initialize a V1 Connector, " f"but found {envs.VLLM_USE_V1=}") - connector_name = config.kv_transfer_config.kv_connector - connector_cls = cls._registry[connector_name]() + kv_transfer_config = config.kv_transfer_config + connector_name = kv_transfer_config.kv_connector + if connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + connector_module_path = kv_transfer_config.kv_connector_module_path + if connector_module_path is None: + raise ValueError( + f"Unsupported connector type: {connector_name}") + connector_module = importlib.import_module(connector_module_path) + connector_cls = getattr(connector_module, connector_name) assert issubclass(connector_cls, KVConnectorBase_V1) logger.info("Creating v1 connector with name: %s", connector_name) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. @@ -105,3 +114,13 @@ def create_connector_v1( "LMCacheConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", "LMCacheConnectorV1") + +KVConnectorFactory.register_connector( + "NixlConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", + "NixlConnector") + +KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py index 42de227b6c30..d121cb701bef 100644 --- a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -7,7 +7,7 @@ (2) offload and share KV caches. """ -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, Union import torch @@ -63,8 +63,8 @@ def __init__( def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: retrieve_status = self.lmcache_should_retrieve(model_input) @@ -78,7 +78,7 @@ def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py index 7b26aec23239..58eabd0a37eb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -6,7 +6,7 @@ database-style KVStore. """ import hashlib -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, Union import torch @@ -31,12 +31,12 @@ def __init__( local_rank: int, config: VllmConfig, ): - self.config = config.kv_transfer_config + self.kv_transfer_config = config.kv_transfer_config self.kv_helper = kv_helper(config) self.local_tp_rank = local_rank # Init kv_store - if self.config.kv_connector == "MooncakeStoreConnector": + if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector": # Check if MOONCAKE_CONFIG_PATH is set import os use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None @@ -50,10 +50,11 @@ def __init__( MooncakeStore) logger.info( "Initializing KVStoreConnector under kv_transfer_config %s", - self.config) + self.kv_transfer_config) self.kv_store = MooncakeStore(config) else: - logger.error("Can not find %s", self.config.kv_connector) + logger.error("Can not find %s", + self.kv_transfer_config.kv_connector) assert self.kv_store is not None @@ -70,7 +71,7 @@ def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: @@ -113,8 +114,8 @@ def send_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: bypass_model_exec = True input_tokens_tensor = model_input.input_tokens diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 0464a7585138..ed8fe38161e9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -8,7 +8,7 @@ But the logic can be extended to support other pipe and lookup buffer. """ -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import torch @@ -106,7 +106,7 @@ def __init__( else: # the current vLLM instance is KV consumer, so it needs to connect - # its recv pipe to the send pipe of KV producder + # its recv pipe to the send pipe of KV producer if self.config.kv_connector == "PyNcclConnector": self.consumer_data_pipe = PyNcclPipe( local_rank=local_rank, @@ -133,7 +133,7 @@ def __init__( ) def select(self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: assert self.consumer_buffer is not None, "Please initialize the "\ "consumer buffer before calling select." @@ -152,7 +152,7 @@ def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: @@ -207,8 +207,8 @@ def send_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: # When bypass_model_exec is set to False, it means that at least for one diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 0b0ce9828a74..b1c9c9af6e23 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -44,8 +44,9 @@ def get_model_args(self, model_executable: torch.nn.Module): head_size = model_config.qk_nope_head_dim + \ model_config.qk_rope_head_dim else: - head_size = getattr(model_config, "head_dim", - int(hidden_size // num_attention_heads)) + head_size = getattr(model_config, "head_dim", None) + if head_size is None: + head_size = int(hidden_size // num_attention_heads) return num_heads, head_size diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index a017b140e090..e66aaa7f8af8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -2,7 +2,4 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorRole) -__all__ = [ - "KVConnectorRole", - "KVConnectorBase_V1", -] +__all__ = ["KVConnectorRole", "KVConnectorBase_V1"] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca919..bc9258e9d07b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,8 +22,7 @@ import enum from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional import torch @@ -34,6 +33,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -47,8 +47,11 @@ class KVConnectorRole(enum.Enum): WORKER = 1 -@dataclass class KVConnectorMetadata: + """ + Abstract Metadata used to communicate between the + Scheduler KVConnector and Worker KVConnector. + """ pass @@ -66,6 +69,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def role(self) -> KVConnectorRole: return self._role + # ============================== + # Worker-side methods + # ============================== + def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. @@ -97,9 +104,15 @@ def _get_connector_metadata(self) -> KVConnectorMetadata: """ return self._connector_metadata - # ============================== - # Worker-side methods - # ============================== + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + return @abstractmethod def start_load_kv(self, forward_context: "ForwardContext", @@ -162,15 +175,31 @@ def wait_for_save(self): """ pass + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer, + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + # ============================== # Scheduler-side methods # ============================== + @abstractmethod def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -181,13 +210,17 @@ def get_num_new_matched_tokens( computed tokens for this request Returns: - the number of tokens that can be loaded from the - external KV cache beyond what is already computed. + A tuple with the following elements: + - The number of tokens that can be loaded from the + external KV cache beyond what is already computed. + - `True` if external KV cache tokens will be loaded + asynchronously (between scheduler steps). """ pass @abstractmethod def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. @@ -207,3 +240,20 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ pass + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return False, None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e07f185f0dd8..2cb68dc1ff67 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -92,7 +93,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -107,9 +108,10 @@ def get_num_new_matched_tokens( external KV cache beyond what is already computed. """ return self._lmcache_engine.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens), False def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py new file mode 100644 index 000000000000..0aabb260fd3d --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class MultiKVConnectorMetadata(KVConnectorMetadata): + metadata: tuple[KVConnectorMetadata, ...] + extra_async_saves: Optional[dict[str, int]] = None + + +class MultiConnector(KVConnectorBase_V1): + """ + A wrapper for using multiple KVConnectors at the same time. + + The current logic is: + - Load KV from the first connector that advertises available tokens from + get_num_new_matched_tokens(), based on the order in the config. + - Save to all connectors. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._connectors: list[KVConnectorBase_V1] = [] + ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors") + assert ktcs is not None + for ktc in ktcs: + temp_config = copy.copy(vllm_config) + temp_config.kv_transfer_config = KVTransferConfig(**ktc) + self._connectors.append( + KVConnectorFactory.create_connector_v1(temp_config, role)) + + # A mapping from request id to the connector that is assigned to it. + self._requests_to_connector: dict[str, KVConnectorBase_V1] = {} + + # Keeps track of *additional* remaining async saves (beyond 1) to be + # finished per request. Not needed for async loads since we only allow + # a single connector to load. + # Propagated from scheduler to worker side via the connector metadata. + self._extra_async_saves: dict[str, int] = {} + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + for c in self._connectors: + c.register_kv_caches(kv_caches) + + # We must override the base class method here because we need to bind + # the metadata to each connector in the order of the connectors in the + # MultiKVConnectorMetadata. + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, MultiKVConnectorMetadata) + if connector_metadata.extra_async_saves: + self._extra_async_saves.update( + connector_metadata.extra_async_saves) + for c, cm in zip(self._connectors, connector_metadata.metadata): + c.bind_connector_metadata(cm) + + def clear_connector_metadata(self) -> None: + for c in self._connectors: + c.clear_connector_metadata() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + for c in self._connectors: + c.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + for c in self._connectors: + c.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + for c in self._connectors: + c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self): + for c in self._connectors: + c.wait_for_save() + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + finished_sending: set[str] = set() + finished_recving: set[str] = set() + for c in self._connectors: + sending, recving = c.get_finished(finished_req_ids) + if not recving and not sending: + continue + # Aggregate finished recving request ids. + finished_recving.update(recving or ()) + # Aggregate finished sending request ids - only include + # once we've drained the "extra" count (for cases where + # more than one connector is async-saving the same request). + for req_id in sending or (): + extra_pending = self._extra_async_saves.get(req_id) + if extra_pending is None: + finished_sending.add(req_id) + continue + assert extra_pending > 0 + if extra_pending == 1: + del self._extra_async_saves[req_id] + else: + self._extra_async_saves[req_id] = extra_pending - 1 + + return finished_sending or None, finished_recving or None + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + for c in self._connectors: + toks, load_async = c.get_num_new_matched_tokens( + request, num_computed_tokens) + # The first connector that has new matched tokens will be assigned + # to this request. + if toks > 0: + self._requests_to_connector[request.request_id] = c + return toks, load_async + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + # If the request is not assigned to any connector, we do nothing. + if request.request_id not in self._requests_to_connector: + return + # We assume that the request is assigned to only one connector. + c = self._requests_to_connector.pop(request.request_id) + c.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: + metadata = MultiKVConnectorMetadata(metadata=tuple( + c.build_connector_meta(scheduler_output) + for c in self._connectors)) + if self._extra_async_saves: + metadata.extra_async_saves = self._extra_async_saves + self._extra_async_saves = {} + return metadata + + def request_finished( + self, + request: "Request", + blocks: "KVCacheBlocks", + ) -> tuple[bool, Optional[dict[str, Any]]]: + async_saves = 0 + kv_txfer_params = None + for c in self._connectors: + async_save, txfer_params = c.request_finished(request, blocks) + if async_save: + async_saves += 1 + if txfer_params is not None: + if kv_txfer_params is not None: + #TODO we can probably change this to merge the dicts here, + # checking for key clashes. + raise RuntimeError( + "Only one connector can produce KV transfer params") + kv_txfer_params = txfer_params + if async_saves > 1: + self._extra_async_saves[request.request_id] = async_saves - 1 + return async_saves > 0, kv_txfer_params diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py new file mode 100644 index 000000000000..6303d77ad305 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -0,0 +1,851 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import math +import threading +import time +import uuid +from collections import defaultdict +from collections.abc import Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import torch +import zmq + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.logger import init_logger +from vllm.utils import make_zmq_path, make_zmq_socket, round_down +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +GET_META_MSG = b"get_meta_msg" + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + + +class NixlAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str + + +class NixlConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + ) + + +class NixlConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler : Optional[NixlConnectorScheduler] = \ + NixlConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[NixlConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlConnectorWorker( + vllm_config, str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """NixlConnector does not do layerwise saving.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """NixlConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """NixlConnector does not save explicitly.""" + pass + + +class NixlConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + logger.info("Initializing NIXL Scheduler %s", engine_id) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + rounded_num_prompt_tokens = round_down( + len(request.prompt_token_ids), self.block_size) + count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) + if count > 0: + return count, True + + # NOTE: if count is 0 here, we have less than block_size + # tokens to pull after subtracting the local prefix cache hit. + # The remote only sends fully computed blocks, so there is + # nothing to transfer but we still need to notify the + # prefill worker so that the remote blocks are freed. + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + self._reqs_need_recv[request.request_id] = (request, []) + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + if not block_ids: + logger.debug( + "Skipping adding request %s to NixlConnectorMetadata, " + "as there are no remote blocks to pull", req_id) + continue + + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + # Get computed blocks. + all_full = request.num_computed_tokens % self.block_size == 0 + computed_block_ids = block_ids if all_full else block_ids[:-1] + + # If prompt < block_size, no xfer so free blocks immediately. + delay_free_blocks = len(computed_block_ids) > 0 + + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, + remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, + ) + + +class NixlConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL worker %s", engine_id) + + # Agent. + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # Map of engine_id -> agent_name. + self._remote_agents: dict[str, str] = {} + + # Metadata. + self.engine_id = engine_id + self.rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + + # KV Caches and nixl tracking data. + self.kv_caches: dict[str, torch.Tensor] = {} + + # Map of engine_id -> kv_caches_base_addr + self.kv_caches_base_addr: dict[str, list[int]] = {} + + # Number of NIXL regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + self.num_layers = 0 + + # nixl_prepped_dlist_handle (int). + self.src_xfer_side_handle: int = 0 + # Map of engine_id -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: dict[str, int] = {} + + # Map of engine_id -> num_blocks. + self.dst_num_blocks: dict[str, int] = {} + self._registered_descs: list[Any] = [] + + # In progress transfers. + # [req_id -> list[handle]] + self._recving_transfers: defaultdict[str, list[Any]] = defaultdict( + list[Any]) + + # Complete transfer tracker. Used by the rank 0 to track finished + # transactions on ranks 1 to N-1. + # [req_id -> count] + self._done_recving_count: defaultdict[str, + int] = defaultdict(lambda: 0) + self._done_sending_count: defaultdict[str, + int] = defaultdict(lambda: 0) + + # Background thread for establishing new connections. + self._nixl_handshake_listener_t: Optional[threading.Thread] = None + + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + # List of block window sizes for each layer for local attention + self.block_window_per_layer: list[Optional[int]] = [] + + @staticmethod + def _nixl_handshake_listener(metadata: NixlAgentMetadata, + ready_event: threading.Event, rank: int): + """Background thread for getting new NIXL handshakes.""" + # NOTE(rob): this is a simple implementation. We will move + # to a better approach like an ETCD server in the future. + + # NOTE(rob): to support heterogeneous TP, we will have to + # move this into the scheduler rather than worker, since + # each rank needs the metadata of all other ranks (whereas + # in this setup, each rank only gets one other rank's meta. + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + str(size_in_bytes)) + + # Listen for new requests for metadata. + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + # NOTE(rob): we need each rank to have a unique port. This + # hack to keeps us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank + path = make_zmq_path("tcp", host, port) + logger.debug("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, _, msg = sock.recv_multipart() + if msg != GET_META_MSG: + logger.warning( + "Connection listener got unexpected message %s", msg) + sock.send_multipart((identity, b"", encoded_data)) + + def _nixl_handshake(self, host: str, port: int): + """Do a NIXL handshake with a remote instance.""" + + start_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + path = make_zmq_path("tcp", host, port + self.rank) + logger.debug("Querying metadata on path: %s", path) + with zmq_ctx(zmq.REQ, path) as sock: + # Send query for the request. + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in nixl.""" + + _, first_kv_cache = next(iter(kv_caches.items())) + kv_elem_size = first_kv_cache.element_size() + + # TODO(tms): Find a more robust way to detect and handle MLA + use_mla = len(first_kv_cache.shape) == 3 + if use_mla: + # MLA case. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + else: + # [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + self.block_len = kv_elem_size * math.prod(block_shape) + + logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, + first_kv_cache.shape) + logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) + self.dst_num_blocks[self.engine_id] = self.num_blocks + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we can + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded NixlAgentMetadata is now larger + # (roughly 8KB vs 5KB). + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + cache_list = [cache_or_caches] if use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append(base_addr) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + self.num_layers = len(self.kv_caches.keys()) + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + + self._registered_descs.append(descs) + + # After KV Caches registered, listen for new connections. + metadata = NixlAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + ) + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.rank), + daemon=True, + name="nixl_handshake_listener") + self._nixl_handshake_listener_t.start() + ready_event.wait() + + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + engine_id = nixl_agent_meta.engine_id + assert engine_id != self.engine_id, "Conflict engine id found!" + if engine_id in self._remote_agents: + return + + self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + + # Create src descs and xfer side handles. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # Create dst descs and xfer side handles. + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + blocks_data = [] + for base_addr in self.kv_caches_base_addr[engine_id]: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * self.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for dst engine %s and rank %s", + len(blocks_data), engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id], descs) + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving. + + In TP>1 setup, each rank exchanges KVs with its counterpart + ranks independently. get_finished() runs in a worker creates + the done_sending and done_recving sets that are sent to the + scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs + are done before adding to finished, Ranks 1 to N-1 communicate + to Rank 0 once their transaction is done + Rank 0 returns + finished sets to Scheduler only once all ranks are done. + """ + done_sending = self._get_new_notifs() + done_recving = self._pop_done_transfers(self._recving_transfers) + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", self.rank, len(done_sending), + len(done_recving)) + + if self.world_size == 1: + return done_sending, done_recving + + # Rank 0: get finished from all other ranks. + if self.rank == 0: + for req_id in done_sending: + self._done_sending_count[req_id] += 1 + for req_id in done_recving: + self._done_recving_count[req_id] += 1 + + # Keep track of how many other ranks have finished. + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + if (req_id in self._done_recving_count + or req_id in self._recving_transfers): + self._done_recving_count[req_id] += 1 + else: + self._done_sending_count[req_id] += 1 + + # Return ids that finished on all ranks to the scheduler. + all_done_recving: set[str] = set() + for req_id in list(self._done_recving_count.keys()): + if self._done_recving_count[req_id] == self.world_size: + del self._done_recving_count[req_id] + all_done_recving.add(req_id) + + all_done_sending: set[str] = set() + for req_id in list(self._done_sending_count.keys()): + if self._done_sending_count[req_id] == self.world_size: + del self._done_sending_count[req_id] + all_done_sending.add(req_id) + + return all_done_sending, all_done_recving + + # Ranks 1 to N-1: send finished ids to Rank 0. + else: + finished_req_ids = list(done_recving.union(done_sending)) + self.tp_group.send_object(finished_req_ids, dst=0) + + # Unused as only Rank 0 results are sent to scheduler. + return done_sending, done_recving + + def _get_new_notifs(self) -> set[str]: + """Get req_ids which got a remote xfer message.""" + + notified_req_ids: set[str] = set() + for req_ids in self.nixl_wrapper.get_new_notifs().values(): + for req_id in req_ids: + assert req_id not in notified_req_ids + notified_req_ids.add(req_id.decode("utf-8")) + return notified_req_ids + + def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: + """ + Pop completed xfers by checking for DONE state. + Args: + transfers: dict of req_id -> list[running_xfer] + Returns: + set of req_ids that have all done xfers + """ + done_req_ids: set[str] = set() + for req_id, handles in list(transfers.items()): + running_reqs = [] + for handle in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + # TODO ptarasiewicz: why abort is throwing errors? + # self.nixl_wrapper.release_xfer_handle(handle) + continue + if xfer_state == "PROC": + running_reqs.append(handle) + else: + raise RuntimeError("Transfer failed with state %s", + xfer_state) + if len(running_reqs) == 0: + done_req_ids.add(req_id) + del transfers[req_id] + else: + transfers[req_id] = running_reqs + return done_req_ids + + def start_load_kv(self, metadata: NixlConnectorMetadata): + """ + Start loading by triggering non-blocking nixl_xfer. + We check for these trnxs to complete in each step(). + """ + for req_id, meta in metadata.requests.items(): + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_host=meta.remote_host, + remote_port=meta.remote_port, + ) + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_host: str, + remote_port: int, + dst_engine_id: str, + request_id: str, + ): + # NOTE(rob): this takes ~2s. We need to get this off the hotpath. + if dst_engine_id not in self._remote_agents: + self._nixl_handshake(remote_host, remote_port) + + # NOTE(rob): having the staging blocks be on the READER side is + # not going to work well (since we will have to call rearrange tensors). + # after we detect the txn is complete (which means we cannot make the + # read trxn async easily). If we want to make "READ" happen cleanly, + # then we will need to have the staging blocks on the remote side. + + # NOTE(rob): according to nvidia the staging blocks are used to + # saturate IB with heterogeneous TP sizes. We should remove the staging + # blocks until we are ready. + + # Full prefix cache hit: do not need to read remote blocks, + # just notify P worker that we have the blocks we need. + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + self.nixl_wrapper.send_notif(dst_engine_id, + notif_msg=request_id.encode("utf-8")) + return + + # Partial prefix cache hit: just read uncomputed blocks. + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + # Get side handles. + local_xfer_side_handle = self.src_xfer_side_handle + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + + # Get descs ids. + local_block_descs_ids: list[int] = [] + remote_block_descs_ids: list[int] = [] + if not self.block_window_per_layer: + # Default case: assume global attention + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, remote_block_ids) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, local_block_ids) + else: + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + for layer_idx, block_window in enumerate( + self.block_window_per_layer): + # For each layer: + if block_window is None: + # If not chunked, we just use the + # full block lists (global attention) + layer_local_block_ids = local_block_ids + layer_remote_block_ids = remote_block_ids + else: + # If chunked, get the last block_window blocks + layer_local_block_ids = local_block_ids[-block_window:] + layer_remote_block_ids = remote_block_ids[-block_window:] + + # Get descs ids for the layer. + layer_local_desc_ids = self._get_block_descs_ids( + self.engine_id, layer_local_block_ids, layer_idx) + layer_remote_desc_ids = self._get_block_descs_ids( + dst_engine_id, layer_remote_block_ids, layer_idx) + + local_block_descs_ids.extend(layer_local_desc_ids) + remote_block_descs_ids.extend(layer_remote_desc_ids) + + assert len(local_block_descs_ids) == len(remote_block_descs_ids) + + # Prepare transfer with Nixl. + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=request_id.encode("utf-8"), + ) + + # Begin async xfer. + self.nixl_wrapper.transfer(handle) + + # Use handle to check completion in future step(). + self._recving_transfers[request_id].append(handle) + + def _get_block_descs_ids(self, + engine_id: str, + block_ids: list[int], + layer_idx: Optional[int] = None) -> list[int]: + """ + Get the descs ids for a set of block ids. + If layer_idx is provided, we use the region_ids for the given layer. + Otherwise, we use all regions. + """ + + if layer_idx is None: + region_ids = range(self.num_regions) + else: + assert layer_idx < self.num_layers + if self.num_layers < self.num_regions: + # If we have more regions than layers, we assume that + # the regions are organized as [K0, V0, K1, V1, ...] + # and we select K_i and V_i + assert 2 * self.num_layers == self.num_regions + region_ids = range(2 * layer_idx, 2 * layer_idx + 2) + else: + # Otherwise, we assume we have MLA and select i-th layer + assert self.num_layers == self.num_regions + region_ids = range(layer_idx, layer_idx + 1) + + num_blocks = self.dst_num_blocks[engine_id] + + # Compute the desc ids for each block. + descs_ids: list[int] = [] + for reg_id in region_ids: + for block_id in block_ids: + descs_ids.append(reg_id * num_blocks + block_id) + return descs_ids + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f91ffbc720e7..0421a65a2c81 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -132,8 +133,7 @@ def inject_kv_into_layer( dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) # Get the metadata - metadata: KVConnectorMetadata = \ - self._get_connector_metadata() + metadata: KVConnectorMetadata = self._get_connector_metadata() assert isinstance(metadata, SharedStorageConnectorMetadata) if metadata is None: @@ -225,7 +225,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -239,7 +239,6 @@ def get_num_new_matched_tokens( the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ - # NOTE: in this debug implementation, we assume that the prompt is # cached_prompt + newly_generated_single_token # Therefore, we use prompt_token_ids[:-1] to determine the folder name @@ -248,7 +247,7 @@ def get_num_new_matched_tokens( # with the block granularity. And it expects the returned blocks and # num_computed_tokens to also be aligned with the block granularity. if not self._found_match_for_request(request): - return 0 + return 0, False logger.info("External Cache Hit!") @@ -257,9 +256,10 @@ def get_num_new_matched_tokens( num_tokens_to_check = align_to_block_size( len(request.prompt_token_ids) - 1, self._block_size) - return num_tokens_to_check - num_computed_tokens + return num_tokens_to_check - num_computed_tokens, False def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. @@ -288,7 +288,7 @@ def build_connector_meta( for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self._requests_need_load: meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids, + block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=False) total_need_load += 1 @@ -299,7 +299,7 @@ def build_connector_meta( # the original prompt tokens. if not self._found_match_for_request(new_req): meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids, + block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=True) @@ -319,7 +319,7 @@ def build_connector_meta( # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. - block_ids = cached_req.new_block_ids + block_ids = cached_req.new_block_ids[0] meta.add_request(token_ids=token_ids, block_ids=block_ids, diff --git a/vllm/distributed/kv_transfer/kv_connector_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py index 9d7145098105..819c06805ee4 100644 --- a/vllm/distributed/kv_transfer/kv_connector_agent.py +++ b/vllm/distributed/kv_transfer/kv_connector_agent.py @@ -5,7 +5,7 @@ 1. `send_kv_caches_and_hidden_states` 2. `recv_kv_caches_and_hidden_states """ -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata @@ -53,7 +53,7 @@ def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: @@ -68,8 +68,8 @@ def close(self) -> None: def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: return self.connector.recv_kv_caches_and_hidden_states( diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index bea42846e9e4..d1ffb8092dfc 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -13,7 +13,7 @@ """ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Optional import torch @@ -93,7 +93,7 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, @abstractmethod def drop_select( self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: """Select and *drop* KV cache entries from the lookup buffer. The functionality is similar to the following python statements @@ -111,7 +111,7 @@ def drop_select( roi (torch.Tensor): A binary mask on top of the input tokens Returns: - List[Optional[torch.Tensor]]: A list of tensors. Can be None. + list[Optional[torch.Tensor]]: A list of tensors. Can be None. Raises: NotImplementedError: This method must be implemented in subclasses. diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 10bbfe1ddd8a..e3b2274bd8a4 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -11,7 +11,7 @@ """ import threading from collections import deque -from typing import Deque, List, Optional, Union +from typing import Optional, Union import torch @@ -38,7 +38,7 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, data_pipe: on device (e.g. GPU) """ - self.buffer: Deque[List[torch.Tensor]] = deque() + self.buffer: deque[list[torch.Tensor]] = deque() self.buffer_size = 0 self.buffer_size_threshold = buffer_size_thresh @@ -50,8 +50,8 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None - def _matches(self, tokens_roi_sender: List[torch.Tensor], - tokens_roi_recver: List[torch.Tensor]): + def _matches(self, tokens_roi_sender: list[torch.Tensor], + tokens_roi_recver: list[torch.Tensor]): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) @@ -88,7 +88,7 @@ def _send_tensor_and_dec_size(self, tensor = tensor.float() self.data_pipe.send_tensor(tensor) - def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): + def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]): if isinstance(data, torch.Tensor): return data.element_size() * data.numel() @@ -151,7 +151,7 @@ def drop_select_handler(self): tokens_roi_recver = [input_tokens, roi] def is_buffer_available( - tokens_roi_recver: List[torch.Tensor], ) -> bool: + tokens_roi_recver: list[torch.Tensor], ) -> bool: # perform input tokens and roi matching # FIXME: this matching is O(n), ideally it should be O(1) # but this buffer size won't (and shouldn't) be too large so @@ -184,7 +184,7 @@ def is_buffer_available( def drop_select( self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: assert self.request_handling_thread is None, \ "drop_select should be called by the KV cache consumer "\ diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index e8bf607eb899..761c56f7e41f 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -15,7 +15,7 @@ import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Optional import torch @@ -35,7 +35,7 @@ def __init__(self, message): super().__init__(self.message) -Metadata = Dict[str, Optional[torch.Tensor]] +Metadata = dict[str, Optional[torch.Tensor]] class PyNcclPipe(KVPipeBase): @@ -83,7 +83,7 @@ def __init__(self, def _get_device_send_recv_impl( self, group: StatelessProcessGroup - ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ + ) -> tuple[Callable[[torch.Tensor, int], None], Callable[ [torch.Tensor, int], None]]: send: Callable[[torch.Tensor, int], None] @@ -118,11 +118,11 @@ def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: """ Create the metadata as a dictionary based on the input tensor. - Parameters: - - tensor: The input tensor or None if no tensor is provided. + Args: + tensor: The input tensor or None if no tensor is provided. Returns: - - metadata: A dictionary with the following keys: + metadata: A dictionary with the following keys: - "dtype": The data type of the tensor or None. - "shape": The shape of the tensor or None. """ @@ -135,13 +135,13 @@ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: """ Create a buffer to receive the tensor based on the provided metadata. - Parameters: - - metadata: A dictionary with keys "dtype" and "shape", describing - the tensor's data type and shape. + Args: + metadata: A dictionary with keys "dtype" and "shape", + describing the tensor's data type and shape. Returns: - - buffer: A tensor of the specified type and shape, allocated on - self.device. + buffer: A tensor of the specified type and shape, + allocated on `self.device`. """ return torch.empty(metadata["shape"], dtype=metadata["dtype"], @@ -151,8 +151,8 @@ def _send_metadata(self, metadata: Metadata): """ Send the metadata dictionary to the target rank. - Parameters: - - metadata: A dictionary with keys "dtype" and "shape". + Args: + metadata: A dictionary with keys "dtype" and "shape". """ self.group.send_obj(metadata, self.target_rank_for_send) @@ -161,8 +161,8 @@ def _recv_metadata(self) -> Metadata: Receive the metadata dictionary from the target rank. Returns: - - metadata: A dictionary with keys "dtype" and "shape" describing - the tensor. + metadata: A dictionary with keys "dtype" and "shape" + describing the tensor. """ return self.group.recv_obj(self.target_rank_for_recv) @@ -171,9 +171,9 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: The actual implementation of sending the tensor and its metadata to the target rank. - Parameters: - - tensor: The input tensor to be sent, or None if no tensor is - being sent. + Args: + tensor: The input tensor to be sent, or `None` if no tensor is + being sent. """ metadata = self._make_metadata(tensor) self._send_metadata(metadata) @@ -187,7 +187,7 @@ def _recv_impl(self) -> Optional[torch.Tensor]: the target rank. Returns: - - buffer: The received tensor, or None if no tensor is received. + buffer: The received tensor, or `None` if no tensor is received. """ metadata = self._recv_metadata() if metadata["dtype"] is None: @@ -227,8 +227,8 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: Sends a tensor and its metadata to the destination rank in a non-blocking way. - Parameters: - - tensor: The tensor to send, or None if no tensor is being sent. + Args: + tensor: The tensor to send, or `None` if no tensor is being sent. """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -250,8 +250,8 @@ def recv_tensor(self) -> Optional[torch.Tensor]: """ Receives a tensor and its metadata from the source rank. Blocking call. - Returns: - - tensor: The received tensor, or None if no tensor is received. + Args: + tensor: The received tensor, or `None` if no tensor is received. """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cb9658ce1004..b674d05a7771 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,7 +29,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union from unittest.mock import patch import torch @@ -54,15 +54,15 @@ class GraphCaptureContext: def _split_tensor_dict( - tensor_dict: Dict[str, Union[torch.Tensor, Any]] -) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + tensor_dict: dict[str, Union[torch.Tensor, Any]] +) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced by its metadata. 2. A list of tensors. """ - metadata_list: List[Tuple[str, Any]] = [] - tensor_list: List[torch.Tensor] = [] + metadata_list: list[tuple[str, Any]] = [] + tensor_list: list[torch.Tensor] = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): # Note: we cannot use `value.device` here, @@ -78,7 +78,7 @@ def _split_tensor_dict( return metadata_list, tensor_list -_group_name_counter: Dict[str, int] = {} +_group_name_counter: dict[str, int] = {} def _get_unique_name(name: str) -> str: @@ -94,7 +94,7 @@ def _get_unique_name(name: str) -> str: return newname -_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} +_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} def _register_group(group: "GroupCoordinator") -> None: @@ -119,7 +119,7 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group.reduce_scatter(tensor, dim) + return group._reduce_scatter_out_place(tensor, dim) def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, @@ -135,7 +135,7 @@ def all_gather(tensor: torch.Tensor, dim: int, world_size: int, group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group.all_gather(tensor, dim) + return group._all_gather_out_place(tensor, dim) def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, @@ -160,6 +160,7 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, op_func=reduce_scatter, mutates_args=[], fake_impl=reduce_scatter_fake, + dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -167,6 +168,7 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, op_func=all_gather, mutates_args=[], fake_impl=all_gather_fake, + dispatch_key=current_platform.dispatch_key, ) @@ -182,7 +184,7 @@ class GroupCoordinator: # available attributes: rank: int # global rank - ranks: List[int] # global ranks in the group + ranks: list[int] # global ranks in the group world_size: int # size of the group # difference between `local_rank` and `rank_in_group`: # if we have a group of size 4 across two nodes: @@ -201,7 +203,7 @@ class GroupCoordinator: def __init__( self, - group_ranks: List[List[int]], + group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], use_device_communicator: bool, @@ -366,6 +368,16 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if self.use_custom_op_call: + return torch.ops.vllm.all_gather(input_, + dim, + world_size, + group_name=self.unique_name) + else: + return self._all_gather_out_place(input_, dim) + + def _all_gather_out_place(self, input_: torch.Tensor, + dim: int) -> torch.Tensor: return self.device_communicator.all_gather(input_, dim) def reduce_scatter(self, @@ -378,6 +390,16 @@ def reduce_scatter(self, assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if self.use_custom_op_call: + return torch.ops.vllm.reduce_scatter(input_, + dim, + world_size, + group_name=self.unique_name) + else: + return self._reduce_scatter_out_place(input_, dim) + + def _reduce_scatter_out_place(self, input_: torch.Tensor, + dim: int) -> torch.Tensor: return self.device_communicator.reduce_scatter(input_, dim) def gather(self, @@ -435,7 +457,7 @@ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): return recv[0] def broadcast_object_list(self, - obj_list: List[Any], + obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None): """Broadcast the input object list. @@ -518,11 +540,11 @@ def recv_object(self, src: int) -> Any: def broadcast_tensor_dict( self, - tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, metadata_group: Optional[ProcessGroup] = None - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -536,7 +558,7 @@ def broadcast_tensor_dict( rank_in_group = self.rank_in_group if rank_in_group == src: - metadata_list: List[Tuple[Any, Any]] = [] + metadata_list: list[tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") @@ -603,10 +625,10 @@ def broadcast_tensor_dict( def send_tensor_dict( self, - tensor_dict: Dict[str, Union[torch.Tensor, Any]], + tensor_dict: dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. """ @@ -626,7 +648,7 @@ def send_tensor_dict( dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size, f"Invalid dst rank ({dst})" - metadata_list: List[Tuple[Any, Any]] = [] + metadata_list: list[tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" @@ -661,7 +683,7 @@ def recv_tensor_dict( self, src: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -682,7 +704,7 @@ def recv_tensor_dict( assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) - tensor_dict: Dict[str, Any] = {} + tensor_dict: dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, @@ -757,6 +779,26 @@ def destroy(self): if self.mq_broadcaster is not None: self.mq_broadcaster = None + def prepare_communication_buffer_for_model(self, model: torch.nn.Module): + if self.device_communicator is not None: + self.device_communicator.prepare_communication_buffer_for_model( + model) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.device_communicator is not None: + return self.device_communicator.dispatch(hidden_states, + router_logits) + else: + return hidden_states, router_logits + + def combine(self, hidden_states) -> torch.Tensor: + if self.device_communicator is not None: + return self.device_communicator.combine(hidden_states) + else: + return hidden_states + _WORLD: Optional[GroupCoordinator] = None @@ -766,7 +808,7 @@ def get_world_group() -> GroupCoordinator: return _WORLD -def init_world_group(ranks: List[int], local_rank: int, +def init_world_group(ranks: list[int], local_rank: int, backend: str) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], @@ -778,7 +820,7 @@ def init_world_group(ranks: List[int], local_rank: int, def init_model_parallel_group( - group_ranks: List[List[int]], + group_ranks: list[list[int]], local_rank: int, backend: str, use_message_queue_broadcaster: bool = False, @@ -816,6 +858,14 @@ def get_dp_group() -> GroupCoordinator: return _DP +_EP: Optional[GroupCoordinator] = None + + +def get_ep_group() -> GroupCoordinator: + assert _EP is not None, ("expert parallel group is not initialized") + return _EP + + def get_pp_group() -> GroupCoordinator: assert _PP is not None, ( "pipeline model parallel group is not initialized") @@ -1001,10 +1051,21 @@ def initialize_model_parallel( backend, group_name="dp") + global _EP + assert _EP is None, ("expert parallel group is already initialized") + group_ranks = all_ranks.transpose(1, 2).reshape( + -1, data_parallel_size * tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _EP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="ep") + logger.info( "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s", rank, world_size, - _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, + _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, + _EP.rank_in_group) def ensure_model_parallel_initialized( @@ -1035,6 +1096,23 @@ def ensure_model_parallel_initialized( f"{pipeline_model_parallel_size=}") +def prepare_communication_buffer_for_model(model: torch.nn.Module): + """Prepare the communication buffer for the model. + Traditional communication libraries like NCCL are almost + model agnostic. However, emerging new communication libraries like + MoE all2all (DeepEP) usually allocate the communication buffer + based on the model shape for optimal performance. + """ + if _TP is not None: + _TP.prepare_communication_buffer_for_model(model) + if _PP is not None: + _PP.prepare_communication_buffer_for_model(model) + if _DP is not None: + _DP.prepare_communication_buffer_for_model(model) + if _EP is not None: + _EP.prepare_communication_buffer_for_model(model) + + def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" return (_TP is not None and _PP is not None) @@ -1081,6 +1159,7 @@ def get_tensor_model_parallel_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP + if _TP: _TP.destroy() _TP = None @@ -1095,6 +1174,11 @@ def destroy_model_parallel(): _DP.destroy() _DP = None + global _EP + if _EP: + _EP.destroy() + _EP = None + def destroy_distributed_environment(): global _WORLD @@ -1115,8 +1199,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ray.shutdown() gc.collect() from vllm.platforms import current_platform - if not current_platform.is_cpu(): - torch.cuda.empty_cache() + empty_cache = current_platform.empty_cache + if empty_cache is not None: + empty_cache() try: torch._C._host_emptyCache() except AttributeError: @@ -1125,7 +1210,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], - source_rank: int = 0) -> List[bool]: + source_rank: int = 0) -> list[bool]: """ This is a collective operation that returns if each rank is in the same node as the source rank. It tests if processes are attached to the same diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index e4d4008cd0a6..93a069d36c4b 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -6,11 +6,15 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses import datetime +import os import pickle import socket +import sys import time +import uuid from collections import deque -from typing import Any, Deque, Dict, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Any, Optional import torch from torch.distributed import ProcessGroup, TCPStore @@ -22,9 +26,24 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import get_tcp_uri, is_torch_equal_or_newer logger = init_logger(__name__) +# We prefer to use os.sched_yield as it results in tighter polling loops, +# measured to be around 3e-7 seconds. However on earlier versions of Python +# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) +USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) + or (sys.version_info[:2] == (3, 10) + and sys.version_info[2] >= 8)) + + +def sched_yield(): + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(0) + def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" @@ -68,7 +87,7 @@ def split_tensor_along_last_dim( def get_pp_indices(num_hidden_layers: int, pp_rank: int, - pp_size: int) -> Tuple[int, int]: + pp_size: int) -> tuple[int, int]: """Try to evenly distribute layers across partitions. If the number of layers is not divisible by the number of partitions, @@ -131,15 +150,15 @@ class StatelessProcessGroup: data_expiration_seconds: int = 3600 # 1 hour # dst rank -> counter - send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict) # src rank -> counter - recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) broadcast_send_counter: int = 0 - broadcast_recv_src_counter: Dict[int, int] = dataclasses.field( + broadcast_recv_src_counter: dict[int, int] = dataclasses.field( default_factory=dict) # A deque to store the data entries, with key and timestamp. - entries: Deque[Tuple[str, + entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque) def __post_init__(self): @@ -210,10 +229,141 @@ def all_gather_obj(self, obj: Any) -> list[Any]: gathered_objs.append(recv_obj) return gathered_objs - def barrier(self): - """A barrier to synchronize all ranks.""" + def barrier(self, timeout: float = 30.0): + """A robust barrier to synchronize all ranks. + + + Uses a multi-phase approach to ensure all processes reach the barrier + before proceeding: + + 1. Each process signals it has reached the barrier + + 2. Each process signals that it has confirmed the arrival of all other + ranks. + + 3. Rank 0 waits for all other ranks to signal their departure to ensure + that all ranks have departed the barrier first. + + Args: + timeout: Maximum time in seconds to wait for each phase (in seconds) + + + Raises: + RuntimeError: If coordination fails or times out + """ + # Generate a barrier ID that is globally unique + try: + if self.rank == 0: + barrier_id = f"barrier_{uuid.uuid4()}" + self.broadcast_obj(barrier_id, src=0) + else: + barrier_id = self.broadcast_obj(None, src=0) + except Exception as e: + raise RuntimeError("Failed to broadcast barrier_id") from e + + # Phase 1: Signal arrival at barrier + # Wait for all processes to arrive + # We need all ranks to confirm the arrival of all other ranks. + # This is the key synchronization point. + arrival_key = f"arrival_{barrier_id}_{self.rank}" + try: + self.store.set(arrival_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier arrival") from e + + start_time = time.time() + processes_arrived: set[int] = set() + + while len(processes_arrived) < self.world_size: + # Check for timeout + cur_time = time.time() + if cur_time - start_time > timeout: + raise RuntimeError("Barrier timed out after %f seconds", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_arrived: + continue + + key = f"arrival_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_arrived.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_arrived) < self.world_size: + sched_yield() + + # Phase 2: Signal departure from barrier + # We only care to block at this stage in rank 0, which runs the + # server side of the TCPStore. We want to make sure that all + # clients have departed the barrier before rank 0 in case the + # next thing after the barrier is a shutdown, including tearing + # down the TCPStore. Other ranks can exit the barrier immediately + # after signaling their departure. + departure_key = f"departure_{barrier_id}_{self.rank}" + try: + self.store.set(departure_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier departure") from e + + if self.rank != 0: + return + + # Make rank 0 wait for all processes to signal departure + start_time = time.time() + processes_departed: set[int] = set() + + while len(processes_departed) < self.world_size: + # Check for timeout + if time.time() - start_time > timeout: + raise RuntimeError("Barrier departure timed out after %f s", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_departed: + continue + + key = f"departure_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_departed.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_departed) < self.world_size: + sched_yield() + + # Clean up keys to avoid leaking memory in the store for i in range(self.world_size): - self.broadcast_obj(None, src=i) + try: + self.store.delete_key(f"arrival_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'arrival_{barrier_id}_{i}') + + try: + self.store.delete_key(f"departure_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'departure_{barrier_id}_{i}') @staticmethod def create( @@ -303,7 +453,7 @@ def stateless_init_torch_distributed_process_group( always formed with process 1, 2, ..., 8, and the additional communication channel is formed with process 9 and 10. """ - init_method = f"tcp://{host}:{port}" + init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string timeout = _get_default_timeout(backend) @@ -360,7 +510,11 @@ def stateless_destroy_torch_distributed_process_group( Destroy ProcessGroup returned by stateless_init_torch_distributed_process_group(). """ - # Lazy import for non-CUDA backends. - from torch.distributed.distributed_c10d import _shutdown_backend - _shutdown_backend(pg) + if is_torch_equal_or_newer("2.7"): + pg.shutdown() + else: + # Lazy import for non-CUDA backends. + from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) + _unregister_process_group(pg.group_name) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 27af74e2e349..442e4100fea1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,14 +4,15 @@ import argparse import dataclasses import json -import re +import sys import threading import warnings -from dataclasses import MISSING, dataclass, fields +from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations -from typing import (Any, Callable, Dict, List, Literal, Optional, Type, - TypeVar, Union, cast, get_args, get_origin) +from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, + Type, TypeVar, Union, cast, get_args, get_origin) +import regex as re import torch from typing_extensions import TypeIs, deprecated @@ -36,7 +37,8 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor +from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, + GiB_bytes, is_in_doc_build, is_in_ray_actor) # yapf: enable @@ -48,12 +50,9 @@ TypeHintT = Union[type[T], object] -def optional_type( - return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: +def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: - def _optional_type(val: str) -> Optional[T]: - if val == "" or val == "None": - return None + def _parse_type(val: str) -> T: try: if return_type is json.loads and not re.match("^{.*}$", val): return cast(T, nullable_kvs(val)) @@ -62,14 +61,24 @@ def _optional_type(val: str) -> Optional[T]: raise argparse.ArgumentTypeError( f"Value {val} cannot be converted to {return_type}.") from e + return _parse_type + + +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + return parse_type(return_type)(val) + return _optional_type def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: if not re.match("^{.*}$", val): return str(val) - else: - return optional_type(json.loads)(val) + return optional_type(json.loads)(val) @deprecated( @@ -144,10 +153,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: cls_docs = get_attr_docs(cls) kwargs = {} for field in fields(cls): + # Get the set of possible types for the field + type_hints: set[TypeHint] = set() + if get_origin(field.type) in {Union, Annotated}: + type_hints.update(get_args(field.type)) + else: + type_hints.add(field.type) + + # If the field is a dataclass, we can use the model_validate_json + generator = (th for th in type_hints if is_dataclass(th)) + dataclass_cls = next(generator, None) + # Get the default value of the field - default = field.default - if field.default_factory is not MISSING: - default = field.default_factory() + if field.default is not MISSING: + default = field.default + elif field.default_factory is not MISSING: + if is_dataclass(field.default_factory) and is_in_doc_build(): + default = {} + else: + default = field.default_factory() # Get the help text for the field name = field.name @@ -158,16 +182,21 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Initialise the kwargs dictionary for the field kwargs[name] = {"default": default, "help": help} - # Get the set of possible types for the field - type_hints: set[TypeHint] = set() - if get_origin(field.type) is Union: - type_hints.update(get_args(field.type)) - else: - type_hints.add(field.type) - # Set other kwargs based on the type hints - json_tip = "\n\nShould be a valid JSON string." - if contains_type(type_hints, bool): + json_tip = """\n\nShould either be a valid JSON string or JSON keys + passed individually. For example, the following sets of arguments are + equivalent:\n\n + - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" + if dataclass_cls is not None: + dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x)) + # Special case for configs with a from_cli method + if hasattr(dataclass_cls, "from_cli"): + from_cli = dataclass_cls.from_cli + dataclass_init = lambda x, f=from_cli: f(x) + kwargs[name]["type"] = dataclass_init + kwargs[name]["help"] += json_tip + elif contains_type(type_hints, bool): # Creates --no-<name> and --<name> flags kwargs[name]["action"] = argparse.BooleanOptionalAction elif contains_type(type_hints, Literal): @@ -202,7 +231,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): # Dict arguments will always be optional - kwargs[name]["type"] = optional_type(json.loads) + kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += json_tip elif (contains_type(type_hints, str) or any(is_not_builtin(th) for th in type_hints)): @@ -258,6 +287,9 @@ class EngineArgs: pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size + data_parallel_size_local: Optional[int] = None + data_parallel_address: Optional[str] = None + data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers @@ -416,7 +448,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="ModelConfig", description=ModelConfig.__doc__, ) - model_group.add_argument("--model", **model_kwargs["model"]) + if 'serve' not in sys.argv[1:] and '--help' not in sys.argv[1:]: + model_group.add_argument("--model", **model_kwargs["model"]) model_group.add_argument("--task", **model_kwargs["task"]) model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) model_group.add_argument("--tokenizer-mode", @@ -544,7 +577,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action=argparse.BooleanOptionalAction, deprecated=True, help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as " - "of v0.8.6. Use `--reasoning-parser` to specify the reasoning " + "of v0.9.0. Use `--reasoning-parser` to specify the reasoning " "parser backend instead. This flag (`--enable-reasoning`) will be " "removed in v0.10.0. When `--reasoning-parser` is specified, " "reasoning mode is automatically enabled.") @@ -570,6 +603,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["tensor_parallel_size"]) parallel_group.add_argument("--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]) + parallel_group.add_argument('--data-parallel-size-local', + '-dpl', + type=int, + help='Number of data parallel replicas ' + 'to run on this node.') + parallel_group.add_argument('--data-parallel-address', + '-dpa', + type=str, + help='Address of data parallel cluster ' + 'head-node.') + parallel_group.add_argument('--data-parallel-rpc-port', + '-dpp', + type=int, + help='Port for data parallel RPC ' + 'communication.') parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) @@ -689,7 +737,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="DeviceConfig", description=DeviceConfig.__doc__, ) - device_group.add_argument("--device", **device_kwargs["device"]) + device_group.add_argument("--device", + **device_kwargs["device"], + deprecated=True) # Speculative arguments speculative_group = parser.add_argument_group( @@ -771,63 +821,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: scheduler_group.add_argument("--scheduler-cls", **scheduler_kwargs["scheduler_cls"]) - # Compilation arguments - # compilation_kwargs = get_kwargs(CompilationConfig) - compilation_group = parser.add_argument_group( - title="CompilationConfig", - description=CompilationConfig.__doc__, - ) - compilation_group.add_argument( - "--compilation-config", - "-O", - type=CompilationConfig.from_cli, - default=None, - help="torch.compile configuration for the model. " - "When it is a number (0, 1, 2, 3), it will be " - "interpreted as the optimization level.\n" - "NOTE: level 0 is the default level without " - "any optimization. level 1 and 2 are for internal " - "testing only. level 3 is the recommended level " - "for production.\n" - "To specify the full compilation config, " - "use a JSON string, e.g. ``{\"level\": 3, " - "\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n" - "Following the convention of traditional " - "compilers, using ``-O`` without space is also " - "supported. ``-O3`` is equivalent to ``-O 3``.") - - # KVTransfer arguments - # kv_transfer_kwargs = get_kwargs(KVTransferConfig) - kv_transfer_group = parser.add_argument_group( - title="KVTransferConfig", - description=KVTransferConfig.__doc__, - ) - kv_transfer_group.add_argument( - "--kv-transfer-config", - type=KVTransferConfig.from_cli, - default=None, - help="The configurations for distributed KV cache " - "transfer. Should be a JSON string.") - kv_transfer_group.add_argument( - '--kv-events-config', - type=KVEventsConfig.from_cli, - default=None, - help='The configurations for event publishing.') - # vLLM arguments - # vllm_kwargs = get_kwargs(VllmConfig) + vllm_kwargs = get_kwargs(VllmConfig) vllm_group = parser.add_argument_group( title="VllmConfig", description=VllmConfig.__doc__, ) - vllm_group.add_argument( - "--additional-config", - type=json.loads, - default=None, - help="Additional config for specified platform in JSON format. " - "Different platforms may support different configs. Make sure the " - "configs are valid for the platform you are using. The input format" - " is like '{\"config_key\":\"config_value\"}'") + vllm_group.add_argument("--kv-transfer-config", + **vllm_kwargs["kv_transfer_config"]) + vllm_group.add_argument('--kv-events-config', + **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument("--compilation-config", "-O", + **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--additional-config", + **vllm_kwargs["additional_config"]) # Other arguments parser.add_argument('--use-v2-block-manager', @@ -972,7 +979,7 @@ def create_engine_config( from vllm.platforms import current_platform current_platform.pre_register_and_update() - device_config = DeviceConfig(device=self.device) + device_config = DeviceConfig(device=current_platform.device_type) model_config = self.create_model_config() # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" @@ -1000,6 +1007,17 @@ def create_engine_config( assert self.enable_chunked_prefill is not None + if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: + assert self.enforce_eager, ( + "Cuda graph is not supported with DualChunkFlashAttention. " + "To run the model in eager mode, set 'enforce_eager=True' " + "or use '--enforce-eager' in the CLI.") + assert current_platform.is_cuda(), ( + "DualChunkFlashAttention is only supported on CUDA platform.") + assert not use_v1, ( + "DualChunkFlashAttention is not supported on V1 engine. " + "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, @@ -1025,10 +1043,30 @@ def create_engine_config( # but we should not do this here. placement_group = ray.util.get_current_placement_group() + # Local DP size defaults to global DP size if not set. + data_parallel_size_local = self.data_parallel_size if ( + self.data_parallel_size_local + is None) else self.data_parallel_size_local + + # DP address, used in multi-node case for torch distributed group + # and ZMQ sockets. + data_parallel_address = self.data_parallel_address if ( + self.data_parallel_address + is not None) else ParallelConfig.data_parallel_master_ip + + # This port is only used when there are remote data parallel engines, + # otherwise the local IPC transport is used. + data_parallel_rpc_port = self.data_parallel_rpc_port if ( + self.data_parallel_rpc_port + is not None) else ParallelConfig.data_parallel_rpc_port + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, + data_parallel_size_local=data_parallel_size_local, + data_parallel_master_ip=data_parallel_address, + data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, @@ -1046,7 +1084,7 @@ def create_engine_config( disable_log_stats=self.disable_log_stats, ) - # Reminder: Please update docs/source/features/compatibility_matrix.md + # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid if self.num_scheduler_steps > 1: if speculative_config is not None: @@ -1157,8 +1195,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ############################################################# # Unsupported Feature Flags on V1. - if (self.load_format == LoadFormat.TENSORIZER.value - or self.load_format == LoadFormat.SHARDED_STATE.value): + if self.load_format == LoadFormat.SHARDED_STATE.value: _raise_or_fallback( feature_name=f"--load_format {self.load_format}", recommend_to_remove=False) @@ -1224,7 +1261,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: and not envs.is_set("VLLM_ATTENTION_BACKEND") ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" supported = False - if fp8_attention and will_use_fa: + if current_platform.is_rocm(): + supported = True + elif fp8_attention and will_use_fa: from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8) supported = flash_attn_supports_fp8() @@ -1252,14 +1291,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # Some quantization is not compatible with torch.compile. - V1_UNSUPPORTED_QUANT = ["gguf"] - if model_config.quantization in V1_UNSUPPORTED_QUANT: - _raise_or_fallback( - feature_name=f"--quantization {model_config.quantization}", - recommend_to_remove=False) - return False - # No Embedding Models so far. if model_config.task not in ["generate"]: _raise_or_fallback(feature_name=f"--task {model_config.task}", @@ -1287,22 +1318,25 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # Only Ngram speculative decoding so far. + # V1 supports N-gram, Medusa, and Eagle speculative decoding. is_ngram_enabled = False is_eagle_enabled = False + is_medusa_enabled = False if self.speculative_config is not None: # This is supported but experimental (handled below). speculative_method = self.speculative_config.get("method") if speculative_method: if speculative_method in ("ngram", "[ngram]"): is_ngram_enabled = True - elif speculative_method in ("eagle", "eagle3"): + elif speculative_method == "medusa": + is_medusa_enabled = True + elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") if speculative_model in ("ngram", "[ngram]"): is_ngram_enabled = True - if not (is_ngram_enabled or is_eagle_enabled): + if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled): # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", recommend_to_remove=False) @@ -1319,6 +1353,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "FLASHMLA", "FLASHINFER", "FLASHINFER_VLLM_V1", + "ROCM_AITER_MLA", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): @@ -1341,20 +1376,13 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: return False if (self.pipeline_parallel_size > 1 - and self.distributed_executor_backend not in ["ray", "mp"]): + and self.distributed_executor_backend + not in ("ray", "mp", "external_launcher")): name = "Pipeline Parallelism without Ray distributed executor " \ - "or multiprocessing executor" + "or multiprocessing executor or external launcher" _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False - # ngram is supported on V1, but off by default for now. - if is_ngram_enabled and _warn_or_fallback("ngram"): - return False - - # Eagle is under development, so we don't support it yet. - if is_eagle_enabled and _warn_or_fallback("Eagle"): - return False - # Non-[CUDA, TPU] may be supported on V1, but off by default for now. v0_hardware = not any( (current_platform.is_cuda(), current_platform.is_tpu())) @@ -1456,11 +1484,15 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: from vllm.platforms import current_platform try: device_memory = current_platform.get_device_total_memory() + device_name = current_platform.get_device_name().lower() except Exception: # This is only used to set default_max_num_batched_tokens device_memory = 0 - if device_memory >= 70 * GiB_bytes: + # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces + # throughput, see PR #17885 for more details. + # So here we do an extra device name check to prevent such regression. + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { UsageContext.LLM_CLASS: 16384, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37bb12d44287..19b219b674f3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -475,7 +475,8 @@ async def add_request_async( *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: - """Async version of {meth}`add_request`.""" + """Async version of + [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request].""" if inputs is not None: prompt = inputs assert prompt is not None and params is not None @@ -582,20 +583,21 @@ async def build_guided_decoding_logits_processor_async( class AsyncLLMEngine(EngineClient): - """An asynchronous wrapper for {class}`LLMEngine`. + """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine]. - This class is used to wrap the {class}`LLMEngine` class to make it - asynchronous. It uses asyncio to create a background loop that keeps - processing incoming requests. The {class}`LLMEngine` is kicked by the - generate method when there are requests in the waiting queue. The generate - method yields the outputs from the {class}`LLMEngine` to the caller. + This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to + make it asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked + by the generate method when there are requests in the waiting queue. The + generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine] + to the caller. Args: log_requests: Whether to log the requests. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. - *args: Arguments for {class}`LLMEngine`. - **kwargs: Arguments for {class}`LLMEngine`. + *args: Arguments for [`LLMEngine`][vllm.LLMEngine]. + **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine]. """ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine @@ -985,8 +987,9 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType` - for more details about the format of each input. + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. @@ -1003,7 +1006,7 @@ async def generate( Details: - If the engine is not running, start the background loop, which iteratively invokes - {meth}`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step] to process the waiting requests. - Add the request to the engine's `RequestTracker`. On the next background loop, this request will be sent to @@ -1075,8 +1078,9 @@ async def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType` - for more details about the format of each input. + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. @@ -1089,15 +1093,15 @@ async def encode( for the request. Details: - - If the engine is not running, start the background loop, - which iteratively invokes - {meth}`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` - to process the waiting requests. - - Add the request to the engine's `RequestTracker`. - On the next background loop, this request will be sent to - the underlying engine. - Also, a corresponding `AsyncStream` will be created. - - Wait for the request outputs from `AsyncStream` and yield them. + - If the engine is not running, start the background loop, + which iteratively invokes + [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][] + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. Example: ``` @@ -1232,6 +1236,9 @@ async def start_profile(self) -> None: async def stop_profile(self) -> None: self.engine.stop_profile() + async def reset_mm_cache(self) -> None: + self.engine.reset_mm_cache() + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: self.engine.reset_prefix_cache(device) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bed696d3dc00..ff33d566ab68 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -130,26 +130,16 @@ class LLMEngine: iteration-level scheduling and efficient memory management to maximize the serving throughput. - The {class}`~vllm.LLM` class wraps this class for offline batched inference - and the {class}`AsyncLLMEngine` class wraps this class for online serving. + The [`LLM`][vllm.LLM] class wraps this class for offline batched inference + and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine] + class wraps this class for online serving. - The config arguments are derived from {class}`~vllm.EngineArgs`. (See - {ref}`engine-args`) + The config arguments are derived from [`EngineArgs`][vllm.EngineArgs]. Args: - model_config: The configuration related to the LLM model. - cache_config: The configuration related to the KV cache memory - management. - parallel_config: The configuration related to distributed execution. - scheduler_config: The configuration related to the request scheduler. - device_config: The configuration related to the device. - lora_config (Optional): The configuration related to serving multi-LoRA. - speculative_config (Optional): The configuration related to speculative - decoding. + vllm_config: The configuration for initializing and running vLLM. executor_class: The model executor class for managing distributed execution. - prompt_adapter_config (Optional): The configuration related to serving - prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. """ @@ -409,6 +399,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # the next step without re-scheduling. self._skip_scheduling_next_step = False + # Don't keep the dummy data in memory + self.reset_mm_cache() + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -692,11 +685,12 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType` + prompt: The prompt to the LLM. See + [PromptType][vllm.inputs.PromptType] for more details about the format of each input. params: Parameters for sampling or pooling. - {class}`~vllm.SamplingParams` for text generation. - {class}`~vllm.PoolingParams` for pooling. + [SamplingParams][vllm.SamplingParams] for text generation. + [PoolingParams][vllm.PoolingParams] for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. lora_request: The LoRA request to add. @@ -708,10 +702,11 @@ def add_request( Details: - Set arrival_time to the current time if it is None. - Set prompt_token_ids to the encoded prompt if it is None. - - Create `n` number of {class}`~vllm.Sequence` objects. - - Create a {class}`~vllm.SequenceGroup` object - from the list of {class}`~vllm.Sequence`. - - Add the {class}`~vllm.SequenceGroup` object to the scheduler. + - Create `n` number of [Sequence][vllm.Sequence] objects. + - Create a [SequenceGroup][vllm.SequenceGroup] object + from the list of [Sequence][vllm.Sequence]. + - Add the [SequenceGroup][vllm.SequenceGroup] object to the + scheduler. Example: >>> # initialize engine @@ -858,9 +853,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: request_id: The ID(s) of the request to abort. Details: - - Refer to the - {meth}`~vllm.core.scheduler.Scheduler.abort_seq_group` - from class {class}`~vllm.core.scheduler.Scheduler`. + - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][]. Example: >>> # initialize engine and add a request with request_id @@ -913,6 +906,10 @@ def has_unfinished_requests_for_virtual_engine( """ return self.scheduler[virtual_engine].has_unfinished_seqs() + def reset_mm_cache(self) -> bool: + """Reset the multi-modal cache.""" + return self.input_preprocessor.mm_registry.reset_processor_cache() + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" @@ -1256,12 +1253,10 @@ def _advance_to_next_step( def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. - :::{figure} https://i.imgur.com/sv2HssD.png - :alt: Overview of the step function - :align: center - - Overview of the step function. - ::: + <figure markdown="span"> + ![Overview of the step function](https://i.imgur.com/sv2HssD.png) + <figcaption>Overview of the step function</figcaption> + </figure> Details: - Step 1: Schedules the sequences to be executed in the next @@ -1655,6 +1650,20 @@ def _get_stats(self, gpu_prefix_cache_hit_rate = self.scheduler[ 0].get_prefix_cache_hit_rate(Device.GPU) + # Exchange the uasge and cache hit stats between gpu and cpu when + # running on cpu because the cpu_worker.py intentionally reports the + # number of cpu blocks as gpu blocks in favor of cache management. + if self.device_config.device_type == "cpu": + num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu + gpu_cache_usage_sys, cpu_cache_usage_sys = ( + cpu_cache_usage_sys, + gpu_cache_usage_sys, + ) + gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = ( + cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate, + ) + # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 033551d07c39..34b48f83b643 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -29,7 +29,7 @@ # to extract the metrics definitions. -# begin-metrics-definitions +# --8<-- [start:metrics-definitions] class Metrics: """ vLLM uses a multiprocessing-based frontend for the OpenAI server. @@ -293,7 +293,7 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): labelnames=labelnames)) -# end-metrics-definitions +# --8<-- [end:metrics-definitions] def _unregister_vllm_metrics(self) -> None: for collector in list(prometheus_client.REGISTRY._collector_to_names): diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index cafd8150bc01..af72c8e6b776 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -123,6 +123,10 @@ class RPCUProfileRequest(Enum): STOP_PROFILE = 2 +class RPCResetMultiModalCacheRequest(Enum): + RESET = 1 + + @dataclass class RPCResetPrefixCacheRequest: device: Device @@ -164,6 +168,7 @@ class RPCAdapterLoadedResponse: RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPCUProfileRequest, RPCLoadAdapterRequest, + RPCResetMultiModalCacheRequest, RPCResetPrefixCacheRequest, RPCSleepRequest, RPCWakeUpRequest, RPCIsSleepingRequest] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 505d3d06b3ca..18b7c187bdff 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -31,6 +31,7 @@ RPCIsSleepingResponse, RPCLoadAdapterRequest, RPCProcessRequest, + RPCResetMultiModalCacheRequest, RPCResetPrefixCacheRequest, RPCSleepRequest, RPCStartupRequest, RPCStartupResponse, @@ -491,8 +492,9 @@ def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType` - for more details about the format of each input. + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. @@ -560,8 +562,9 @@ def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType` - for more details about the format of each input. + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. @@ -687,6 +690,13 @@ async def stop_profile(self) -> None: await self._send_one_way_rpc_request( request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) + async def reset_mm_cache(self) -> None: + """Reset the multi-modal cache""" + + await self._send_one_way_rpc_request( + request=RPCResetMultiModalCacheRequest.RESET, + socket=self.input_socket) + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: """Reset the prefix cache""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a5dcf9e2d945..434cb4985562 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -22,6 +22,7 @@ RPCIsSleepingResponse, RPCLoadAdapterRequest, RPCProcessRequest, + RPCResetMultiModalCacheRequest, RPCResetPrefixCacheRequest, RPCSleepRequest, RPCStartupRequest, RPCStartupResponse, @@ -41,19 +42,22 @@ class MQLLMEngine: - """A multiprocessing wrapper for {class}`LLMEngine`. + """A multiprocessing wrapper for + [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - This class is used to wrap the {class}`LLMEngine` class to enable use + This class is used to wrap the + [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use in concurrnet manner. It runs a background loop and uses zeromq to receive new requests and stream outputs incrementally via ipc. - The {class}`LLMEngine` generate or encode process is kicked off when a new - RPCProcessRequest is received by the input_socket. + The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode + process is kicked off when a new RPCProcessRequest is received by the + input_socket. The self.engine_loop checks the input_socket for new requests, adds them to the LLMEngine if there are any, calls the internal - {class}`LLMEngine.step()`, and sends the RequestOutputs back over - the output_socket. + [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends + the RequestOutputs back over the output_socket. If use_async_sockets is set, the logic associated with reading new requests from the socket and sending data to the socket is passed @@ -64,8 +68,8 @@ class MQLLMEngine: ipc_path: Base path for zeromq interprocess messaging use_async_sockets: Whether to make send/recv async with GPU log_requests: Whether to log the requests. - *args: Arguments for {class}`LLMEngine`. - **kwargs: Arguments for {class}`LLMEngine`. + *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. + **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. """ def __init__(self, @@ -269,6 +273,8 @@ def handle_new_input(self): self.stop_profile() elif isinstance(request, RPCLoadAdapterRequest): self._handle_load_adapter_request(request) + elif isinstance(request, RPCResetMultiModalCacheRequest): + self.reset_mm_cache() elif isinstance(request, RPCResetPrefixCacheRequest): self.reset_prefix_cache() elif isinstance(request, RPCSleepRequest): @@ -409,6 +415,9 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.engine.stop_profile() + def reset_mm_cache(self) -> bool: + return self.engine.reset_mm_cache() + def reset_prefix_cache(self) -> bool: return self.engine.reset_prefix_cache() diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 4cfb22c5a750..110f84a65efc 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -56,8 +56,11 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, scheduled computation. Args: - seq_group: the outputs are associated with this {class}`SequenceGroup` - outputs: the {class}`SequenceGroupOutput`s for all scheduler steps + seq_group: the outputs are associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + outputs: the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s + for all scheduler steps """ for output in outputs: # Concatenate single-step prompt logprob processing results. @@ -67,7 +70,7 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, @staticmethod @functools.lru_cache def _log_prompt_logprob_unsupported_warning_once(): - # Reminder: Please update docs/source/features/compatibility_matrix.md + # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid logger.warning( "Prompt logprob is not supported by multi step workers. " diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index ea4b71a5b9cd..e88f119c8742 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -19,17 +19,21 @@ def single_step_process_prompt_logprob( sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, output: CompletionSequenceGroupOutput) -> None: - """Process prompt logprobs associated with the {class}`SequenceGroupOutput` - for a given step. + """Process prompt logprobs associated with the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step. Do nothing if the output has no prompt logprobs. Account for the fact that transformers do not compute first-token logprobs. Args: - sg_output_proc: {class}`SequenceGroupOutputProcessor` instance - seq_group: the output is associated with this {class}`SequenceGroup` - output: the {class}`SequenceGroupOutput` for a single scheduler step + sg_output_proc: + [`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor] + instance + seq_group: the output is associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] + for a single scheduler step """ prompt_logprobs = output.prompt_logprobs @@ -103,8 +107,11 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, scheduled computation. Args: - seq_group: the output is associated with this {class}`SequenceGroup` - outputs: the {class}`SequenceGroupOutput` for a single scheduler step + seq_group: the output is associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + outputs: the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] + for a single scheduler step """ assert len(outputs) == 1, "Single step should only have 1 output." output = outputs[0] diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e9350612ee57..a837a2d288a9 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -278,6 +278,11 @@ async def stop_profile(self) -> None: """Start profiling the engine""" ... + @abstractmethod + async def reset_mm_cache(self) -> None: + """Reset the multi-modal cache""" + ... + @abstractmethod async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 38fe98572178..ec1b327da905 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -44,6 +44,7 @@ # yapf: enable from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import deprecate_kwargs, random_uuid logger = init_logger(__name__) @@ -328,11 +329,17 @@ def resolve_mistral_chat_template( "so it will be ignored.") return None +@deprecate_kwargs( + "trust_remote_code", + additional_message="Please use `model_config.trust_remote_code` instead.", +) def resolve_hf_chat_template( - model_config: ModelConfig, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], + *, + model_config: ModelConfig, + trust_remote_code: Optional[bool] = None, ) -> Optional[str]: # 1st priority: The given chat template if chat_template is not None: @@ -348,11 +355,11 @@ def resolve_hf_chat_template( trust_remote_code=model_config.trust_remote_code, ) if isinstance(processor, ProcessorMixin) and \ + hasattr(processor, 'chat_template') and \ processor.chat_template is not None: return processor.chat_template except Exception: - logger.debug("Failed to load AutoProcessor chat template for %s", - tokenizer.name_or_path, exc_info=True) + logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 # 3rd priority: AutoTokenizer chat template try: @@ -378,18 +385,18 @@ def resolve_hf_chat_template( def _resolve_chat_template_content_format( - model_config: ModelConfig, chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], - given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, + *, + model_config: ModelConfig, ) -> _ChatTemplateContentFormat: if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): hf_chat_template = resolve_hf_chat_template( - model_config, tokenizer, chat_template=chat_template, tools=tools, + model_config=model_config, ) else: hf_chat_template = None @@ -400,7 +407,7 @@ def _resolve_chat_template_content_format( detected_format = ("string" if jinja_text is None else _detect_content_format(jinja_text, default="string")) - return detected_format if given_format == "auto" else given_format + return detected_format @lru_cache @@ -427,19 +434,24 @@ def _log_chat_template_content_format( ) +@deprecate_kwargs( + "trust_remote_code", + additional_message="Please use `model_config.trust_remote_code` instead.", +) def resolve_chat_template_content_format( - model_config: ModelConfig, chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, + *, + model_config: ModelConfig, + trust_remote_code: Optional[bool] = None, ) -> _ChatTemplateContentFormat: detected_format = _resolve_chat_template_content_format( - model_config, chat_template, tools, - given_format, tokenizer, + model_config=model_config, ) _log_chat_template_content_format( @@ -448,7 +460,8 @@ def resolve_chat_template_content_format( detected_format=detected_format, ) - return detected_format + return detected_format if given_format == "auto" else given_format + ModalityStr = Literal["image", "audio", "video", "image_embeds"] @@ -512,7 +525,7 @@ def _placeholder_str(self, modality: ModalityStr, hf_config.image_token_index) if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2", - "internvl_chat", "ovis2", "skywork_chat", + "internvl_chat", "ovis", "skywork_chat", "NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"): return "<image>" if model_type in ("mllama", "llama4"): @@ -543,6 +556,8 @@ def _placeholder_str(self, modality: ModalityStr, return "(<audio>./</audio>)" raise TypeError(f"Unknown model type: {model_type}") elif modality == "video": + if model_type == "internvl_chat": + return "<video>" if model_type in ("qwen2_vl", "qwen2_5_vl"): return "<|vision_start|><|video_pad|><|vision_end|>" if model_type == "qwen2_5_omni": @@ -1190,21 +1205,27 @@ def parse_chat_messages_futures( return conversation, mm_tracker.all_mm_data() +@deprecate_kwargs( + "trust_remote_code", + additional_message="Please use `model_config.trust_remote_code` instead.", +) def apply_hf_chat_template( - model_config: ModelConfig, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: list[ConversationMessage], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], *, + model_config: ModelConfig, tokenize: bool = False, # Different from HF's default + # Deprecated, explicitly capture here so it doesn't slit into kwargs. + trust_remote_code: Optional[bool] = None, **kwargs: Any, ) -> str: hf_chat_template = resolve_hf_chat_template( - model_config, tokenizer, chat_template=chat_template, tools=tools, + model_config=model_config, ) if hf_chat_template is None: @@ -1272,3 +1293,6 @@ def apply_mistral_chat_template( "An error occurred in `mistral_common` while applying chat " "template") raise ValueError from e + +def random_tool_call_id() -> str: + return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py index d5f9f7e729f0..810ecfdf71c3 100644 --- a/vllm/entrypoints/cli/collect_env.py +++ b/vllm/entrypoints/cli/collect_env.py @@ -4,12 +4,11 @@ from vllm.collect_env import main as collect_env_main from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils import FlexibleArgumentParser class CollectEnvSubcommand(CLISubcommand): - """The `serve` subcommand for the vLLM CLI. """ + """The `collect-env` subcommand for the vLLM CLI. """ def __init__(self): self.name = "collect-env" @@ -23,12 +22,12 @@ def cmd(args: argparse.Namespace) -> None: def subparser_init( self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: - serve_parser = subparsers.add_parser( + collect_env_parser = subparsers.add_parser( "collect-env", help="Start collecting environment information.", description="Start collecting environment information.", usage="vllm collect-env") - return make_arg_parser(serve_parser) + return collect_env_parser def cmd_init() -> list[CLISubcommand]: diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index b7c1afce7118..6676c294c81c 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -9,7 +9,7 @@ import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.serve import vllm.version -from vllm.entrypoints.utils import cli_env_setup +from vllm.entrypoints.utils import VLLM_SERVE_PARSER_EPILOG, cli_env_setup from vllm.utils import FlexibleArgumentParser CMD_MODULES = [ @@ -32,7 +32,10 @@ def signal_handler(sig, frame): def main(): cli_env_setup() - parser = FlexibleArgumentParser(description="vLLM CLI") + parser = FlexibleArgumentParser( + description="vLLM CLI", + epilog=VLLM_SERVE_PARSER_EPILOG, + ) parser.add_argument('-v', '--version', action='version', diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 1d1bba1d49ce..215fcf3c3e44 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -101,9 +101,18 @@ def cmd(args: argparse.Namespace) -> None: model_name, client = _interactive_cli(args) system_prompt = args.system_prompt conversation: list[ChatCompletionMessageParam] = [] + if system_prompt is not None: conversation.append({"role": "system", "content": system_prompt}) + if args.quick: + conversation.append({"role": "user", "content": args.quick}) + + chat_completion = client.chat.completions.create( + model=model_name, messages=conversation) + print(chat_completion.choices[0].message.content) + return + print("Please enter a message for the chat model:") while True: try: @@ -136,6 +145,12 @@ def subparser_init( default=None, help=("The system prompt to be added to the chat template, " "used for models that support system prompts.")) + chat_parser.add_argument("-q", + "--quick", + type=str, + metavar="MESSAGE", + help=("Send a single prompt as MESSAGE " + "and print the response, then exit.")) return chat_parser @@ -149,6 +164,13 @@ def __init__(self): @staticmethod def cmd(args: argparse.Namespace) -> None: model_name, client = _interactive_cli(args) + + if args.quick: + completion = client.completions.create(model=model_name, + prompt=args.quick) + print(completion.choices[0].text) + return + print("Please enter prompt to complete:") while True: input_prompt = input("> ") @@ -168,6 +190,13 @@ def subparser_init( "via the running API server."), usage="vllm complete [options]") _add_query_options(complete_parser) + complete_parser.add_argument( + "-q", + "--quick", + type=str, + metavar="PROMPT", + help= + "Send a single prompt and print the completion output, then exit.") return complete_parser diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 5c8781b50d2c..957fec290bf2 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,14 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import signal import uvloop +import vllm.envs as envs +from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) -from vllm.utils import FlexibleArgumentParser +from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG, + show_filtered_argument_or_group_from_help) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, get_tcp_uri +from vllm.v1.engine.core import EngineCoreProc +from vllm.v1.engine.core_client import CoreEngineProcManager +from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) class ServeSubcommand(CLISubcommand): @@ -24,7 +36,10 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - uvloop.run(run_server(args)) + if args.headless: + run_headless(args) + else: + uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: validate_parsed_serve_args(args) @@ -42,6 +57,18 @@ def subparser_init( nargs='?', help="The model tag to serve " "(optional if specified in config)") + serve_parser.add_argument( + "--headless", + action='store_true', + default=False, + help="Run in headless mode. See multi-node data parallel " + "documentation for more details.") + serve_parser.add_argument( + '--data-parallel-start-rank', + '-dpr', + type=int, + default=0, + help='Starting data parallel rank for secondary nodes.') serve_parser.add_argument( "--config", type=str, @@ -52,8 +79,63 @@ def subparser_init( "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference" ) - return make_arg_parser(serve_parser) + serve_parser = make_arg_parser(serve_parser) + show_filtered_argument_or_group_from_help(serve_parser) + serve_parser.epilog = VLLM_SERVE_PARSER_EPILOG + return serve_parser def cmd_init() -> list[CLISubcommand]: return [ServeSubcommand()] + + +def run_headless(args: argparse.Namespace): + + # Create the EngineConfig. + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + + if not envs.VLLM_USE_V1: + raise RuntimeError("Headless mode is only supported for V1") + + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + port = engine_args.data_parallel_rpc_port # add to config too + input_address = get_tcp_uri(host, port) + + if local_engine_count <= 0: + raise RuntimeError("data_parallel_size_local must be > 0 in " + "headless mode") + + # Catch SIGTERM and SIGINT to allow graceful shutdown. + def signal_handler(signum, frame): + logger.debug("Received %d signal.", signum) + raise SystemExit + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + logger.info( + "Launching %d data parallel engine(s) in headless mode, " + "with head node address %s.", local_engine_count, input_address) + + # Create the engines. + engine_manager = CoreEngineProcManager( + target_fn=EngineCoreProc.run_engine_core, + local_engine_count=local_engine_count, + start_index=args.data_parallel_start_rank, + local_start_index=0, + vllm_config=vllm_config, + on_head_node=False, + input_address=input_address, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + ) + + try: + engine_manager.join_first() + finally: + logger.info("Shutting down.") + engine_manager.close() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 72ad79bd2df2..59cc44eb0e18 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -4,7 +4,8 @@ import warnings from collections.abc import Sequence from contextlib import contextmanager -from typing import Any, Callable, ClassVar, Optional, Union, cast, overload +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, + cast, overload) import cloudpickle import torch.nn as nn @@ -13,7 +14,8 @@ from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) -from vllm.config import CompilationConfig, ModelDType, TokenizerMode +from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, + is_init_field) from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -46,6 +48,9 @@ from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, is_list_of) +if TYPE_CHECKING: + from vllm.v1.metrics.reader import Metric + logger = init_logger(__name__) _R = TypeVar("_R", default=Any) @@ -115,7 +120,8 @@ class LLM: to eager mode. Additionally for encoder-decoder models, if the sequence length of the encoder input is larger than this, we fall back to the eager mode. - disable_custom_all_reduce: See {class}`~vllm.config.ParallelConfig` + disable_custom_all_reduce: See + [ParallelConfig][vllm.config.ParallelConfig]. disable_async_output_proc: Disable async output processing. This may result in lower performance. hf_token: The token to use as HTTP bearer authorization for remote files @@ -127,13 +133,11 @@ class LLM: compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the level of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. - **kwargs: Arguments for {class}`~vllm.EngineArgs`. (See - {ref}`engine-args`) + **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. - :::{note} - This class is intended to be used for offline inference. For online - serving, use the {class}`~vllm.AsyncLLMEngine` class instead. - ::: + Note: + This class is intended to be used for offline inference. For online + serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead. """ DEPRECATE_LEGACY: ClassVar[bool] = True @@ -142,7 +146,7 @@ class LLM: DEPRECATE_INIT_POSARGS: ClassVar[bool] = True """ A flag to toggle whether to deprecate positional arguments in - {meth}`LLM.__init__`. + [LLM.__init__][]. """ @classmethod @@ -204,9 +208,13 @@ def __init__( kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) if compilation_config is not None: - if isinstance(compilation_config, (int, dict)): - compilation_config_instance = CompilationConfig.from_cli( - str(compilation_config)) + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + level=compilation_config) + elif isinstance(compilation_config, dict): + predicate = lambda x: is_init_field(CompilationConfig, x[0]) + compilation_config_instance = CompilationConfig( + **dict(filter(predicate, compilation_config.items()))) else: compilation_config_instance = compilation_config else: @@ -399,7 +407,7 @@ def generate( Args: prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See {class}`~vllm.inputs.PromptType` + for batch inference. See [PromptType][vllm.inputs.PromptType] for more details about the format of each prompts. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. @@ -417,11 +425,10 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. - :::{note} - Using `prompts` and `prompt_token_ids` as keyword parameters is - considered legacy and may be deprecated in the future. You should - instead pass them via the `inputs` parameter. - ::: + Note: + Using `prompts` and `prompt_token_ids` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the `inputs` parameter. """ runner_type = self.llm_engine.model_config.runner_type if runner_type not in ["generate", "transcription"]: @@ -490,17 +497,16 @@ def collective_rpc(self, `self` argument, in addition to the arguments passed in `args` and `kwargs`. The `self` argument will be the worker object. timeout: Maximum time in seconds to wait for execution. Raises a - {exc}`TimeoutError` on timeout. `None` means wait indefinitely. + [`TimeoutError`][] on timeout. `None` means wait indefinitely. args: Positional arguments to pass to the worker method. kwargs: Keyword arguments to pass to the worker method. Returns: A list containing the results from each worker. - :::{note} - It is recommended to use this API to only pass control messages, - and set up data-plane communication to pass data. - ::: + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. """ return self.llm_engine.collective_rpc(method, timeout, args, kwargs) @@ -667,7 +673,7 @@ def chat( Generate responses for a chat conversation. The chat conversation is converted into a text prompt using the - tokenizer and calls the {meth}`generate` method to generate the + tokenizer and calls the [generate][] method to generate the responses. Multi-modal inputs can be passed in the same way you would pass them @@ -676,8 +682,8 @@ def chat( Args: messages: A list of conversations or a single conversation. - - Each conversation is represented as a list of messages. - - Each message is a dictionary with 'role' and 'content' keys. + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it @@ -687,27 +693,27 @@ def chat( use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. chat_template: The template to use for structuring the chat. - If not provided, the model's default chat template will be used. + If not provided, the model's default chat template will be used. chat_template_content_format: The format to render message content. - - "string" will render the content as a string. - Example: ``"Who are you?"`` - - "openai" will render the content as a list of dictionaries, - similar to OpenAI schema. - Example: ``[{"type": "text", "text": "Who are you?"}]`` + - "string" will render the content as a string. + Example: `"Who are you?"` + - "openai" will render the content as a list of dictionaries, + similar to OpenAI schema. + Example: `[{"type": "text", "text": "Who are you?"}]` add_generation_prompt: If True, adds a generation template to each message. continue_final_message: If True, continues the final message in the conversation instead of starting a new one. Cannot be - ``True`` if ``add_generation_prompt`` is also ``True``. + `True` if `add_generation_prompt` is also `True`. chat_template_kwargs: Additional kwargs to pass to the chat template. mm_processor_kwargs: Multimodal processor kwarg overrides for this chat request. Only used for offline requests. Returns: - A list of ``RequestOutput`` objects containing the generated + A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ list_of_messages: list[list[ChatCompletionMessageParam]] @@ -726,11 +732,11 @@ def chat( tokenizer = self.get_tokenizer(lora_request) model_config = self.llm_engine.get_model_config() resolved_content_format = resolve_chat_template_content_format( - model_config, chat_template, tools, chat_template_content_format, tokenizer, + model_config=model_config, ) _chat_template_kwargs: dict[str, Any] = dict( @@ -762,9 +768,9 @@ def chat( ) else: prompt_str = apply_hf_chat_template( - model_config, - tokenizer, + tokenizer=tokenizer, conversation=conversation, + model_config=model_config, **_chat_template_kwargs, ) # Special tokens are already included in chat templates so @@ -906,7 +912,7 @@ def encode( Args: prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See {class}`~vllm.inputs.PromptType` + for batch inference. See [PromptType][vllm.inputs.PromptType] for more details about the format of each prompts. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. @@ -919,11 +925,10 @@ def encode( A list of `PoolingRequestOutput` objects containing the pooled hidden states in the same order as the input prompts. - :::{note} - Using `prompts` and `prompt_token_ids` as keyword parameters is - considered legacy and may be deprecated in the future. You should - instead pass them via the `inputs` parameter. - ::: + Note: + Using `prompts` and `prompt_token_ids` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the `inputs` parameter. """ runner_type = self.llm_engine.model_config.runner_type if runner_type != "pooling": @@ -996,7 +1001,7 @@ def embed( Args: prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See {class}`~vllm.inputs.PromptType` + for batch inference. See [PromptType][vllm.inputs.PromptType] for more details about the format of each prompts. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. @@ -1006,7 +1011,7 @@ def embed( generation, if any. Returns: - A list of ``EmbeddingRequestOutput`` objects containing the + A list of `EmbeddingRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ if self.llm_engine.model_config.task != "embed": @@ -1040,7 +1045,7 @@ def classify( Args: prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See {class}`~vllm.inputs.PromptType` + for batch inference. See [PromptType][vllm.inputs.PromptType] for more details about the format of each prompts. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. @@ -1048,7 +1053,7 @@ def classify( generation, if any. Returns: - A list of ``ClassificationRequestOutput`` objects containing the + A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ if self.llm_engine.model_config.task != "classify": @@ -1158,11 +1163,11 @@ def score( lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: - """Generate similarity scores for all pairs ``<text,text_pair>``. + """Generate similarity scores for all pairs `<text,text_pair>`. - The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``. - In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N`` - times to pair with the ``text_2`` sentences. + The inputs can be `1 -> 1`, `1 -> N` or `N -> N`. + In the `1 - N` case the `text_1` sentence will be replicated `N` + times to pair with the `text_2` sentences. The input pairs are used to build a list of prompts for the cross encoder model. This class automatically batches the prompts, considering the memory constraint. For the best performance, put all @@ -1170,9 +1175,9 @@ def score( Args: text_1: can be a single prompt or a list of prompts, in which - case it has to have the same length as the ``text_2`` list + case it has to have the same length as the `text_2` list text_2: The texts to pair with the query to form the input - to the LLM. See {class}`~vllm.inputs.PromptType` for + to the LLM. See [PromptType][vllm.inputs.PromptType] for more details about the format of each prompts. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. @@ -1180,7 +1185,7 @@ def score( generation, if any. Returns: - A list of ``ScoringRequestOutput`` objects containing the + A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. """ runner_type = self.llm_engine.model_config.runner_type @@ -1281,18 +1286,32 @@ def sleep(self, level: int = 1): def wake_up(self, tags: Optional[list[str]] = None): """ - Wake up the engine from sleep mode. See the {meth}`sleep` method + Wake up the engine from sleep mode. See the [sleep][] method for more details. Args: tags: An optional list of tags to reallocate the engine memory for specific memory allocations. Values must be in - ("weights", "kv_cache",). If None, all memory is reallocated. + `("weights", "kv_cache")`. If None, all memory is reallocated. wake_up should be called with all tags (or None) before the engine is used again. """ self.llm_engine.wake_up(tags) + def get_metrics(self) -> list["Metric"]: + """Return a snapshot of aggregated metrics from Prometheus. + + Returns: + A ``MetricSnapshot`` instance capturing the current state + of all aggregated metrics from Prometheus. + + Note: + This method is only available with the V1 LLM engine. + """ + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + assert isinstance(self.llm_engine, V1LLMEngine) + return self.llm_engine.get_metrics() + # LEGACY def _convert_v1_inputs( self, @@ -1301,27 +1320,25 @@ def _convert_v1_inputs( ): # skip_tokenizer_init is now checked in engine + if prompts is None and prompt_token_ids is None: + raise ValueError( + "Either prompts or prompt_token_ids must be provided.") + if prompts is not None and prompt_token_ids is not None \ + and len(prompts) != len(prompt_token_ids): + raise ValueError( + "The lengths of prompts and prompt_token_ids must be the same." + ) + if prompts is not None: prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] if prompt_token_ids is not None: prompt_token_ids = [ p["content"] for p in parse_and_batch_prompt(prompt_token_ids) ] - - num_requests = None if prompts is not None: num_requests = len(prompts) - if prompt_token_ids is not None: - if (num_requests is not None - and num_requests != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") - + elif prompt_token_ids is not None: num_requests = len(prompt_token_ids) - if num_requests is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") - parsed_prompts: list[PromptType] = [] for i in range(num_requests): item: PromptType diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index ea5759152a22..d4655dd5e6ab 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -2,6 +2,8 @@ from typing import Optional, Union +import torch + from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -23,6 +25,7 @@ def log_inputs( request_id: str, prompt: Optional[str], prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], @@ -39,6 +42,8 @@ def log_inputs( logger.info( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " + "prompt_embeds shape: %s, " "lora_request: %s, prompt_adapter_request: %s.", request_id, - prompt, params, prompt_token_ids, lora_request, - prompt_adapter_request) + prompt, params, prompt_token_ids, + prompt_embeds.shape if prompt_embeds is not None else None, + lora_request, prompt_adapter_request) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e034eacb24ef..2da89b4f5944 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -7,7 +7,6 @@ import inspect import multiprocessing import os -import re import signal import socket import tempfile @@ -17,8 +16,11 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus +from json import JSONDecodeError from typing import Annotated, Optional, Union +import prometheus_client +import regex as re import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError @@ -48,6 +50,8 @@ # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, CompletionRequest, CompletionResponse, DetokenizeRequest, @@ -71,6 +75,8 @@ UnloadLoRAAdapterRequest) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_classification import ( + ServingClassification) from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_engine import OpenAIServing @@ -181,6 +187,10 @@ async def build_async_engine_client_from_engine_args( usage_context=usage_context, disable_log_requests=engine_args.disable_log_requests, disable_log_stats=engine_args.disable_log_stats) + + # Don't keep the dummy data in memory + await async_llm.reset_mm_cache() + yield async_llm finally: if async_llm: @@ -297,15 +307,18 @@ async def validate_json_request(raw_request: Request): content_type = raw_request.headers.get("content-type", "").lower() media_type = content_type.split(";", maxsplit=1)[0] if media_type != "application/json": - raise HTTPException( - status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, - detail="Unsupported Media Type: Only 'application/json' is allowed" - ) + raise RequestValidationError(errors=[ + "Unsupported Media Type: Only 'application/json' is allowed" + ]) router = APIRouter() +class PrometheusResponse(Response): + media_type = prometheus_client.CONTENT_TYPE_LATEST + + def mount_metrics(app: FastAPI): # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable @@ -324,6 +337,10 @@ def mount_metrics(app: FastAPI): registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) + # `response_class=PrometheusResponse` is needed to return an HTTP response + # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" + # instead of the default "application/json" which is incorrect. + # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364 Instrumentator( excluded_handlers=[ "/metrics", @@ -334,7 +351,7 @@ def mount_metrics(app: FastAPI): "/server_info", ], registry=registry, - ).add().instrument(app).expose(app) + ).add().instrument(app).expose(app, response_class=PrometheusResponse) # Add prometheus asgi middleware to route /metrics requests metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) @@ -373,6 +390,10 @@ def score(request: Request) -> Optional[ServingScores]: return request.app.state.openai_serving_scores +def classify(request: Request) -> Optional[ServingClassification]: + return request.app.state.openai_serving_classification + + def rerank(request: Request) -> Optional[ServingScores]: return request.app.state.openai_serving_scores @@ -389,7 +410,7 @@ def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client -@router.get("/health") +@router.get("/health", response_class=Response) async def health(raw_request: Request) -> Response: """Health check.""" await engine_client(raw_request).check_health() @@ -405,6 +426,7 @@ async def get_server_load_metrics(request: Request): # - /v1/audio/transcriptions # - /v1/embeddings # - /pooling + # - /classify # - /score # - /v1/score # - /rerank @@ -414,18 +436,42 @@ async def get_server_load_metrics(request: Request): content={'server_load': request.app.state.server_load_metrics}) -@router.api_route("/ping", methods=["GET", "POST"]) +@router.get("/ping", response_class=Response) +@router.post("/ping", response_class=Response) async def ping(raw_request: Request) -> Response: """Ping check. Endpoint required for SageMaker""" return await health(raw_request) -@router.post("/tokenize", dependencies=[Depends(validate_json_request)]) +@router.post("/tokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_IMPLEMENTED.value: { + "model": ErrorResponse + }, + }) @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) - generator = await handler.create_tokenize(request, raw_request) + try: + generator = await handler.create_tokenize(request, raw_request) + except NotImplementedError as e: + raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED.value, + detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e + if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -435,12 +481,31 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): assert_never(generator) -@router.post("/detokenize", dependencies=[Depends(validate_json_request)]) +@router.post("/detokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) - generator = await handler.create_detokenize(request, raw_request) + try: + generator = await handler.create_detokenize(request, raw_request) + except OverflowError as e: + raise RequestValidationError(errors=[str(e)]) from e + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e + if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -465,7 +530,23 @@ async def show_version(): @router.post("/v1/chat/completions", - dependencies=[Depends(validate_json_request)]) + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + } + }) @with_cancellation @load_aware_call async def create_chat_completion(request: ChatCompletionRequest, @@ -487,7 +568,24 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) +@router.post("/v1/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation @load_aware_call async def create_completion(request: CompletionRequest, raw_request: Request): @@ -496,7 +594,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Completions API") - generator = await handler.create_completion(request, raw_request) + try: + generator = await handler.create_completion(request, raw_request) + except OverflowError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e + if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -506,7 +612,16 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) +@router.post("/v1/embeddings", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation @load_aware_call async def create_embedding(request: EmbeddingRequest, raw_request: Request): @@ -553,7 +668,16 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) -@router.post("/pooling", dependencies=[Depends(validate_json_request)]) +@router.post("/pooling", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation @load_aware_call async def create_pooling(request: PoolingRequest, raw_request: Request): @@ -572,7 +696,37 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): assert_never(generator) -@router.post("/score", dependencies=[Depends(validate_json_request)]) +@router.post("/classify", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_classify(request: ClassificationRequest, + raw_request: Request): + handler = classify(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Classification API") + + generator = await handler.create_classify(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ClassificationResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation @load_aware_call async def create_score(request: ScoreRequest, raw_request: Request): @@ -591,7 +745,16 @@ async def create_score(request: ScoreRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/score", dependencies=[Depends(validate_json_request)]) +@router.post("/v1/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation @load_aware_call async def create_score_v1(request: ScoreRequest, raw_request: Request): @@ -602,12 +765,28 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) -@router.post("/v1/audio/transcriptions") +@router.post("/v1/audio/transcriptions", + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNPROCESSABLE_ENTITY.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation @load_aware_call -async def create_transcriptions(request: Annotated[TranscriptionRequest, - Form()], - raw_request: Request): +async def create_transcriptions(raw_request: Request, + request: Annotated[TranscriptionRequest, + Form()]): handler = transcription(raw_request) if handler is None: return base(raw_request).create_error_response( @@ -627,7 +806,16 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/rerank", dependencies=[Depends(validate_json_request)]) +@router.post("/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation @load_aware_call async def do_rerank(request: RerankRequest, raw_request: Request): @@ -645,7 +833,16 @@ async def do_rerank(request: RerankRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)]) +@router.post("/v1/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( @@ -656,7 +853,16 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)]) +@router.post("/v2/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) @@ -736,12 +942,29 @@ async def is_sleeping(raw_request: Request): return JSONResponse(content={"is_sleeping": is_sleeping}) -@router.post("/invocations", dependencies=[Depends(validate_json_request)]) +@router.post("/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) async def invocations(raw_request: Request): """ For SageMaker, routes requests to other handlers based on model `task`. """ - body = await raw_request.json() + try: + body = await raw_request.json() + except JSONDecodeError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}") from e + task = raw_request.app.state.task if task not in TASK_HANDLERS: @@ -832,10 +1055,26 @@ def build_app(args: Namespace) -> FastAPI: allow_headers=args.allowed_headers, ) + @app.exception_handler(HTTPException) + async def http_exception_handler(_: Request, exc: HTTPException): + err = ErrorResponse(message=exc.detail, + type=HTTPStatus(exc.status_code).phrase, + code=exc.status_code) + return JSONResponse(err.model_dump(), status_code=exc.status_code) + @app.exception_handler(RequestValidationError) - async def validation_exception_handler(_, exc): - err = ErrorResponse(message=str(exc), - type="BadRequestError", + async def validation_exception_handler(_: Request, + exc: RequestValidationError): + exc_str = str(exc) + errors_str = str(exc.errors()) + + if exc.errors() and errors_str and errors_str != exc_str: + message = f"{exc_str} {errors_str}" + else: + message = exc_str + + err = ErrorResponse(message=message, + type=HTTPStatus.BAD_REQUEST.phrase, code=HTTPStatus.BAD_REQUEST) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -937,10 +1176,10 @@ async def init_app_state( chat_template=resolved_chat_template) else: hf_chat_template = resolve_hf_chat_template( - vllm_config.model_config, - tokenizer, + tokenizer=tokenizer, chat_template=None, tools=None, + model_config=vllm_config.model_config, ) if hf_chat_template != resolved_chat_template: @@ -1001,6 +1240,12 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger) if model_config.task in ( "score", "embed", "pooling") else None + state.openai_serving_classification = ServingClassification( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.task == "classify" else None state.jinaai_serving_reranking = ServingScores( engine_client, model_config, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index d8cec2202134..d01af5e42266 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -286,6 +286,9 @@ def validate_parsed_serve_args(args: argparse.Namespace): if args.enable_auto_tool_choice and not args.tool_call_parser: raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") + if args.enable_prompt_embeds and args.enable_prompt_adapter: + raise ValueError( + "Cannot use prompt embeds and prompt adapter at the same time.") def log_non_default_args(args: argparse.Namespace): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 40e477f03194..393cf381b16b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,18 +3,20 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import json -import re import time +from http import HTTPStatus from typing import Annotated, Any, ClassVar, Literal, Optional, Union +import regex as re import torch -from fastapi import UploadFile +from fastapi import HTTPException, UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) from typing_extensions import TypeAlias from vllm import envs -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + random_tool_call_id) from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, @@ -249,7 +251,7 @@ class ChatCompletionRequest(OpenAIBaseModel): parallel_tool_calls: Optional[bool] = False user: Optional[str] = None - # doc: begin-chat-completion-sampling-params + # --8<-- [start:chat-completion-sampling-params] best_of: Optional[int] = None use_beam_search: bool = False top_k: Optional[int] = None @@ -264,9 +266,9 @@ class ChatCompletionRequest(OpenAIBaseModel): spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None prompt_logprobs: Optional[int] = None - # doc: end-chat-completion-sampling-params + # --8<-- [end:chat-completion-sampling-params] - # doc: begin-chat-completion-extra-params + # --8<-- [start:chat-completion-extra-params] echo: bool = Field( default=False, description=( @@ -401,15 +403,18 @@ class ChatCompletionRequest(OpenAIBaseModel): "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " "to 256 bit). Not supported by vLLM engine V0.")) + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.") - # doc: end-chat-completion-extra-params + # --8<-- [end:chat-completion-extra-params] # Default sampling parameters for chat completion requests _DEFAULT_SAMPLING_PARAMS: dict = { "repetition_penalty": 1.0, "temperature": 1.0, "top_p": 1.0, - "top_k": -1, + "top_k": 0, "min_p": 0.0, } @@ -538,7 +543,9 @@ def to_sampling_params( output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, - logit_bias=self.logit_bias) + logit_bias=self.logit_bias, + extra_args=({"kv_transfer_params": self.kv_transfer_params} + if self.kv_transfer_params else None)) def _get_guided_json_from_tool( self) -> Optional[Union[str, dict, BaseModel]]: @@ -738,7 +745,8 @@ class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create model: Optional[str] = None - prompt: Union[list[int], list[list[int]], str, list[str]] + prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None best_of: Optional[int] = None echo: Optional[bool] = False frequency_penalty: Optional[float] = 0.0 @@ -756,7 +764,7 @@ class CompletionRequest(OpenAIBaseModel): top_p: Optional[float] = None user: Optional[str] = None - # doc: begin-completion-sampling-params + # --8<-- [start:completion-sampling-params] use_beam_search: bool = False top_k: Optional[int] = None min_p: Optional[float] = None @@ -771,9 +779,9 @@ class CompletionRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None allowed_token_ids: Optional[list[int]] = None prompt_logprobs: Optional[int] = None - # doc: end-completion-sampling-params + # --8<-- [end:completion-sampling-params] - # doc: begin-completion-extra-params + # --8<-- [start:completion-extra-params] add_special_tokens: bool = Field( default=True, description=( @@ -846,14 +854,18 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) - # doc: end-completion-extra-params + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.") + + # --8<-- [end:completion-extra-params] # Default sampling parameters for completion requests _DEFAULT_SAMPLING_PARAMS: dict = { "repetition_penalty": 1.0, "temperature": 1.0, "top_p": 1.0, - "top_k": -1, + "top_k": 0, "min_p": 0.0, } @@ -971,7 +983,9 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids) + allowed_token_ids=self.allowed_token_ids, + extra_args=({"kv_transfer_params": self.kv_transfer_params} + if self.kv_transfer_params else None)) @model_validator(mode="before") @classmethod @@ -1012,6 +1026,14 @@ def validate_stream_options(cls, data): return data + @model_validator(mode="before") + @classmethod + def validate_prompt_and_prompt_embeds(cls, data): + if data.get("prompt") is None and data.get("prompt_embeds") is None: + raise ValueError( + "At least one of `prompt` or `prompt_embeds` must be set.") + return data + class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -1023,11 +1045,11 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): user: Optional[str] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # doc: begin-embedding-pooling-params + # --8<-- [start:embedding-pooling-params] additional_data: Optional[Any] = None - # doc: end-embedding-pooling-params + # --8<-- [end:embedding-pooling-params] - # doc: begin-embedding-extra-params + # --8<-- [start:embedding-extra-params] add_special_tokens: bool = Field( default=True, description=( @@ -1042,7 +1064,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) - # doc: end-embedding-extra-params + # --8<-- [end:embedding-extra-params] def to_pooling_params(self): return PoolingParams(dimensions=self.dimensions, @@ -1058,11 +1080,11 @@ class EmbeddingChatRequest(OpenAIBaseModel): user: Optional[str] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # doc: begin-chat-embedding-pooling-params + # --8<-- [start:chat-embedding-pooling-params] additional_data: Optional[Any] = None - # doc: end-chat-embedding-pooling-params + # --8<-- [end:chat-embedding-pooling-params] - # doc: begin-chat-embedding-extra-params + # --8<-- [start:chat-embedding-extra-params] add_special_tokens: bool = Field( default=False, description=( @@ -1096,7 +1118,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling."), ) - # doc: end-chat-embedding-extra-params + # --8<-- [end:chat-embedding-extra-params] @model_validator(mode="before") @classmethod @@ -1125,11 +1147,11 @@ class ScoreRequest(OpenAIBaseModel): text_2: Union[list[str], str] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # doc: begin-score-pooling-params + # --8<-- [start:score-pooling-params] additional_data: Optional[Any] = None - # doc: end-score-pooling-params + # --8<-- [end:score-pooling-params] - # doc: begin-score-extra-params + # --8<-- [start:score-extra-params] priority: int = Field( default=0, description=( @@ -1138,7 +1160,7 @@ class ScoreRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) - # doc: end-score-extra-params + # --8<-- [end:score-extra-params] def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) @@ -1151,11 +1173,11 @@ class RerankRequest(OpenAIBaseModel): top_n: int = Field(default_factory=lambda: 0) truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # doc: begin-rerank-pooling-params + # --8<-- [start:rerank-pooling-params] additional_data: Optional[Any] = None - # doc: end-rerank-pooling-params + # --8<-- [end:rerank-pooling-params] - # doc: begin-rerank-extra-params + # --8<-- [start:rerank-extra-params] priority: int = Field( default=0, description=( @@ -1164,7 +1186,7 @@ class RerankRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) - # doc: end-rerank-extra-params + # --8<-- [end:rerank-extra-params] def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) @@ -1221,6 +1243,8 @@ class CompletionResponse(OpenAIBaseModel): model: str choices: list[CompletionResponseChoice] usage: UsageInfo + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, description="KVTransfer parameters.") class CompletionResponseStreamChoice(OpenAIBaseModel): @@ -1291,13 +1315,54 @@ class ScoreResponse(OpenAIBaseModel): usage: UsageInfo +class ClassificationRequest(OpenAIBaseModel): + model: Optional[str] = None + input: Union[list[str], str] + truncate_prompt_tokens: Optional[int] = None + user: Optional[str] = None + + # --8<-- [start:classification-pooling-params] + additional_data: Optional[Any] = None + # --8<-- [end:classification-pooling-params] + + # --8<-- [start:classification-extra-params] + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # --8<-- [end:classification-extra-params] + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class ClassificationData(OpenAIBaseModel): + index: int + label: Optional[str] + probs: list[float] + num_classes: int + + +class ClassificationResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"classify-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[ClassificationData] + usage: UsageInfo + + class FunctionCall(OpenAIBaseModel): name: str arguments: str class ToolCall(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + id: str = Field(default_factory=random_tool_call_id) type: Literal["function"] = "function" function: FunctionCall @@ -1309,8 +1374,8 @@ class DeltaFunctionCall(BaseModel): # a tool call delta where everything is optional class DeltaToolCall(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") - type: Literal["function"] = "function" + id: Optional[str] = None + type: Optional[Literal["function"]] = None index: int function: Optional[DeltaFunctionCall] = None @@ -1369,6 +1434,8 @@ class ChatCompletionResponse(OpenAIBaseModel): choices: list[ChatCompletionResponseChoice] usage: UsageInfo prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, description="KVTransfer parameters.") class DeltaMessage(OpenAIBaseModel): @@ -1535,6 +1602,10 @@ class TokenizeChatRequest(OpenAIBaseModel): default=None, description=("Additional kwargs to pass to the HF processor."), ) + tools: Optional[list[ChatCompletionToolsParam]] = Field( + default=None, + description=("A list of tools the model may call."), + ) @model_validator(mode="before") @classmethod @@ -1627,7 +1698,7 @@ class TranscriptionRequest(OpenAIBaseModel): timestamps incurs additional latency. """ - # doc: begin-transcription-extra-params + # --8<-- [start:transcription-extra-params] stream: Optional[bool] = False """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat @@ -1636,9 +1707,9 @@ class TranscriptionRequest(OpenAIBaseModel): # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False stream_continuous_usage_stats: Optional[bool] = False - # doc: end-transcription-extra-params + # --8<-- [end:transcription-extra-params] - # doc: begin-transcription-sampling-params + # --8<-- [start:transcription-sampling-params] temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. @@ -1672,14 +1743,14 @@ class TranscriptionRequest(OpenAIBaseModel): presence_penalty: Optional[float] = 0.0 """The presence penalty to use for sampling.""" - # doc: end-transcription-sampling-params + # --8<-- [end:transcription-sampling-params] # Default sampling parameters for transcription requests. _DEFAULT_SAMPLING_PARAMS: dict = { "repetition_penalty": 1.0, "temperature": 1.0, "top_p": 1.0, - "top_k": -1, + "top_k": 0, "min_p": 0.0, } @@ -1727,7 +1798,13 @@ def to_sampling_params( @model_validator(mode="before") @classmethod - def validate_stream_options(cls, data): + def validate_transcription_request(cls, data): + if isinstance(data.get("file"), str): + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail="Expected 'file' to be a file-like object, not 'str'.", + ) + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream = data.get("stream", False) if any(bool(data.get(so, False)) for so in stream_opts) and not stream: diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index fccf459f17dc..eae83c9a494a 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -365,8 +365,8 @@ async def main(args): # Determine the type of request and run it. if request.url == "/v1/chat/completions": - chat_handler_fn = (None if openai_serving_chat is None else - openai_serving_chat.create_chat_completion) + chat_handler_fn = openai_serving_chat.create_chat_completion if \ + openai_serving_chat is not None else None if chat_handler_fn is None: response_futures.append( make_async_error_request_output( @@ -380,8 +380,8 @@ async def main(args): run_request(chat_handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/embeddings": - embed_handler_fn = (None if openai_serving_embedding is None else - openai_serving_embedding.create_embedding) + embed_handler_fn = openai_serving_embedding.create_embedding if \ + openai_serving_embedding is not None else None if embed_handler_fn is None: response_futures.append( make_async_error_request_output( @@ -394,8 +394,8 @@ async def main(args): run_request(embed_handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/score": - score_handler_fn = (None if openai_serving_scores is None else - openai_serving_scores.create_score) + score_handler_fn = openai_serving_scores.create_score if \ + openai_serving_scores is not None else None if score_handler_fn is None: response_futures.append( make_async_error_request_output( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5c11836fbff4..bc11686d7be8 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -2,7 +2,6 @@ import asyncio import json -import re import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence @@ -10,13 +9,15 @@ import jinja2 import partial_json_parser +import regex as re from fastapi import Request from pydantic import TypeAdapter from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, - ConversationMessage) + ConversationMessage, + random_tool_call_id) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -196,7 +197,7 @@ async def create_chat_completion( except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(f"{e} {e.__cause__}") request_id = "chatcmpl-" \ f"{self._base_request_id(raw_request, request.request_id)}" @@ -363,9 +364,10 @@ def extract_tool_call_required_streaming( function_name_returned = True delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(function=DeltaFunctionCall( - name=current_tool_call["name"], - arguments=arguments), + DeltaToolCall(id=random_tool_call_id(), + function=DeltaFunctionCall( + name=current_tool_call["name"], + arguments=arguments), index=len(obj) - 1, type="function") ]) @@ -382,8 +384,7 @@ def extract_tool_call_required_streaming( # instead of name every time name=None, arguments=delta_text), - index=len(obj) - 1, - type="function") + index=len(obj) - 1) ]) else: delta_message = None @@ -422,7 +423,7 @@ async def chat_completion_stream_generator( and self._should_stream_with_auto_tool_parsing(request)) all_previous_token_ids: Optional[list[list[int]]] - function_name_returned: Optional[list[bool]] = None + function_name_returned = [False] * num_choices # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. @@ -435,7 +436,6 @@ async def chat_completion_stream_generator( reasoning_end_arr = [False] * num_choices elif request.tool_choice == "required": previous_texts = [""] * num_choices - function_name_returned = [False] * num_choices all_previous_token_ids = None else: previous_texts, all_previous_token_ids = None, None @@ -623,16 +623,27 @@ async def chat_completion_stream_generator( delta_text = previous_text + delta_text current_text = "" + if function_name_returned[i]: + delta_tool_call = DeltaToolCall( + function=DeltaFunctionCall( + arguments=delta_text), + index=i) + else: + delta_tool_call = DeltaToolCall( + id=random_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + function_name_returned[i] = True + delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(function=DeltaFunctionCall( - name=tool_choice_function_name, - arguments=delta_text), - index=i) + delta_tool_call, ]) elif request.tool_choice == "required": assert previous_texts is not None - assert function_name_returned is not None previous_text = previous_texts[i] current_text = previous_text + delta_text fn_name_returned = function_name_returned[i] @@ -835,7 +846,7 @@ async def chat_completion_stream_generator( total_tokens=num_prompt_tokens + completion_tokens, ) - data = chunk.model_dump_json(exclude_unset=True) + data = chunk.model_dump_json(exclude_none=True) yield f"data: {data}\n\n" # once the final token is handled, if stream_options.include_usage @@ -1075,6 +1086,7 @@ async def chat_completion_full_generator( choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, ) return response diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py new file mode 100644 index 000000000000..90cdd389d59f --- /dev/null +++ b/vllm/entrypoints/openai/serving_classification.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 + +from http import HTTPStatus +from typing import Optional, Union, cast + +import numpy as np +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ClassificationData, + ClassificationRequest, + ClassificationResponse, + ErrorResponse, UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, + OpenAIServing, + ServeContext) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import ClassificationOutput, PoolingRequestOutput + +logger = init_logger(__name__) + + +class ClassificationMixin(OpenAIServing): + + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """ + Process classification inputs: tokenize text, resolve adapters, + and prepare model-specific inputs. + """ + ctx = cast(ClassificationServeContext, ctx) + if isinstance(ctx.request.input, str) and not ctx.request.input: + return self.create_error_response( + "Input cannot be empty for classification", + status_code=HTTPStatus.BAD_REQUEST, + ) + + if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0: + return None + + try: + ( + ctx.lora_request, + ctx.prompt_adapter_request, + ) = self._maybe_get_adapters(ctx.request) + + ctx.tokenizer = await self.engine_client.get_tokenizer( + ctx.lora_request) + + if ctx.prompt_adapter_request is not None: + raise NotImplementedError( + "Prompt adapter is not supported for classification models" + ) + + ( + ctx.request_prompts, + ctx.engine_prompts, + ) = await self._preprocess_completion( + ctx.request, + ctx.tokenizer, + ctx.request.input, + truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, + ) + + return None + + except (ValueError, TypeError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + def _build_response( + self, + ctx: ServeContext, + ) -> Union[ClassificationResponse, ErrorResponse]: + """ + Convert model outputs to a formatted classification response + with probabilities and labels. + """ + ctx = cast(ClassificationServeContext, ctx) + items: list[ClassificationData] = [] + num_prompt_tokens = 0 + + final_res_batch_checked = cast(list[PoolingRequestOutput], + ctx.final_res_batch) + + for idx, final_res in enumerate(final_res_batch_checked): + classify_res = ClassificationOutput.from_base(final_res.outputs) + + probs = classify_res.probs + predicted_index = int(np.argmax(probs)) + label = getattr(self.model_config.hf_config, "id2label", + {}).get(predicted_index) + + item = ClassificationData( + index=idx, + label=label, + probs=probs, + num_classes=len(probs), + ) + + items.append(item) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ClassificationResponse( + id=ctx.request_id, + created=ctx.created_time, + model=ctx.model_name, + data=items, + usage=usage, + ) + + +class ServingClassification(ClassificationMixin): + request_id_prefix = "classify" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + ) + + async def create_classify( + self, + request: ClassificationRequest, + raw_request: Request, + ) -> Union[ClassificationResponse, ErrorResponse]: + model_name = self._get_model_name(request.model) + request_id = (f"{self.request_id_prefix}-" + f"{self._base_request_id(raw_request)}") + + ctx = ClassificationServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + ) + + return await super().handle(ctx) # type: ignore diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1067f35ce240..7beaae287de9 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,7 @@ import jinja2 from fastapi import Request +from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -25,8 +26,11 @@ UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - clamp_prompt_logprobs) + clamp_prompt_logprobs, + is_text_tokens_prompt) from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, + is_tokens_prompt) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -90,6 +94,10 @@ async def create_completion( return self.create_error_response( "suffix is not currently supported") + if request.echo and request.prompt_embeds is not None: + return self.create_error_response( + "Echo is unsupported with prompt embeds.") + request_id = f"cmpl-{self._base_request_id(raw_request)}" created_time = int(time.time()) @@ -130,8 +138,24 @@ async def create_completion( try: for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) + # Mypy does not infer that engine_prompt will have only one of + # "prompt_token_ids" or "prompt_embeds" defined, and both of + # these as Union[object, the expected type], where it infers + # object if engine_prompt is a subclass of one of the + # typeddicts that defines both keys. Worse, because of + # https://github.com/python/mypy/issues/8586, mypy does not + # infer the type of engine_prompt correctly because of the + # enumerate. So we need an unnecessary cast here. + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], + engine_prompt) + if is_embeds_prompt(engine_prompt): + input_length = len(engine_prompt["prompt_embeds"]) + elif is_tokens_prompt(engine_prompt): + input_length = len(engine_prompt["prompt_token_ids"]) + else: + assert_never(engine_prompt) + default_max_tokens = self.max_model_len - input_length + if request.use_beam_search: sampling_params = request.to_beam_search_params( default_max_tokens, self.default_sampling_params) @@ -152,6 +176,11 @@ async def create_completion( trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) + # Mypy inconsistently requires this second cast in different + # environments. It shouldn't be necessary (redundant from above) + # but pre-commit in CI fails without it. + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], + engine_prompt) if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, @@ -211,7 +240,11 @@ async def create_completion( # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - final_res.prompt = request_prompts[i]["prompt"] + request_prompt = request_prompts[i] + if is_text_tokens_prompt(request_prompt): + final_res.prompt = request_prompt["prompt"] + else: + final_res.prompt = None final_res_batch_checked = cast(list[RequestOutput], final_res_batch) @@ -276,8 +309,8 @@ async def completion_stream_generator( prompt_text = res.prompt # Prompt details are excluded from later streamed outputs - if res.prompt_token_ids is not None: - num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + if prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(prompt_token_ids) delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[dict[ @@ -482,7 +515,7 @@ def request_output_to_completion_response( model=model_name, choices=choices, usage=usage, - ) + kv_transfer_params=final_res_batch[0].kv_transfer_params) def _create_completion_logprobs( self, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 4b4d2d8b76f4..3785d2642f9d 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,14 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -import asyncio import base64 -import time -from collections.abc import AsyncGenerator from typing import Final, Literal, Optional, Union, cast import numpy as np from fastapi import Request -from typing_extensions import assert_never +from typing_extensions import assert_never, override from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -19,13 +16,13 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, + OpenAIServing, + ServeContext) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) -from vllm.utils import merge_async_iterators logger = init_logger(__name__) @@ -45,180 +42,77 @@ def _get_embedding( assert_never(encoding_format) -class OpenAIServingEmbedding(OpenAIServing): +class EmbeddingMixin(OpenAIServing): - def __init__( - self, - engine_client: EngineClient, - model_config: ModelConfig, - models: OpenAIServingModels, - *, - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - chat_template_content_format: ChatTemplateContentFormatOption, - ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger) - - self.chat_template = chat_template - self.chat_template_content_format: Final = chat_template_content_format - - async def create_embedding( + async def _preprocess( self, - request: EmbeddingRequest, - raw_request: Optional[Request] = None, - ) -> Union[EmbeddingResponse, ErrorResponse]: - """ - Embedding API similar to OpenAI's API. - - See https://platform.openai.com/docs/api-reference/embeddings/create - for the API specification. This API mimics the OpenAI Embedding API. - """ - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - encoding_format = request.encoding_format - - model_name = self._get_model_name(request.model) - request_id = f"embd-{self._base_request_id(raw_request)}" - created_time = int(time.time()) - - truncate_prompt_tokens = request.truncate_prompt_tokens - - pooling_params = request.to_pooling_params() - - try: - pooling_params.verify(self.model_config) - except ValueError as e: - return self.create_error_response(str(e)) - + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + ctx = cast(EmbeddingServeContext, ctx) try: - truncate_prompt_tokens = _validate_truncation_size( - self.max_model_len, truncate_prompt_tokens) ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) + ctx.lora_request, + ctx.prompt_adapter_request, + ) = self._maybe_get_adapters(ctx.request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request + ) - if prompt_adapter_request is not None: + if ctx.prompt_adapter_request is not None: raise NotImplementedError("Prompt adapter is not supported " "for embedding models") - if isinstance(request, EmbeddingChatRequest): + if isinstance(ctx.request, EmbeddingChatRequest): ( _, - request_prompts, - engine_prompts, + ctx.request_prompts, + ctx.engine_prompts, ) = await self._preprocess_chat( - request, + ctx.request, tokenizer, - request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self. + ctx.request.messages, + chat_template=ctx.request.chat_template + or ctx.chat_template, + chat_template_content_format=ctx. chat_template_content_format, # In embedding requests, we are not generating tokens, # so there is no need to append extra tokens to the input add_generation_prompt=False, continue_final_message=False, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, + truncate_prompt_tokens=ctx.truncate_prompt_tokens, + add_special_tokens=ctx.request.add_special_tokens, ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, + (ctx.request_prompts, + ctx.engine_prompts) = await self._preprocess_completion( + ctx.request, tokenizer, - request.input, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, + ctx.request.input, + truncate_prompt_tokens=ctx.truncate_prompt_tokens, + add_special_tokens=ctx.request.add_special_tokens, ) + return None except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - # Schedule the request and get the result generator. - generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - try: - for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}-{i}" - - self._log_inputs(request_id_item, - request_prompts[i], - params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) - - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) - - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - ) - - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - result_generator = merge_async_iterators(*generators) - - num_prompts = len(engine_prompts) - - # Non-streaming response - final_res_batch: list[Optional[PoolingRequestOutput]] - final_res_batch = [None] * num_prompts - try: - async for i, res in result_generator: - final_res_batch[i] = res - - assert all(final_res is not None for final_res in final_res_batch) - - final_res_batch_checked = cast(list[PoolingRequestOutput], - final_res_batch) - - response = self.request_output_to_embedding_response( - final_res_batch_checked, - request_id, - created_time, - model_name, - encoding_format, - ) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - return response - - def request_output_to_embedding_response( + def _build_response( self, - final_res_batch: list[PoolingRequestOutput], - request_id: str, - created_time: int, - model_name: str, - encoding_format: Literal["float", "base64"], - ) -> EmbeddingResponse: + ctx: ServeContext, + ) -> Union[EmbeddingResponse, ErrorResponse]: items: list[EmbeddingResponseData] = [] num_prompt_tokens = 0 - for idx, final_res in enumerate(final_res_batch): + final_res_batch_checked = cast(list[PoolingRequestOutput], + ctx.final_res_batch) + + for idx, final_res in enumerate(final_res_batch_checked): embedding_res = EmbeddingRequestOutput.from_base(final_res) item = EmbeddingResponseData( index=idx, embedding=_get_embedding(embedding_res.outputs, - encoding_format), + ctx.request.encoding_format), ) prompt_token_ids = final_res.prompt_token_ids @@ -231,9 +125,76 @@ def request_output_to_embedding_response( ) return EmbeddingResponse( - id=request_id, - created=created_time, - model=model_name, + id=ctx.request_id, + created=ctx.created_time, + model=ctx.model_name, data=items, usage=usage, ) + + +class OpenAIServingEmbedding(EmbeddingMixin): + request_id_prefix = "embd" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_embedding( + self, + request: EmbeddingRequest, + raw_request: Optional[Request] = None, + ) -> Union[EmbeddingResponse, ErrorResponse]: + """ + Embedding API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + model_name = self._get_model_name(request.model) + request_id = (f"{self.request_id_prefix}-" + f"{self._base_request_id(raw_request)}") + + ctx = EmbeddingServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + + return await super().handle(ctx) # type: ignore + + @override + def _validate_request( + self, + ctx: ServeContext[EmbeddingRequest], + ) -> Optional[ErrorResponse]: + if error := super()._validate_request(ctx): + return error + + ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens + + pooling_params = ctx.request.to_pooling_params() + + try: + pooling_params.verify(self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + return None diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index bb11650815ec..c73575b48d9c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,14 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 - +import base64 +import io import json -from collections.abc import Iterable, Iterator, Mapping, Sequence +import sys +import time +from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping, + Sequence) from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus -from typing import Annotated, Any, Callable, Optional, TypedDict, Union +from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, + TypeVar, Union, cast, overload) +import torch from fastapi import Request -from pydantic import Field +from pydantic import BaseModel, ConfigDict, Field from starlette.datastructures import Headers +from typing_extensions import TypeIs + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict import vllm.envs as envs from vllm.config import ModelConfig @@ -24,22 +41,34 @@ resolve_chat_template_content_format) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, CompletionRequest, + CompletionResponse, DetokenizeRequest, EmbeddingChatRequest, EmbeddingCompletionRequest, - ErrorResponse, RerankRequest, - ScoreRequest, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, + PoolingResponse, RerankRequest, + ScoreRequest, ScoreResponse, TokenizeChatRequest, TokenizeCompletionRequest, - TranscriptionRequest) + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable -from vllm.inputs import TokensPrompt +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin + MultiModalDataDict) +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -47,13 +76,15 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import is_list_of, make_async, random_uuid +from vllm.utils import (is_list_of, make_async, merge_async_iterators, + random_uuid) logger = init_logger(__name__) CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, EmbeddingCompletionRequest, RerankRequest, - ScoreRequest, TokenizeCompletionRequest] + ClassificationRequest, ScoreRequest, + TokenizeCompletionRequest] ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] @@ -61,16 +92,114 @@ AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, TranscriptionRequest] +AnyResponse = Union[ + CompletionResponse, + ChatCompletionResponse, + EmbeddingResponse, + TranscriptionResponse, + TokenizeResponse, + PoolingResponse, + ClassificationResponse, + ScoreResponse, +] + class TextTokensPrompt(TypedDict): prompt: str prompt_token_ids: list[int] -RequestPrompt = Union[list[int], str, TextTokensPrompt] +class EmbedsPrompt(TypedDict): + prompt_embeds: torch.Tensor + + +RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt] + + +def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + +def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt) + + +RequestT = TypeVar("RequestT", bound=AnyRequest) + + +class RequestProcessingMixin(BaseModel): + """ + Mixin for request processing, + handling prompt preparation and engine input. + """ + request_prompts: Optional[Sequence[RequestPrompt]] = \ + Field(default_factory=list) + engine_prompts: Optional[Union[list[EngineTokensPrompt], + list[EngineEmbedsPrompt]]] = Field( + default_factory=list) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ResponseGenerationMixin(BaseModel): + """ + Mixin for response generation, + managing result generators and final batch results. + """ + result_generator: Optional[AsyncGenerator[tuple[int, Union[ + RequestOutput, PoolingRequestOutput]], None]] = None + final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( + default_factory=list) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, + Generic[RequestT]): + # Shared across all requests + request: RequestT + raw_request: Optional[Request] = None + model_name: str + request_id: str + created_time: int = Field(default_factory=lambda: int(time.time())) + lora_request: Optional[LoRARequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + # Shared across most requests + tokenizer: Optional[AnyTokenizer] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # `protected_namespaces` resolves Pydantic v2's warning + # on conflict with protected namespace "model_" + model_config = ConfigDict( + protected_namespaces=(), + arbitrary_types_allowed=True, + ) + + +ClassificationServeContext = ServeContext[ClassificationRequest] + + +class EmbeddingServeContext(ServeContext[EmbeddingRequest]): + chat_template: Optional[str] = None + chat_template_content_format: ChatTemplateContentFormatOption + + +# Used to resolve the Pydantic error related to +# forward reference of MultiModalDataDict in TokensPrompt +RequestProcessingMixin.model_rebuild() +ServeContext.model_rebuild() +ClassificationServeContext.model_rebuild() +EmbeddingServeContext.model_rebuild() class OpenAIServing: + request_id_prefix: ClassVar[str] = """ + A short string prepended to every requestโ€™s ID (e.g. "embd", "classify") + so you can easily tell โ€œthis ID came from Embedding vs Classification.โ€ + """ def __init__( self, @@ -100,6 +229,173 @@ def __init__( self._tokenize_prompt_input_or_inputs, executor=self._tokenizer_executor) + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """ + Default preprocessing hook. Subclasses may override + to prepare `ctx` (classification, embedding, etc.). + """ + return None + + def _build_response( + self, + ctx: ServeContext, + ) -> Union[AnyResponse, ErrorResponse]: + """ + Default response builder. Subclass may override this method + to return the appropriate response object. + """ + return self.create_error_response("unimplemented endpoint") + + async def handle( + self, + ctx: ServeContext, + ) -> Union[AnyResponse, ErrorResponse]: + generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None] + generation = self._pipeline(ctx) + + async for response in generation: + return response + + return self.create_error_response("No response yielded from pipeline") + + async def _pipeline( + self, + ctx: ServeContext, + ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]: + """Execute the request processing pipeline yielding responses.""" + if error := await self._check_model(ctx.request): + yield error + if error := self._validate_request(ctx): + yield error + + preprocess_ret = await self._preprocess(ctx) + if isinstance(preprocess_ret, ErrorResponse): + yield preprocess_ret + + generators_ret = await self._prepare_generators(ctx) + if isinstance(generators_ret, ErrorResponse): + yield generators_ret + + collect_ret = await self._collect_batch(ctx) + if isinstance(collect_ret, ErrorResponse): + yield collect_ret + + yield self._build_response(ctx) + + def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]: + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", + None) + + if truncate_prompt_tokens is not None: + if truncate_prompt_tokens <= self.max_model_len: + ctx.truncate_prompt_tokens = truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + return None + + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Schedule the request and get the result generator.""" + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], + None]] = [] + + try: + trace_headers = (None if ctx.raw_request is None else await + self._get_trace_headers(ctx.raw_request.headers)) + + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters") + + pooling_params = ctx.request.to_pooling_params() + + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + self._log_inputs( + request_id_item, + ctx.request_prompts[i], + params=pooling_params, + lora_request=ctx.lora_request, + prompt_adapter_request=ctx.prompt_adapter_request) + + # Mypy has an existing bug related to inferring the variance of + # TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + ctx.result_generator = merge_async_iterators(*generators) + + return None + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def _collect_batch( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Collect batch results from the result generator.""" + try: + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[Optional[Union[RequestOutput, + PoolingRequestOutput]]] + final_res_batch = [None] * num_prompts + + if ctx.result_generator is None: + return self.create_error_response( + "Result generator not available") + + async for i, res in ctx.result_generator: + final_res_batch[i] = res + + if None in final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts") + + ctx.final_res_batch = [ + res for res in final_res_batch if res is not None + ] + + return None + + except Exception as e: + return self.create_error_response(str(e)) + def create_error_response( self, message: str, @@ -183,6 +479,12 @@ def _normalize_prompt_text_to_input( if truncate_prompt_tokens is None: encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + elif truncate_prompt_tokens < 0: + # Negative means we cap at the model's max length + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=self.max_model_len) else: encoded = tokenizer(prompt, add_special_tokens=add_special_tokens, @@ -204,6 +506,8 @@ def _normalize_prompt_tokens_to_input( ) -> TextTokensPrompt: if truncate_prompt_tokens is None: input_ids = prompt_ids + elif truncate_prompt_tokens < 0: + input_ids = prompt_ids[-self.max_model_len:] else: input_ids = prompt_ids[-truncate_prompt_tokens:] @@ -219,13 +523,16 @@ def _validate_input( ) -> TextTokensPrompt: token_num = len(input_ids) - # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens + # Note: EmbeddingRequest, ClassificationRequest, + # and ScoreRequest doesn't have max_tokens if isinstance(request, (EmbeddingChatRequest, EmbeddingCompletionRequest, - ScoreRequest, RerankRequest)): + ScoreRequest, RerankRequest, ClassificationRequest)): + operation = { + ScoreRequest: "score", + ClassificationRequest: "classification" + }.get(type(request), "embedding generation") - operation = "score" if isinstance(request, ScoreRequest) \ - else "embedding generation" if token_num > self.max_model_len: raise ValueError( f"This model's maximum context length is " @@ -247,7 +554,7 @@ def _validate_input( # TODO(#9845): remove max_tokens when field dropped from OpenAI API max_tokens = request.max_completion_tokens or request.max_tokens else: - max_tokens = request.max_tokens + max_tokens = getattr(request, "max_tokens", None) if max_tokens is None: if token_num >= self.max_model_len: raise ValueError( @@ -275,7 +582,8 @@ def _tokenize_prompt_input( add_special_tokens: bool = True, ) -> TextTokensPrompt: """ - A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs` + A simpler implementation of + [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] that assumes single input. """ return next( @@ -296,7 +604,8 @@ def _tokenize_prompt_inputs( add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: """ - A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs` + A simpler implementation of + [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] that assumes multiple inputs. """ for text in prompt_inputs: @@ -320,10 +629,11 @@ def _tokenize_prompt_input_or_inputs( self, request: AnyRequest, tokenizer: AnyTokenizer, - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> list[TextTokensPrompt]: + ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: """ Tokenize/detokenize depending on the input format. @@ -331,11 +641,25 @@ def _tokenize_prompt_input_or_inputs( , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ + inputs_embeds = list[EmbedsPrompt]() + inputs_text = list[TextTokensPrompt]() + + if (isinstance(request, CompletionRequest) + and request.prompt_embeds is not None): + inputs_embeds.extend( + self._load_prompt_embeds(request.prompt_embeds, + truncate_prompt_tokens)) + + # Empty prompts are okay as long as there are prompt embeddings + if input_or_inputs is None or (inputs_embeds + and input_or_inputs == ""): + return [], inputs_embeds + # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly - # "is True" is required for Pyright to perform type narrowing + # "is False" is required for Pyright to perform type narrowing # See: https://github.com/microsoft/pyright/issues/7672 - return [ + inputs_text.extend([ self._normalize_prompt_text_to_input( request, tokenizer, @@ -349,29 +673,88 @@ def _tokenize_prompt_input_or_inputs( prompt_ids=prompt_input["content"], truncate_prompt_tokens=truncate_prompt_tokens) for prompt_input in parse_and_batch_prompt(input_or_inputs) - ] + ]) + + return inputs_text, inputs_embeds + @overload async def _preprocess_completion( self, - request: CompletionLikeRequest, + request: Union[DetokenizeRequest, EmbeddingCompletionRequest, + RerankRequest, ClassificationRequest, ScoreRequest, + TokenizeCompletionRequest], tokenizer: AnyTokenizer, input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: + ... + + @overload + async def _preprocess_completion( + self, + request: CompletionRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ + EngineTokensPrompt, EngineEmbedsPrompt]]]: + ... + + async def _preprocess_completion( + self, + request: CompletionLikeRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]: - request_prompts = await self._tokenize_prompt_input_or_inputs_async( - request, - tokenizer, - input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) + ) -> tuple[Union[list[TextTokensPrompt], list[Union[ + TextTokensPrompt, EmbedsPrompt]]], Union[ + list[EngineTokensPrompt], list[Union[EngineTokensPrompt, + EngineEmbedsPrompt]]]]: + if not isinstance(request, + CompletionRequest) and input_or_inputs is None: + raise ValueError( + "Prompt embeds with non-completion requests is not" + " currently supported.") + + (request_prompts_text, request_prompts_embeds + ) = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + + engine_prompts_text = [ + EngineTokensPrompt( + prompt_token_ids=request_prompt_text["prompt_token_ids"]) + for request_prompt_text in request_prompts_text + ] - engine_prompts = [ - TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) - for request_prompt in request_prompts + # This check is equivalent to simply checking if + # `request_prompts_embeds` is empty, but it's difficult to propagate + # overloads to the private helper functions to enable this check. + # This overload is needed because only TextPrompts are allowed for + # non-completion requests and if we don't add the overload here, + # everywhere this function is used outside of serving_completion will + # need logic asserting that only text prompts are in the request. + if not isinstance(request, + CompletionRequest) and input_or_inputs is not None: + return request_prompts_text, engine_prompts_text + + engine_prompts_embeds = [ + EngineEmbedsPrompt( + prompt_embeds=request_prompt_embeds["prompt_embeds"]) + for request_prompt_embeds in request_prompts_embeds ] + request_prompts = request_prompts_embeds + request_prompts_text + engine_prompts = engine_prompts_embeds + engine_prompts_text return request_prompts, engine_prompts async def _preprocess_chat( @@ -390,15 +773,15 @@ async def _preprocess_chat( truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = False, ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], - list[TokensPrompt]]: + list[EngineTokensPrompt]]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( - model_config, chat_template, tool_dicts, chat_template_content_format, tokenizer, + model_config=model_config, ) conversation, mm_data_future = parse_chat_messages_futures( messages, @@ -425,9 +808,9 @@ async def _preprocess_chat( ) else: request_prompt = apply_hf_chat_template( - model_config, - tokenizer, + tokenizer=tokenizer, conversation=conversation, + model_config=model_config, **_chat_template_kwargs, ) @@ -463,7 +846,7 @@ async def _preprocess_chat( prompt=tokenizer.decode(request_prompt), prompt_token_ids=request_prompt) - engine_prompt = TokensPrompt( + engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data @@ -475,6 +858,35 @@ async def _preprocess_chat( return conversation, [request_prompt], [engine_prompt] + def _load_prompt_embeds( + self, + prompt_embeds: Optional[Union[bytes, list[bytes]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + ) -> list[EmbedsPrompt]: + + def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: + tensor = torch.load(io.BytesIO(base64.b64decode(embed)), + weights_only=True) + assert isinstance( + tensor, + (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor)) + if tensor.dim() > 2: + tensor = tensor.squeeze(0) + assert tensor.dim() == 2 + if truncate_prompt_tokens is not None: + tensor = tensor[-truncate_prompt_tokens:] + return {"prompt_embeds": tensor} + + if prompt_embeds: + if isinstance(prompt_embeds, list): + return [ + _load_and_validate_embed(embed) for embed in prompt_embeds + ] + else: + return [_load_and_validate_embed(prompt_embeds)] + else: + return [] + def _log_inputs( self, request_id: str, @@ -486,13 +898,13 @@ def _log_inputs( ) -> None: if self.request_logger is None: return - + prompt, prompt_token_ids, prompt_embeds = None, None, None if isinstance(inputs, str): prompt = inputs - prompt_token_ids = None elif isinstance(inputs, list): - prompt = None prompt_token_ids = inputs + elif 'prompt_embeds' in inputs: + prompt_embeds = inputs.get("prompt_embeds") else: prompt = inputs["prompt"] prompt_token_ids = inputs["prompt_token_ids"] @@ -501,6 +913,7 @@ def _log_inputs( request_id, prompt, prompt_token_ids, + prompt_embeds, params=params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c642fc51005e..5ef1a486d86c 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -65,6 +65,8 @@ async def create_tokenize( tokenizer = await self.engine_client.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): + tool_dicts = (None if request.tools is None else + [tool.model_dump() for tool in request.tools]) ( _, request_prompts, @@ -73,6 +75,7 @@ async def create_tokenize( request, tokenizer, request.messages, + tool_dicts=tool_dicts, chat_template=request.chat_template or self.chat_template, chat_template_content_format=self. chat_template_content_format, @@ -91,7 +94,7 @@ async def create_tokenize( ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(f"{e} {e.__cause__}") input_ids: list[int] = [] for i, engine_prompt in enumerate(engine_prompts): @@ -103,8 +106,9 @@ async def create_tokenize( # Silently ignore prompt adapter since it does not affect # tokenization (Unlike in Embeddings API where an error is raised) - - input_ids.extend(engine_prompt["prompt_token_ids"]) + if isinstance(engine_prompt, + dict) and "prompt_token_ids" in engine_prompt: + input_ids.extend(engine_prompt["prompt_token_ids"]) return TokenizeResponse(tokens=input_ids, count=len(input_ids), diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index b81dc4e7ad7b..054c0b006b2f 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from .abstract_tool_parser import ToolParser, ToolParserManager +from .deepseekv3_tool_parser import DeepSeekV3ToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser +from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser @@ -15,5 +17,6 @@ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", - "PythonicToolParser", "Phi4MiniJsonToolParser" + "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", + "DeepSeekV3ToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py new file mode 100644 index 000000000000..14e743e13a72 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -0,0 +1,369 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("deepseek_v3") +class DeepSeekV3ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<๏ฝœtoolโ–callsโ–begin๏ฝœ>" + self.tool_calls_end_token: str = "<๏ฝœtoolโ–callsโ–end๏ฝœ>" + + self.tool_call_start_token: str = "<๏ฝœtoolโ–callโ–begin๏ฝœ>" + self.tool_call_end_token: str = "<๏ฝœtoolโ–callโ–end๏ฝœ>" + + self.tool_call_regex = re.compile( + r"<๏ฝœtoolโ–callโ–begin๏ฝœ>(?P<type>.*)<๏ฝœtoolโ–sep๏ฝœ>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<๏ฝœtoolโ–callโ–end๏ฝœ>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P<type>.*)<๏ฝœtoolโ–sep๏ฝœ>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*[^\n`])" + ) + + self.stream_tool_call_name_regex = re.compile( + r"(?P<type>.*)<๏ฝœtoolโ–sep๏ฝœ>(?P<function_name>.*)\n") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "DeepSeek-V3 Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + tool_calls = [] + for match in function_call_tuples: + tool_type, function_name, function_args = match + tool_calls.append( + ToolCall( + type=tool_type, + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_type, tool_name, tool_args = ( + current_tool_call_matches.groups()) + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_type, tool_name = ( + current_tool_call_name_matches.groups()) + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 76da63c58008..383e0d44de99 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from json import JSONDecoder from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -22,7 +23,6 @@ partial_json_loads) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -80,7 +80,8 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"]), + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), ), ) for function_call in raw_function_calls ] @@ -166,7 +167,8 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) sent = len( self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] @@ -200,7 +202,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) @@ -218,7 +220,8 @@ def extract_tool_calls_streaming( if cur_arguments: sent = len( self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) prev_arguments = self.prev_tool_call_arr[ self.current_tool_id].get("arguments") @@ -226,7 +229,8 @@ def extract_tool_calls_streaming( if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) if cur_args_json != prev_args_json: prefix = find_common_prefix( diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index 91afc88ef3dd..b8bf142530ee 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -7,6 +7,7 @@ import partial_json_parser from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -20,7 +21,6 @@ partial_json_loads) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -67,7 +67,8 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"]), + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), ), ) for function_call in raw_function_calls ] @@ -151,7 +152,8 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) sent = len( self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] @@ -182,7 +184,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) @@ -197,7 +199,8 @@ def extract_tool_calls_streaming( if cur_arguments: sent = len( self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) prev_arguments = self.prev_tool_call_arr[ self.current_tool_id].get("arguments") @@ -205,7 +208,8 @@ def extract_tool_calls_streaming( if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) if cur_args_json != prev_args_json: prefix = find_common_prefix( prev_args_json, cur_args_json) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 4c39e9b0c61f..2b9f9852bcb3 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -17,7 +18,6 @@ ToolParser, ToolParserManager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -259,7 +259,7 @@ def extract_tool_calls_streaming( return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 57d7c77c64f7..3f2799f8010a 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -7,6 +7,7 @@ import partial_json_parser from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -18,7 +19,6 @@ extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -106,7 +106,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) @@ -133,7 +133,8 @@ def extract_tool_calls_streaming( delta = None # first time to get parameters elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) arguments_delta = cur_arguments_json[:cur_arguments_json. index(delta_text) + @@ -148,8 +149,10 @@ def extract_tool_calls_streaming( self.current_tool_id] += arguments_delta # both prev and cur parameters, send the increase parameters elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) @@ -190,7 +193,8 @@ def extract_tool_calls( action_dict = json.loads(action) name, parameters = action_dict['name'], json.dumps( action_dict.get('parameters', action_dict.get('arguments', - {}))) + {})), + ensure_ascii=False) if not tools or name not in [t.function.name for t in tools]: ExtractedToolCallInformation(tools_called=False, diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 8df106bf2718..2714a545f997 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -19,7 +20,6 @@ from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -96,8 +96,9 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"]))) - for function_call in raw_function_calls + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), + )) for function_call in raw_function_calls ] content = model_output[:model_output. @@ -187,7 +188,7 @@ def extract_tool_calls_streaming( diff: Union[str, None] = current_tool_call.get("arguments") if diff: - diff = json.dumps(diff).replace( + diff = json.dumps(diff, ensure_ascii=False).replace( self.streamed_args_for_tool[self.current_tool_id], "") delta = DeltaMessage(tool_calls=[ @@ -220,7 +221,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) @@ -248,7 +249,8 @@ def extract_tool_calls_streaming( "mid-arguments") delta = None elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) logger.debug("finding %s in %s", new_text, cur_arguments_json) @@ -267,8 +269,10 @@ def extract_tool_calls_streaming( self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) logger.debug("Searching for diff between \n%s\n%s", cur_args_json, prev_args_json) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py new file mode 100644 index 000000000000..858c8db99fd2 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: Apache-2.0 +import ast +import json +from collections.abc import Sequence +from typing import Any, Union + +import regex as re +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("llama4_pythonic") +class Llama4PythonicToolParser(ToolParser): + """ + Toolcall parser for Llama4 that produce tool calls in a pythonic style + Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic + """ + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + + # remove <|python_start|> and <|python_end|> + # as Llama 4 model sometime will output those tokens + if model_output.startswith("<|python_start|>"): + model_output = model_output[len("<|python_start|>"):] + model_output = model_output.replace("<|python_end|>", "") + if not (self.TOOL_CALL_REGEX.match(model_output)): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not current_text.startswith("[") and not current_text.startswith( + "<|python_start|>"): + return DeltaMessage(content=delta_text) + + try: + # remove <|python_start|> and <|python_end|> + if current_text.startswith("<|python_start|>"): + current_text = current_text[len("<|python_start|>"):] + if current_text.endswith("<|python_end|>"): + current_text = current_text[:current_text. + rfind("<|python_end|>")] + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts): + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = index < len( + tool_calls) - 1 or ")]" not in added_text + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = (added_text[:-2] + if not new_call_complete else "") + if not new_call_complete and added_text[-2] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta(self.streamed_args_for_tool[index], + new_call, index, withheld_suffix) + + if delta is not None: + tool_deltas.append(delta) + if (delta.function is not None + and delta.function.arguments is not None): + self.streamed_args_for_tool[ + index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content='') + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError( + "Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall(type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments))) + + +def _make_valid_python(text: str) -> Union[tuple[str, str], None]: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[:text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[:text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name + if text.endswith(","): + text = text[:-1] + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( + "[") and not text.endswith(")"): + return None # Incomplete function name + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, + index: int, + withheld_suffix: str) -> Union[DeltaToolCall, None]: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[:-len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall(id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + )) + + arg_diff = new_call_args[len(previously_sent_args):] + return DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 5c181616aa01..4eda7044cbba 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from json import JSONDecoder from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -21,7 +22,6 @@ is_complete_json, partial_json_loads) from vllm.logger import init_logger -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -88,7 +88,8 @@ def extract_tool_calls( # function call args are JSON but as a string arguments=json.dumps(raw_function_call["arguments"] \ if "arguments" in raw_function_call \ - else raw_function_call["parameters"]))) + else raw_function_call["parameters"], + ensure_ascii=False))) for raw_function_call in function_call_arr ] @@ -174,7 +175,8 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) sent = len( self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] @@ -208,7 +210,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) @@ -226,7 +228,8 @@ def extract_tool_calls_streaming( if cur_arguments: sent = len( self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) prev_arguments = self.prev_tool_call_arr[ self.current_tool_id].get("arguments") @@ -234,7 +237,8 @@ def extract_tool_calls_streaming( if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) if cur_args_json != prev_args_json: prefix = find_common_prefix( diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 9dbfe85ecc68..fecad7e653ab 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from random import choices from string import ascii_letters, digits from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow from pydantic import Field diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 668776a832e2..b403a146716d 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from typing import Any, Optional +import regex as re from transformers import PreTrainedTokenizerBase +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation, @@ -14,7 +15,6 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) from vllm.logger import init_logger -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -73,16 +73,17 @@ def extract_tool_calls( tool_calls: list[ToolCall] = [ ToolCall( - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), type="function", function=FunctionCall( name=raw_function_call["name"], # function call args are JSON but as a string arguments=json.dumps( - raw_function_call["arguments"] if "arguments" in - raw_function_call else - raw_function_call["parameters"]))) - for raw_function_call in function_call_arr + raw_function_call["arguments"] + if "arguments" in raw_function_call else + raw_function_call["parameters"], + ensure_ascii=False), + )) for raw_function_call in function_call_arr ] # get any content before the tool call diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 9f141d6b334b..548ff39d1ca4 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -2,10 +2,10 @@ import ast import json -import re from collections.abc import Sequence from typing import Any, Union +import regex as re from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: arguments = {} for keyword in call.keywords: arguments[keyword.arg] = _get_parameter_value(keyword.value) - return ToolCall(type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(arguments))) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments, + ensure_ascii=False)), + ) def _make_valid_python(text: str) -> Union[tuple[str, str], None]: @@ -280,6 +283,7 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, new_call_args = new_call_args[:-len(withheld_suffix)] if not previously_sent_args: return DeltaToolCall(id=new_call.id, + type="function", index=index, function=DeltaFunctionCall( name=new_call.function.name, @@ -288,5 +292,5 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, arg_diff = new_call_args[len(previously_sent_args):] return DeltaToolCall( - id="", index=index, function=DeltaFunctionCall( + id=None, index=index, function=DeltaFunctionCall( arguments=arg_diff)) if arg_diff else None diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 2fe6e1a9e9c4..cc651a172b40 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -13,6 +13,13 @@ logger = init_logger(__name__) +VLLM_SERVE_PARSER_EPILOG = ( + "Tip: Use `vllm serve --help=<keyword>` to explore arguments from help.\n" + " - To view a argument group: --help=ModelConfig\n" + " - To view a single argument: --help=max-num-seqs\n" + " - To search by keyword: --help=max\n" + " - To list all groups: --help=listgroup") + async def listen_for_disconnect(request: Request) -> None: """Returns if a disconnect message is received""" @@ -158,3 +165,55 @@ def _validate_truncation_size( tokenization_kwargs["max_length"] = truncate_prompt_tokens return truncate_prompt_tokens + + +def show_filtered_argument_or_group_from_help(parser): + import sys + for arg in sys.argv: + if arg.startswith('--help='): + search_keyword = arg.split('=', 1)[1] + + # List available groups + if search_keyword == 'listgroup': + print("\nAvailable argument groups:") + for group in parser._action_groups: + if group.title and not group.title.startswith( + "positional arguments"): + print(f" - {group.title}") + if group.description: + print(" " + group.description.strip()) + print() + sys.exit(0) + + # For group search + formatter = parser._get_formatter() + for group in parser._action_groups: + if group.title and group.title.lower() == search_keyword.lower( + ): + formatter.start_section(group.title) + formatter.add_text(group.description) + formatter.add_arguments(group._group_actions) + formatter.end_section() + print(formatter.format_help()) + sys.exit(0) + + # For single arg + matched_actions = [] + + for group in parser._action_groups: + for action in group._group_actions: + # search option name + if any(search_keyword.lower() in opt.lower() + for opt in action.option_strings): + matched_actions.append(action) + + if matched_actions: + print(f"\nParameters matching '{search_keyword}':\n") + formatter = parser._get_formatter() + formatter.add_arguments(matched_actions) + print(formatter.format_help()) + sys.exit(0) + + print(f"\nNo group or parameter matching '{search_keyword}'") + print("Tip: use `--help=listgroup` to view all groups.") + sys.exit(1) diff --git a/vllm/envs.py b/vllm/envs.py index 134cdf9905fa..b007bf8c59b7 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -55,6 +55,7 @@ VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MM_INPUT_CACHE_GIB: int = 8 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None @@ -68,6 +69,7 @@ VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[list[str]] = None + VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False @@ -112,6 +114,10 @@ VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False + VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" + VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 + VLLM_ALL2ALL_BACKEND: str = "naive" + VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 def get_default_cache_root(): @@ -134,10 +140,43 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: return int(value) +def get_vllm_port() -> Optional[int]: + """Get the port from VLLM_PORT environment variable. + + Returns: + The port number as an integer if VLLM_PORT is set, None otherwise. + + Raises: + ValueError: If VLLM_PORT is a URI, suggest k8s service discovery issue. + """ + if 'VLLM_PORT' not in os.environ: + return None + + port = os.getenv('VLLM_PORT', '0') + + try: + return int(port) + except ValueError as err: + from urllib.parse import urlparse + try: + parsed = urlparse(port) + if parsed.scheme: + raise ValueError( + f"VLLM_PORT '{port}' appears to be a URI. " + "This may be caused by a Kubernetes service discovery issue" + "check the warning in: https://docs.vllm.ai/en/stable/usage/env_vars.html" + ) + except Exception: + pass + + raise ValueError( + f"VLLM_PORT '{port}' must be a valid integer") from err + + # The begin-* and end* here are used by the documentation generator # to extract the used env vars. -# begin-env-vars-definition +# --8<-- [start:env-vars-definition] environment_variables: dict[str, Callable[[], Any]] = { @@ -214,10 +253,8 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Note: if VLLM_PORT is set, and some code asks for multiple ports, the # VLLM_PORT will be used as the first port, and the rest will be generated # by incrementing the VLLM_PORT value. - # '0' is used to make mypy happy 'VLLM_PORT': - lambda: int(os.getenv('VLLM_PORT', '0')) - if 'VLLM_PORT' in os.environ else None, + get_vllm_port, # path used for ipc when the frontend api server is running in # multi-processing mode to communicate with the backend engine process. @@ -263,6 +300,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + # Internal flag to enable/disable Inductor standalone compile + "VLLM_TEST_STANDALONE_COMPILE": + lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0", + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": @@ -438,6 +479,16 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), + # Backend for Video IO + # - "opencv": Default backend that uses OpenCV stream buffered backend. + # + # Custom backend implementations can be registered + # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and + # imported at runtime. + # If a non-existing backend is used, an AssertionError will be thrown. + "VLLM_VIDEO_LOADER_BACKEND": + lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), + # Cache size (in GiB) for multimodal input cache # Default is 4 GiB "VLLM_MM_INPUT_CACHE_GIB": @@ -497,6 +548,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ "VLLM_PLUGINS"].split(","), + # a local directory to look in for unrecognized LoRA adapters. + # only works if plugins are enabled and + # VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled. + "VLLM_LORA_RESOLVER_CACHE_DIR": + lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), + # Enables torch profiler if set. Path to the directory where torch profiler # traces are saved. Note that it must be an absolute path. "VLLM_TORCH_PROFILER_DIR": @@ -743,9 +800,31 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # insecure method and it is needed for some reason. "VLLM_ALLOW_INSECURE_SERIALIZATION": lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), + + # IP address used for NIXL handshake between remote agents. + "VLLM_NIXL_SIDE_CHANNEL_HOST": + lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"), + + # Port used for NIXL handshake between remote agents. + "VLLM_NIXL_SIDE_CHANNEL_PORT": + lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), + + # all2all backend for vllm's expert parallel communication + # Available options: + # - "naive": naive all2all implementation using all-reduce + # - "pplx": use pplx kernels + "VLLM_ALL2ALL_BACKEND": + lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), + + # Control the maximum number of tokens per expert supported by the + # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for + # the blockscale tensor of activations NVFP4 Quantization. + # This is used to prevent the kernel from running out of memory. + "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": + lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), } -# end-env-vars-definition +# --8<-- [end:env-vars-definition] def __getattr__(name: str): @@ -805,6 +884,7 @@ def factorize(name: str): "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", "VLLM_DP_SIZE", + "VLLM_TEST_STANDALONE_COMPILE", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 522bd940211f..40ca1d29939a 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -74,7 +74,7 @@ def collective_rpc(self, `self` argument, in addition to the arguments passed in `args` and `kwargs`. The `self` argument will be the worker object. timeout: Maximum time in seconds to wait for execution. Raises a - {exc}`TimeoutError` on timeout. `None` means wait indefinitely. + [`TimeoutError`][] on timeout. `None` means wait indefinitely. args: Positional arguments to pass to the worker method. kwargs: Keyword arguments to pass to the worker method. diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 9b0b98731e03..8e67c7a41bb1 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -528,12 +528,12 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: ray.get(parallel_worker_tasks) def _check_ray_cgraph_installation(self): - import pkg_resources + import importlib.metadata + from packaging import version required_version = version.parse("2.43.0") - current_version = version.parse( - pkg_resources.get_distribution("ray").version) + current_version = version.parse(importlib.metadata.version("ray")) if current_version < required_version: raise ValueError(f"Ray version {required_version} is " f"required, but found {current_version}") diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 37cc07bfbb36..7bc98a16f041 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -87,9 +87,8 @@ def execute_model_spmd( # TODO(swang): This is needed right now because Ray Compiled Graph # executes on a background thread, so we need to reset torch's # current device. - import torch if not self.compiled_dag_cuda_device_set: - torch.cuda.set_device(self.worker.device) + current_platform.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True output = self.worker._execute_model_spmd(execute_model_req, @@ -113,8 +112,7 @@ def setup_device_if_necessary(self): # Not needed pass else: - import torch - torch.cuda.set_device(self.worker.device) + current_platform.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 2e4b47c1e24a..1d3a6e443a80 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -86,9 +86,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor): def _init_executor(self) -> None: """Initialize the worker and load the model. """ - assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \ - ("ExecutorWithExternalLauncher does not " - "support pipeline parallelism.") assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ ("ExecutorWithExternalLauncher needs deterministic " "execution, so it" diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 9ddc3d1f2c51..3c8083e3dd0d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,10 +11,6 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -31,6 +27,7 @@ @dataclass class DPMetadata: + max_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor @@ -94,8 +91,10 @@ def set_forward_context(attn_metadata: Any, dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + max_tokens_across_dp_cpu = torch.max(num_tokens_tensor) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) - dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + dp_metadata = DPMetadata(max_tokens_across_dp_cpu, + cu_tokens_across_dp_cpu) global _forward_context prev_context = _forward_context @@ -106,16 +105,6 @@ def set_forward_context(attn_metadata: Any, attn_metadata=attn_metadata, dp_metadata=dp_metadata) - # KVConnector: trigger (possibly async) load before forward. - # Each attn layer will block until the reading is complete. - trigger_kv_transfer = (attn_metadata is not None - and has_kv_transfer_group() - and is_v1_kv_transfer_group()) - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.start_load_kv(_forward_context) - try: yield finally: @@ -131,7 +120,10 @@ def set_forward_context(attn_metadata: Any, # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch - torch.cuda.synchronize() + from vllm.platforms import current_platform + synchronize = current_platform.synchronize + if synchronize is not None: + synchronize() now = time.perf_counter() # time measurement is in milliseconds batchsize_forward_time[batchsize].append( @@ -152,11 +144,4 @@ def set_forward_context(attn_metadata: Any, "(batchsize, count, median_time(ms)): %s"), forward_stats) - # KVConnector: each attn layer triggers (possibly async) save. - # Ensure all those operations complete before forward() is done. - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.wait_for_save() - _forward_context = prev_context diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0673aece9108..df4f844cd815 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -10,8 +10,9 @@ INPUT_REGISTRY = InputRegistry() """ -The global {class}`~InputRegistry` which is used by {class}`~vllm.LLMEngine` -to dispatch data processing according to the target model. +The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used +by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the +target model. """ __all__ = [ diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c83ab73b614a..843c45bd6163 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast import torch -from typing_extensions import NotRequired, TypedDict, TypeVar +from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar if TYPE_CHECKING: from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs @@ -80,24 +80,37 @@ class EmbedsPrompt(TypedDict): """ Set of possible schemas for a single prompt: -- A text prompt ({class}`str` or {class}`TextPrompt`) -- A tokenized prompt ({class}`TokensPrompt`) -- An embeddings prompt ({class}`EmbedsPrompt`) +- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt]) +- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt]) +- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) Note that "singleton" is as opposed to a data structure which encapsulates multiple prompts, i.e. of the sort which may be utilized for encoder/decoder models when the user desires to express both the encoder & decoder -prompts explicitly, i.e. {class}`ExplicitEncoderDecoderPrompt` +prompts explicitly, i.e. +[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] -A prompt of type {class}`SingletonPrompt` may be employed -as (1) input to a decoder-only model, (2) input to +A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be +employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or (3) as a member of a larger data structure encapsulating -more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt` +more than one prompt, i.e. +[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] """ + +def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + +def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt) + + _T1_co = TypeVar("_T1_co", bound=SingletonPrompt, default=SingletonPrompt, @@ -115,18 +128,20 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): comprising an explicit encoder prompt and a decoder prompt. The encoder and decoder prompts, respectively, may be formatted - according to any of the {class}`SingletonPrompt` schemas, + according to any of the + [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. mm_processor_kwargs should be at the top-level, and should not be set in the encoder/decoder prompts, since they are agnostic to the encoder/decoder. - Note that an {class}`ExplicitEncoderDecoderPrompt` may not - be used as an input to a decoder-only model, + Note that an + [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] + may not be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be - {class}`SingletonPrompt` instances. + [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances. """ encoder_prompt: _T1_co @@ -141,11 +156,11 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: -- A text prompt ({class}`str` or {class}`TextPrompt`) -- A tokenized prompt ({class}`TokensPrompt`) -- An embeddings prompt ({class}`EmbedsPrompt`) +- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt]) +- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt]) +- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) - A single data structure containing both an encoder and a decoder prompt - ({class}`ExplicitEncoderDecoderPrompt`) + ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]) """ @@ -178,7 +193,8 @@ def token_inputs( prompt: Optional[str] = None, cache_salt: Optional[str] = None, ) -> TokenInputs: - """Construct {class}`TokenInputs` from optional values.""" + """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional + values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) if prompt is not None: @@ -210,7 +226,8 @@ def embeds_inputs( prompt_embeds: torch.Tensor, cache_salt: Optional[str] = None, ) -> EmbedsInputs: - """Construct :class:`EmbedsInputs` from optional values.""" + """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional + values.""" inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds) if cache_salt is not None: @@ -221,7 +238,7 @@ def embeds_inputs( DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ -The inputs in {class}`~vllm.LLMEngine` before they are +The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are passed to the model executor. This specifies the data required for decoder-only models. """ @@ -229,11 +246,12 @@ def embeds_inputs( class EncoderDecoderInputs(TypedDict): """ - The inputs in {class}`~vllm.LLMEngine` before they are - passed to the model executor. + The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they + are passed to the model executor. This specifies the required data for encoder-decoder models. """ + encoder: Union[TokenInputs, "MultiModalInputs"] """The inputs for the encoder portion.""" @@ -243,13 +261,13 @@ class EncoderDecoderInputs(TypedDict): SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ -A processed {class}`SingletonPrompt` which can be passed to -{class}`vllm.sequence.Sequence`. +A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be +passed to [`vllm.sequence.Sequence`][]. """ ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] """ -The inputs to {data}`vllm.inputs.InputProcessor`. +The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][]. """ _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) @@ -266,7 +284,8 @@ def build_explicit_enc_dec_prompt( return ExplicitEncoderDecoderPrompt( encoder_prompt=encoder_prompt, decoder_prompt=decoder_prompt, - mm_processor_kwargs=mm_processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs, + ) def zip_enc_dec_prompts( @@ -277,7 +296,8 @@ def zip_enc_dec_prompts( ) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of - {class}`ExplicitEncoderDecoderPrompt` instances. + [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] + instances. ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same dictionary will be used for every encoder/decoder prompt. If an iterable is @@ -288,10 +308,11 @@ def zip_enc_dec_prompts( if isinstance(mm_processor_kwargs, dict): return [ build_explicit_enc_dec_prompt( - encoder_prompt, decoder_prompt, - cast(dict[str, Any], mm_processor_kwargs)) - for (encoder_prompt, - decoder_prompt) in zip(enc_prompts, dec_prompts) + encoder_prompt, + decoder_prompt, + cast(dict[str, Any], mm_processor_kwargs), + ) for (encoder_prompt, + decoder_prompt) in zip(enc_prompts, dec_prompts) ] return [ build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index d17122b48344..4c64a41ace31 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -23,13 +23,13 @@ class ParsedTokens(TypedDict): @overload def parse_and_batch_prompt( - prompt: Union[str, list[str]]) -> Sequence[ParsedText]: + prompt: Union[str, list[str]], ) -> Sequence[ParsedText]: ... @overload def parse_and_batch_prompt( - prompt: Union[list[int], list[list[int]]]) -> Sequence[ParsedTokens]: + prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]: ... @@ -86,7 +86,7 @@ class ParsedTokensPrompt(TypedDict): class ParsedEmbedsPrompt(TypedDict): - type: Literal['embeds'] + type: Literal["embeds"] content: EmbedsPrompt @@ -133,7 +133,7 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: def is_explicit_encoder_decoder_prompt( - prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 6e8effd60274..b9acabeabd8d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -67,11 +67,11 @@ def get_eos_token_id(self, return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id def get_decoder_start_token_id(self) -> Optional[int]: - ''' + """ Obtain the decoder start token id employed by an encoder/decoder model. Returns None for non-encoder/decoder models or if the model config is unavailable. - ''' + """ if not self.model_config.is_encoder_decoder: logger.warning_once( @@ -79,14 +79,14 @@ def get_decoder_start_token_id(self) -> Optional[int]: "this is not an encoder/decoder model.") return None - if (self.model_config is None or self.model_config.hf_config is None): + if self.model_config is None or self.model_config.hf_config is None: logger.warning_once( "Using None for decoder start token id because " "model config is not available.") return None dec_start_token_id = getattr(self.model_config.hf_config, - 'decoder_start_token_id', None) + "decoder_start_token_id", None) if dec_start_token_id is None: logger.warning_once( "Falling back on <BOS> for decoder start token " @@ -97,7 +97,7 @@ def get_decoder_start_token_id(self) -> Optional[int]: return dec_start_token_id def _get_default_enc_dec_decoder_prompt(self) -> list[int]: - ''' + """ Specifically for encoder/decoder models: generate a default decoder prompt for when the user specifies only the encoder prompt. @@ -126,7 +126,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> list[int]: Returns: * prompt_token_ids - ''' + """ bos_token_id = self.get_bos_token_id() assert bos_token_id is not None @@ -224,7 +224,10 @@ async def _tokenize_prompt_async( lora_request: Optional[LoRARequest], tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[int]: - """Async version of {meth}`_tokenize_prompt`.""" + """ + Async version of + [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt]. + """ tokenizer = self.get_tokenizer_group() tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) @@ -287,7 +290,10 @@ async def _process_multimodal_async( lora_request: Optional[LoRARequest], return_mm_hashes: bool = False, ) -> MultiModalInputs: - """Async version of {meth}`_process_multimodal`.""" + """ + Async version of + [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal]. + """ tokenizer = await self._get_mm_tokenizer_async(lora_request) mm_processor = self.mm_registry.create_processor(self.model_config, @@ -472,7 +478,7 @@ def _prompt_to_llm_inputs( Returns: - * {class}`SingletonInputs` instance + * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance """ parsed = parse_singleton_prompt(prompt) @@ -508,7 +514,10 @@ async def _prompt_to_llm_inputs_async( lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> SingletonInputs: - """Async version of {meth}`_prompt_to_llm_inputs`.""" + """ + Async version of + [`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs]. + """ parsed = parse_singleton_prompt(prompt) if parsed["type"] == "embeds": @@ -644,7 +653,9 @@ def _process_encoder_decoder_prompt( ) -> EncoderDecoderInputs: """ For encoder/decoder models only: - Process an input prompt into an {class}`EncoderDecoderInputs` instance. + Process an input prompt into an + [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] + instance. There are two types of input prompts: singleton prompts which carry only the @@ -670,7 +681,8 @@ def _process_encoder_decoder_prompt( Returns: - * {class}`EncoderDecoderInputs` instance + * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] + instance """ encoder_inputs: SingletonInputs decoder_inputs: Optional[SingletonInputs] @@ -710,7 +722,10 @@ async def _process_encoder_decoder_prompt_async( prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> EncoderDecoderInputs: - """Async version of {meth}`_process_encoder_decoder_prompt`.""" + """ + Async version of + [`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt]. + """ encoder_inputs: SingletonInputs decoder_inputs: Optional[SingletonInputs] @@ -778,7 +793,8 @@ def _process_decoder_only_prompt( ) -> DecoderOnlyInputs: """ For decoder-only models: - Process an input prompt into an {class}`DecoderOnlyInputs` instance. + Process an input prompt into a + [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance. Arguments: @@ -789,7 +805,7 @@ def _process_decoder_only_prompt( Returns: - * {class}`DecoderOnlyInputs` instance + * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance """ prompt_comps = self._prompt_to_llm_inputs( @@ -812,7 +828,10 @@ async def _process_decoder_only_prompt_async( prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, ) -> DecoderOnlyInputs: - """Async version of {meth}`_process_decoder_only_prompt`.""" + """ + Async version of + [`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt]. + """ prompt_comps = await self._prompt_to_llm_inputs_async( prompt, tokenization_kwargs=tokenization_kwargs, @@ -863,7 +882,10 @@ async def preprocess_async( prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, ) -> ProcessorInputs: - """Async version of {meth}`preprocess`.""" + """ + Async version of + [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. + """ if self.model_config.is_encoder_decoder: assert not return_mm_hashes, ( "Multimodal hashes for encoder-decoder models should not be ", diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index aecddbcd7515..f424a8f613ab 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -38,7 +38,7 @@ def get_hf_config( ) -> _C: """ Get the HuggingFace configuration - ({class}`transformers.PretrainedConfig`) of the model, + (`transformers.PretrainedConfig`) of the model, additionally checking its type. Raises: @@ -79,7 +79,7 @@ def get_hf_processor( ) -> _P: """ Get the HuggingFace processor - ({class}`transformers.ProcessorMixin`) of the model, + (`transformers.ProcessorMixin`) of the model, additionally checking its type. Raises: @@ -159,7 +159,7 @@ def call_hf_processor( msg = (f"Failed to apply {type(hf_processor).__name__} " f"on data={data} with kwargs={merged_kwargs}") - raise RuntimeError(msg) from exc + raise ValueError(msg) from exc class DummyData(NamedTuple): diff --git a/vllm/logger.py b/vllm/logger.py index cf32041c5b70..fd16dd95bb1b 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -68,22 +68,22 @@ class _VllmLogger(Logger): """ Note: This class is just to provide type information. - We actually patch the methods directly on the {class}`logging.Logger` + We actually patch the methods directly on the [`logging.Logger`][] instance to avoid conflicting with other libraries such as `intel_extension_for_pytorch.utils._logger`. """ def info_once(self, msg: str, *args: Hashable) -> None: """ - As {meth}`info`, but subsequent calls with the same message - are silently dropped. + As [`info`][logging.Logger.info], but subsequent calls with + the same message are silently dropped. """ _print_info_once(self, msg, *args) def warning_once(self, msg: str, *args: Hashable) -> None: """ - As {meth}`warning`, but subsequent calls with the same message - are silently dropped. + As [`warning`][logging.Logger.warning], but subsequent calls with + the same message are silently dropped. """ _print_warning_once(self, msg, *args) diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index 169e24794095..47ce0ab188bd 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -18,7 +18,7 @@ def prepare_object_to_dump(obj) -> str: if isinstance(obj, str): - return "'{obj}'" # Double quotes + return f"'{obj}'" # Double quotes elif isinstance(obj, dict): dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \ for k, v in obj.items()}) @@ -42,9 +42,9 @@ def prepare_object_to_dump(obj) -> str: return obj.anon_repr() elif hasattr(obj, '__dict__'): items = obj.__dict__.items() - dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \ + dict_str = ', '.join([f'{str(k)}={prepare_object_to_dump(v)}' \ for k, v in items]) - return (f"{type(obj).__name__}({dict_str})") + return f"{type(obj).__name__}({dict_str})" else: # Hacky way to make sure we can serialize the object in JSON format try: diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index e195f8cf5e8e..b6b138a44051 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # pylint: disable=unused-argument -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast import torch import torch.nn as nn @@ -118,7 +118,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # specifying kwargs so they can be easily accessed in decorator @@ -141,8 +141,8 @@ class MergedColumnParallelLinearWithShardedLoRA( """ def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: #NOTE: lora_a contains 2 subloras, and each sublora could be None. output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size @@ -165,7 +165,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # specifying kwargs so they can be easily accessed in decorator @@ -201,7 +201,7 @@ def apply(self, @classmethod @_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, + lora_config: LoRAConfig, packed_modules_list: list, model_config: Optional[PretrainedConfig]) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( @@ -222,8 +222,8 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): """ def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: # NOTE: lora_a contains 3 subloras, and each sublora could be None. shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] @@ -248,7 +248,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # specifying kwargs so they can be easily accessed in decorator @@ -281,7 +281,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: if bias is None: return bias - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) shard_size = self.lora_bias_stacked[0].shape[2] start_idx = self.tp_rank * shard_size @@ -341,7 +341,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # specifying kwargs so they can be easily accessed in decorator diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6749ec16a097..023c8e9c9a86 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -3,7 +3,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast import torch import torch.nn as nn @@ -82,14 +82,14 @@ class LoRAMapping(AdapterMapping): class BaseLayerWithLoRA(nn.Module): def slice_lora_a( - self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: """Slice lora a if splitting for tensor parallelism.""" ... def slice_lora_b( - self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: """Slice lora b if splitting with tensor parallelism.""" ... @@ -128,7 +128,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" @@ -140,7 +140,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer - self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_slice: Optional[tuple[int, int]] self.embeddings_weights: Optional[torch.Tensor] def create_lora_weights( @@ -279,7 +279,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return type(source_layer) is VocabParallelEmbedding @@ -296,9 +296,9 @@ def __init__(self, base_layer: LinearBase): self.base_layer = base_layer self.input_size = self.base_layer.input_size self.device = _get_lora_device(self.base_layer) - self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None + self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - self.output_slices: Tuple[int, ...] + self.output_slices: tuple[int, ...] self.tp_size: int self.output_size: int self.n_slices: int @@ -365,7 +365,7 @@ def reset_lora(self, index: int): self.lora_b_stacked[s_index][index] = 0 if self.lora_config.bias_enabled: # Make mypy happy - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) self.lora_bias_stacked[s_index][index] = 0 @@ -399,7 +399,7 @@ def set_lora( lora_b.T, non_blocking=True) if lora_bias is not None: - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) assert len(self.lora_bias_stacked) self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( @@ -497,7 +497,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return type(source_layer) is ReplicatedLinear @@ -597,7 +597,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return type(source_layer) is ColumnParallelLinear or ( @@ -674,13 +674,13 @@ def create_lora_weights( ) for output_size in self.output_slices) def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: return lora_a def slice_lora_b( - self, lora_b: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: + self, lora_b: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: for i, (shard_id, shard_size) in enumerate( zip(self.output_ids, self.output_slices)): if (lora_b_i := lora_b[i]) is not None: @@ -689,8 +689,8 @@ def slice_lora_b( return lora_b def slice_bias( - self, bias: List[Union[torch.Tensor, - None]]) -> List[Union[torch.Tensor, None]]: + self, bias: list[Union[torch.Tensor, + None]]) -> list[Union[torch.Tensor, None]]: for i, (shard_id, shard_size) in enumerate( zip(self.output_ids, self.output_slices)): if (bias_i := bias[i]) is not None: @@ -725,7 +725,7 @@ def set_lora( lora_b_i.T, non_blocking=True) if lora_bias is not None: - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) for i in range(self.n_slices): if (lora_bias_i := lora_bias[i]) is not None: @@ -740,7 +740,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return (type(source_layer) is MergedColumnParallelLinear @@ -809,7 +809,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: @classmethod @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, + lora_config: LoRAConfig, packed_modules_list: list, model_config: Optional[PretrainedConfig]) -> bool: return type(source_layer) is QKVParallelLinear and len( packed_modules_list) == 1 @@ -869,7 +869,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return (type(source_layer) is QKVParallelLinear @@ -923,7 +923,7 @@ def forward( - output - bias """ - # Set up backprop all-reduce. + # set up backprop all-reduce. if self.base_layer.input_is_parallel: input_parallel = input_ else: @@ -958,7 +958,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return type(source_layer) is RowParallelLinear @@ -981,7 +981,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: LogitsProcessor, hidden_size: int, dtype: torch.dtype, device: torch.device, - sharded_to_full_mapping: Optional[List[int]]) -> None: + sharded_to_full_mapping: Optional[list[int]]) -> None: super().__init__() self.base_layer = base_layer self.hidden_size = hidden_size @@ -1189,7 +1189,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # Special handling for the LogitsProcessor. @@ -1256,7 +1256,7 @@ def forward( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.base_layer( positions, query, @@ -1265,7 +1265,7 @@ def forward( ) @property - def scaling_factor_to_offset(self) -> Dict[float, int]: + def scaling_factor_to_offset(self) -> dict[float, int]: return self.base_layer.scaling_factor_to_offset @classmethod @@ -1273,7 +1273,7 @@ def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 00299bf6c2a8..294b49e0a899 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional -from typing import Sequence as GenericSequence +from collections.abc import Sequence as GenericSequence +from typing import Optional import torch import torch.types @@ -125,11 +125,11 @@ def __init__( self, module_name: str, rank: int, - lora_alphas: List[Optional[int]], - lora_a: List[Optional[torch.Tensor]], - lora_b: List[Optional[torch.Tensor]], - bias: Optional[List[Optional[torch.Tensor]]] = None, - scaling: Optional[List[float]] = None, + lora_alphas: list[Optional[int]], + lora_a: list[Optional[torch.Tensor]], + lora_b: list[Optional[torch.Tensor]], + bias: Optional[list[Optional[torch.Tensor]]] = None, + scaling: Optional[list[float]] = None, ) -> None: super().__init__( module_name=module_name, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9f9d808679d7..af5cebdf2a8b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,11 +3,11 @@ import copy import math import os -import re +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type, - Union) +from typing import Any, Callable, Optional, Union +import regex as re import safetensors.torch import torch from torch import nn @@ -29,6 +29,7 @@ get_supported_lora_modules, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -44,12 +45,12 @@ class LongContextLoRAContext: """Context for lora adapters that support long context.""" # The scaling factors to support long context lora fine tuned models. - scaling_factors: List[float] + scaling_factors: list[float] # dimension to apply rotary embedding. rot_dim: int # offsets to the sin_cos_cache for each lora_id loaded. # This value is dynamically modified. - offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) + offsets_by_lora_id: dict[int, int] = field(default_factory=dict) def get_lora_id(): @@ -65,7 +66,7 @@ def __init__( self, lora_model_id: int, rank: int, - loras: Dict[str, LoRALayerWeights], + loras: dict[str, LoRALayerWeights], scaling_factor: Optional[float] = None, ) -> None: """ @@ -84,7 +85,7 @@ def __init__( lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" self.rank = rank - self.loras: Dict[str, LoRALayerWeights] = loras + self.loras: dict[str, LoRALayerWeights] = loras def clone(self, lora_model_id: int) -> "LoRAModel": """Return a copy of the object with different ids. @@ -113,19 +114,19 @@ def check_lora_name(self, lora_name: str) -> bool: def from_lora_tensors( cls, lora_model_id: int, - tensors: Dict[str, torch.Tensor], + tensors: dict[str, torch.Tensor], peft_helper: PEFTHelper, device: str = "cuda", dtype: Optional[torch.dtype] = None, - embeddings: Optional[Dict[str, torch.Tensor]] = None, + embeddings: Optional[dict[str, torch.Tensor]] = None, target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[Dict[str, str]] = None, - embedding_padding_modules: Optional[List[str]] = None, + embedding_modules: Optional[dict[str, str]] = None, + embedding_padding_modules: Optional[list[str]] = None, weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and is_pin_memory_available() - loras: Dict[str, LoRALayerWeights] = {} + loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( tensor_name, weights_mapper) @@ -185,19 +186,19 @@ def from_lora_tensors( @classmethod def from_local_checkpoint( - cls, - lora_dir: str, - expected_lora_modules: List[str], - peft_helper: PEFTHelper, - *, - lora_model_id: Optional[int] = None, - device: str = "cuda", - dtype: Optional[torch.dtype] = None, - target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[Dict[str, str]] = None, - embedding_padding_modules: Optional[List[str]] = None, - weights_mapper: Optional[WeightsMapper] = None, - ) -> "LoRAModel": + cls, + lora_dir: str, + expected_lora_modules: list[str], + peft_helper: PEFTHelper, + *, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[dict[str, str]] = None, + embedding_padding_modules: Optional[list[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. Args: @@ -219,10 +220,36 @@ def from_local_checkpoint( lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") + tensors: dict[str, torch.Tensor] = {} + unexpected_modules: list[Union[list[str], str]] = [] + + def check_unexpected_modules(modules: dict): + for lora_module in modules.keys(): # noqa + module_name, _, _ = parse_fine_tuned_lora_name( + lora_module, weights_mapper) + part_name = module_name.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module_name) + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") - unexpected_modules: List[Union[list[str], str]] - if os.path.isfile(lora_tensor_path): - tensors: Dict[str, torch.Tensor] = {} + if tensorizer_config_dict: + from tensorizer import TensorDeserializer + + tensorizer_config = TensorizerConfig(**tensorizer_config_dict) + lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir, + "adapter_model.tensors") + tensorizer_args = tensorizer_config._construct_tensorizer_args() + tensors = TensorDeserializer(lora_tensor_path, + dtype=tensorizer_config.dtype, + **tensorizer_args.deserializer_params) + check_unexpected_modules(tensors) + + elif os.path.isfile(lora_tensor_path): # Find unexpected modules. # Use safetensor key as a source of truth to find expected modules. # in peft if you have target_modules A, B, C and C does not exist @@ -232,20 +259,8 @@ def from_local_checkpoint( unexpected_modules = [] with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore - for lora_module in f.keys(): # noqa - module_name, _, _ = parse_fine_tuned_lora_name( - lora_module, weights_mapper) - part_name = module_name.split(".")[-1] - if part_name not in expected_lora_modules: - unexpected_modules.append(module_name) - if unexpected_modules: - raise ValueError( - f"While loading {lora_dir}, expected" - f" target modules in {expected_lora_modules}" - f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct" - ) # Load tensors if there are only expected modules. + check_unexpected_modules(f) for module in f.keys(): # noqa tensors[module] = f.get_tensor(module) elif os.path.isfile(lora_bin_file_path): @@ -329,7 +344,7 @@ def __init__( self.max_num_seqs = max_num_seqs assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 - self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None self.punica_wrapper = get_punica_wrapper( @@ -339,7 +354,7 @@ def __init__( max_loras=self.lora_config.max_loras) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. - self.scaling_factor_to_offset: Dict[float, int] = {} + self.scaling_factor_to_offset: dict[float, int] = {} super().__init__(model) self.supported_lora_modules = get_supported_lora_modules(self.model) @@ -358,9 +373,9 @@ def __init__( # text modules (e.g. ChatGLM) and hasattr(self.model, "get_mm_mapping")) self.is_pooling_model = is_pooling_model(self.model) - self.packed_modules: Dict[str, List[str]] = {} - self.modules: Dict[str, BaseLayerWithLoRA] = {} - # Dict instead of a Set for compatibility with LRUCache. + self.packed_modules: dict[str, list[str]] = {} + self.modules: dict[str, BaseLayerWithLoRA] = {} + # Dict instead of a set for compatibility with LRUCache. self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self @@ -530,7 +545,7 @@ def create_dummy_lora( lora_id: int, rank: int, scaling_factor: Optional[float], - embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: + embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): @@ -578,7 +593,7 @@ def create_dummy_lora( else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] - subloras: List[Optional[LoRALayerWeights]] = [] + subloras: list[Optional[LoRALayerWeights]] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, @@ -630,8 +645,8 @@ def _register_packed_modules(self, module_full_name: str) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): - replacement_loras: List[Optional[LoRALayerWeights]] = [] - replaced_module: Set[str] = set() + replacement_loras: list[Optional[LoRALayerWeights]] = [] + replaced_module: set[str] = set() has_replacement = False for r in new_module_names: lora = self._get_lora_layer_weights(lora_model, r) @@ -694,7 +709,7 @@ def remove_adapter(self, adapter_id: int) -> bool: return remove_adapter(adapter_id, self._registered_adapters, self.deactivate_adapter) - def list_adapters(self) -> Dict[int, Any]: + def list_adapters(self) -> dict[int, Any]: return list_adapters(self._registered_adapters) def get_adapter(self, adapter_id: int) -> Optional[Any]: @@ -721,7 +736,7 @@ def __init__(self, model: nn.Module, max_num_seqs: int, self._active_adapters: LoRALRUCache = LoRALRUCache( self.lora_slots, self._deactivate_adapter) - def list_adapters(self) -> Dict[int, LoRAModel]: + def list_adapters(self) -> dict[int, LoRAModel]: """List all registered LoRAModels.""" return dict(self._registered_adapters.cache) @@ -786,7 +801,7 @@ def create_lora_manager( vocab_size: int, lora_config: LoRAConfig, device: torch.device, - lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" if not hasattr(model, "packed_modules_mapping"): diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index e41ae1d9594a..9feb9e462459 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -6,8 +6,6 @@ https://arxiv.org/abs/2310.18547 """ -from typing import List - import torch import triton import triton.language as tl @@ -127,7 +125,7 @@ def _lora_expand_kernel( @torch.inference_mode() def _lora_expand( inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] - lora_b_weights: List[ + lora_b_weights: list[ torch.Tensor], # shape [num_lora, hidden_size, lora_rank] output_tensor: torch. Tensor, # shape [num_tokens, hidden_size * num_slices] @@ -143,7 +141,7 @@ def _lora_expand( """ Args: inputs (torch.Tensor): input tensor - lora_b_weights (List[torch.Tensor]): lora'b weight + lora_b_weights (list[torch.Tensor]): lora'b weight output_tensor (torch.Tensor): output tensor token_lora_mapping (torch.Tensor): A tensor mapping each input token to the lora-id related to that token. A value of -1 indicates that @@ -155,7 +153,7 @@ def _lora_expand( lora_token_start_loc (torch.Tensor): A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] - identifies the the region in token_indices_sorted_by_lora_ids that + identifies the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] should process. lora_ids (torch.Tensor): LoRA ids to process. no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates @@ -264,7 +262,7 @@ def _lora_expand( def _lora_expand_fake( inputs: torch.Tensor, - lora_b_weights: List[torch.Tensor], + lora_b_weights: list[torch.Tensor], output_tensor: torch.Tensor, token_lora_mapping: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor, diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index 055e78f406f3..ac459a83220c 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -4,7 +4,7 @@ """ from dataclasses import dataclass -from typing import Tuple, Union +from typing import Union import torch @@ -125,7 +125,7 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: def meta_args( self, token_nums: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ This function returns the kernel metadata required for the current diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index fb0422cf0b0e..c3871bd58ffa 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -6,8 +6,6 @@ https://arxiv.org/abs/2310.18547 """ -from typing import List - import torch import triton import triton.language as tl @@ -98,7 +96,7 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, @torch.inference_mode() def _lora_shrink( inputs: torch.Tensor, # shape [num_tokens, hidden_size] - lora_a_weights: List[ + lora_a_weights: list[ torch.Tensor], # shape [num_loras, lora_rank, hidden_size] output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] token_lora_mapping: torch.Tensor, # shape [num_tokens] @@ -112,7 +110,7 @@ def _lora_shrink( """ Args: inputs (torch.Tensor): Input tensor - lora_a_weights (List[torch.Tensor]): LoRA weights + lora_a_weights (list[torch.Tensor]): LoRA weights output_tensor (torch.Tensor): output tensor token_lora_mapping (torch.Tensor): A tensor mapping each input token to the lora-id related to that token. A value of -1 indicates that @@ -219,7 +217,7 @@ def _lora_shrink( def _lora_shrink_fake( inputs: torch.Tensor, - lora_a_weights: List[torch.Tensor], + lora_a_weights: list[torch.Tensor], output_tensor: torch.Tensor, token_lora_mapping: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor, diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index f779bbccd31a..6225635c2955 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Tuple - import torch -_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} -_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} +_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} +_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} -def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device): +def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): """ `_LORA_A_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. @@ -53,7 +51,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device): return _LORA_A_PTR_DICT.get(key) -def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, +def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, device: torch.device): """ `_LORA_B_PTR_DICT` collects the required information during `profile_run`, diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index f6944368b36e..7d335e5f7fab 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -6,10 +6,11 @@ import math import os from dataclasses import MISSING, dataclass, field, fields -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from vllm.config import LoRAConfig from vllm.logger import init_logger +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig logger = init_logger(__name__) @@ -40,7 +41,7 @@ class PEFTHelper: vllm_max_position_embeddings: Optional[int] = field(default=False) vllm_long_context_scaling_factor: Optional[float] = field(default=None) - def _validate_features(self) -> List[str]: + def _validate_features(self) -> list[str]: """ Check if there are any unsupported LoRA features. """ @@ -89,12 +90,31 @@ def from_dict(cls, config_dict: dict) -> "PEFTHelper": return cls(**filtered_dict) @classmethod - def from_local_dir(cls, lora_path: str, - max_position_embeddings: Optional[int]) -> "PEFTHelper": + def from_local_dir( + cls, + lora_path: str, + max_position_embeddings: Optional[int], + tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper": lora_config_path = os.path.join(lora_path, "adapter_config.json") - with open(lora_config_path) as f: - config = json.load(f) + if tensorizer_config_dict: + tensorizer_config = TensorizerConfig(**tensorizer_config_dict) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + from tensorizer.stream_io import open_stream + lora_config_path = os.path.join(tensorizer_config.lora_dir, + "adapter_config.json") + with open_stream(lora_config_path, + mode="rb", + **tensorizer_args.stream_params) as f: + config = json.load(f) + + logger.info("Successfully deserialized LoRA config from %s", + tensorizer_config.lora_dir) + + else: + with open(lora_config_path) as f: + config = json.load(f) + config["vllm_max_position_embeddings"] = max_position_embeddings return cls.from_dict(config) diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 78866c51895b..e03f7329021b 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -7,7 +7,7 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import torch @@ -28,7 +28,7 @@ class PunicaWrapperABC(ABC): def update_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -43,9 +43,9 @@ def update_metadata( @abstractmethod def add_shrink( self, - y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + y: Union[tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], + lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs, ) -> Optional[torch.Tensor]: @@ -59,10 +59,10 @@ def add_shrink( def add_expand( self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, @@ -91,13 +91,13 @@ def add_lora_embedding( def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. @@ -150,7 +150,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, # 5 is the number of indices tensors. # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices,long_lora_indices - self.indices_len: List[Optional[int]] = [None] * 5 + self.indices_len: list[Optional[int]] = [None] * 5 # these attributes are the information required for sgmv kernel self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, @@ -171,7 +171,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, def _update_base_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -228,8 +228,8 @@ def _apply_bias( self, indices: torch.Tensor, output: torch.Tensor, - output_slices: Tuple[int, ...], - lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + output_slices: tuple[int, ...], + lora_bias_stacked: tuple[Optional[torch.Tensor], ...], ): """Applies bias to output @@ -259,7 +259,7 @@ def _apply_bias( @property def prefill_metadata( self - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. @@ -323,7 +323,7 @@ def long_lora_indices(self) -> torch.Tensor: def update_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -341,8 +341,8 @@ def update_metadata( self.is_prefill = False @abstractmethod - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -352,9 +352,9 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y[i] += (x @ lora_a_stacked[i]) * scale Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ @@ -364,10 +364,10 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], @abstractmethod def add_expand(self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs) -> Optional[torch.Tensor]: @@ -384,11 +384,11 @@ def add_expand(self, Args: y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight - output_slices (Tuple[int, ...]): Every slice's size + output_slices (tuple[int, ...]): Every slice's size offset_start (int): The starting position of y, defaults to 0 add_inputs (bool): Defaults to True. @@ -422,13 +422,13 @@ def add_lora_embedding(self, def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. @@ -445,12 +445,12 @@ def add_lora_linear(self, Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ # TODO: implement it based on torch ops raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 29428f4cfff3..8118a72d696a 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch @@ -150,8 +150,8 @@ def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, shrink_fun(y, x, w_t_all, scale) y = y.view_as(y_org) - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs): """ Performs GEMM for multiple slices of lora_a. @@ -165,9 +165,9 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y[i] += (x @ lora_a_stacked[i]) * scale Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ @@ -179,10 +179,10 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], def add_expand(self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs) -> None: @@ -198,11 +198,11 @@ def add_expand(self, Args: y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight - output_slices (Tuple[int, ...]): Every slice's size + output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y @@ -250,13 +250,13 @@ def add_lora_embedding(self, def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> None: """ Applicable to linear-related lora. @@ -273,12 +273,12 @@ def add_lora_linear(self, Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index bb6d2808e46a..224640ec7192 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -6,7 +6,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final +from typing import TYPE_CHECKING, Optional, Union, final import torch @@ -57,7 +57,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, def update_metadata( self, mapping: LoRAMapping, - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -74,7 +74,7 @@ def update_metadata( self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) def add_shrink(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs): """ Performs GEMM for multiple slices of lora_a. @@ -86,7 +86,7 @@ def add_shrink(self, y: torch.Tensor, x: torch.Tensor, Args: y (torch.Tensor): Output tensors x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ @@ -102,9 +102,9 @@ def add_shrink(self, y: torch.Tensor, x: torch.Tensor, def add_expand(self, y: torch.Tensor, x: torch.Tensor, - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs) -> None: @@ -121,10 +121,10 @@ def add_expand(self, Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight - output_slices (Tuple[int, ...]): Every slice's size + output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y @@ -181,11 +181,11 @@ def add_lora_embedding(self, def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, buffer: Optional[torch.Tensor] = None, **kwargs) -> None: @@ -204,11 +204,11 @@ def add_lora_linear(self, Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. + output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[torch.Tensor]): Defaults to None. """ diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py index 3661a7214648..416c23e73bf8 100644 --- a/vllm/lora/punica_wrapper/punica_hpu.py +++ b/vllm/lora/punica_wrapper/punica_hpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final +from typing import TYPE_CHECKING, Optional, Union, final import torch from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, @@ -28,7 +28,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, def _update_base_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -48,9 +48,9 @@ def _update_base_metadata( # graph accumulation. Hence HPU appends `lora_offset` to a list and # converts it to a tensor only after it is ready. if long_lora_context: - index_mapping_indices: List[int] = list( + index_mapping_indices: list[int] = list( mapping.index_mapping).copy() - long_lora_offsets: List[int] = [] + long_lora_offsets: list[int] = [] for i in range(len(index_mapping_indices)): lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) @@ -85,13 +85,13 @@ def add_lora_embedding(self, def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> None: y_org = y x = x.view(-1, x.shape[-1]) @@ -122,9 +122,9 @@ def add_lora_logits(self, def add_shrink( self, - y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + y: Union[tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], + lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs, ) -> None: @@ -133,10 +133,10 @@ def add_shrink( def add_expand( self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 37544c755d90..f3153c6dab03 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch import torch.nn.functional as F @@ -77,8 +77,8 @@ def expand_slice(self, y: torch.Tensor, x: torch.Tensor, self._get_token_lora_indices(x), y_offset, y_slice_size, add_inputs) - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -88,9 +88,9 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y[i] += (x @ lora_a_stacked[i]) * scale Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ @@ -106,10 +106,10 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], def add_expand(self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs) -> torch.Tensor: @@ -125,11 +125,11 @@ def add_expand(self, Args: y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight - output_slices (Tuple[int, ...]): Every slice's size + output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y @@ -177,13 +177,13 @@ def add_lora_embedding(self, def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> torch.Tensor: """ Applicable to linear-related lora. @@ -200,12 +200,12 @@ def add_lora_linear(self, Args: y (torch.Tensor): Output tensor. Will not be changed in-place. x (torch.Tensor): Input tensor (T, E) - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) @@ -284,8 +284,8 @@ def _apply_bias( self, indices: torch.Tensor, output: torch.Tensor, - output_slices: Tuple[int, ...], - lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + output_slices: tuple[int, ...], + lora_bias_stacked: tuple[Optional[torch.Tensor], ...], ): """Applies bias to output diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index f4e5542b177d..1adb40b4c284 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import torch @@ -12,7 +12,7 @@ def compute_meta( token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function @@ -43,14 +43,14 @@ def compute_meta( # TODO see if this can be vectorized def convert_mapping( mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, device: torch.device, long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], list[int]]: """Converts LoRAMapping to index tensors. Args: @@ -84,7 +84,7 @@ def convert_mapping( (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_indices). """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + index_mapping_indices: list[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None @@ -92,7 +92,7 @@ def convert_mapping( long_lora_offsets = torch.zeros(len(index_mapping_indices), device=device, dtype=torch.long) - prompt_mapping: List[int] = [ + prompt_mapping: list[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] @@ -109,7 +109,7 @@ def convert_mapping( index_mapping_indices[i], 0) long_lora_offsets[i] = lora_offset - indices_list: List[Union[List[int], torch.Tensor]] = [ + indices_list: list[Union[list[int], torch.Tensor]] = [ index_mapping_indices, lora_indices, embedding_indices, diff --git a/vllm/lora/request.py b/vllm/lora/request.py index badfaa419377..616e94f8d678 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -31,6 +31,7 @@ class LoRARequest( lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None base_model_name: Optional[str] = msgspec.field(default=None) + tensorizer_config_dict: Optional[dict] = None def __post_init__(self): if self.lora_local_path: diff --git a/vllm/lora/resolver.py b/vllm/lora/resolver.py index 6726ca9a903f..33f35322fe85 100644 --- a/vllm/lora/resolver.py +++ b/vllm/lora/resolver.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections.abc import Set from dataclasses import dataclass, field -from typing import AbstractSet, Dict, Optional +from typing import Optional from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -40,9 +41,9 @@ async def resolve_lora(self, base_model_name: str, @dataclass class _LoRAResolverRegistry: - resolvers: Dict[str, LoRAResolver] = field(default_factory=dict) + resolvers: dict[str, LoRAResolver] = field(default_factory=dict) - def get_supported_resolvers(self) -> AbstractSet[str]: + def get_supported_resolvers(self) -> Set[str]: """Get all registered resolver names.""" return self.resolvers.keys() diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 01064e5d007e..619dd3bdc40a 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import os -import re -from typing import List, Optional, Set, Tuple, Type, Union +from typing import Optional, Union import huggingface_hub +import regex as re from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, HFValidationError, RepositoryNotFoundError) from torch import nn @@ -37,7 +37,7 @@ logger = init_logger(__name__) -_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { +_all_lora_classes: set[type[BaseLayerWithLoRA]] = { VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, @@ -58,7 +58,7 @@ def from_layer(layer: nn.Module, max_loras: int, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig] = None) -> nn.Module: for lora_cls in _all_lora_classes: # specifying kwargs so they can be easily accessed in decorator @@ -99,7 +99,7 @@ def replace_submodule(model: nn.Module, module_name: str, def parse_fine_tuned_lora_name( name: str, weights_mapper: Optional[WeightsMapper] = None -) -> Tuple[str, bool, bool]: +) -> tuple[str, bool, bool]: """Parse the name of lora weights. args: @@ -108,7 +108,7 @@ def parse_fine_tuned_lora_name( weights_mapper: maps the name of weight, e.g. `model.` -> `language_model.model.`, return: - Tuple(module_name, is_lora_a): + tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, is_lora_a whether the tensor is lora_a or lora_b. is_bias whether the tensor is lora bias. @@ -147,8 +147,8 @@ def parse_fine_tuned_lora_name( raise ValueError(f"{name} is unsupported LoRA weight") -def is_regex_target_modules(load_modules: Union[str, List[str]], - expected_lora_modules: List[str]) -> bool: +def is_regex_target_modules(load_modules: Union[str, list[str]], + expected_lora_modules: list[str]) -> bool: """ PEFT supports passing `target_modules` in the form of regular expressions, such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to @@ -179,11 +179,11 @@ def is_subset(sub_list, full_list): return False -def get_supported_lora_modules(model: nn.Module) -> List[str]: +def get_supported_lora_modules(model: nn.Module) -> list[str]: """ In vLLM, all linear layers support LoRA. """ - supported_lora_modules: Set[str] = set() + supported_lora_modules: set[str] = set() # step1: traverse the model to get all the linear subfixes. for name, module in model.named_modules(): if isinstance(module, (LinearBase, )): diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 108beb34b244..afc8a8dc3b26 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from typing import Any, Dict, List, Literal, Optional, Set, Type, Union +from typing import Any, Literal, Optional, Union import torch @@ -27,7 +27,7 @@ class WorkerLoRAManager(AbstractWorkerManager): Every request, the requested LoRAs will be loaded (unless they are already loaded), and every other LoRA will be unloaded.""" - _manager_cls: Type[LoRAModelManager] = LoRAModelManager + _manager_cls: type[LoRAModelManager] = LoRAModelManager def __init__( self, @@ -36,9 +36,9 @@ def __init__( vocab_size: int, lora_config: LoRAConfig, device: torch.device, - embedding_modules: Dict[str, str], - embedding_padding_modules: List[str], - lora_model_cls: Type[LoRAModel] = LoRAModel, + embedding_modules: dict[str, str], + embedding_padding_modules: list[str], + lora_model_cls: type[LoRAModel] = LoRAModel, max_position_embeddings: Optional[int] = None, ): self._lora_model_cls = lora_model_cls @@ -88,7 +88,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: self._adapter_manager.supported_lora_modules) packed_modules_mapping = ( self._adapter_manager.packed_modules_mapping) - expected_lora_modules: List[str] = [] + expected_lora_modules: list[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: expected_lora_modules.extend( @@ -100,7 +100,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: lora_path = get_adapter_absolute_path(lora_request.lora_path) peft_helper = PEFTHelper.from_local_dir( - lora_path, self.max_position_embeddings) + lora_path, self.max_position_embeddings, + lora_request.tensorizer_config_dict) # Validates the LoRA configuration against requirements before # loading weights, throwing an exception if validation fails. @@ -125,6 +126,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, + tensorizer_config_dict=lora_request.tensorizer_config_dict, weights_mapper=hf_to_vllm_mapper) except FileNotFoundError as e: @@ -162,12 +164,12 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def pin_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.pin_adapter(adapter_id) - def set_active_adapters(self, requests: Set[Any], + def set_active_adapters(self, requests: set[Any], mapping: Optional[Any]) -> None: set_active_adapters_worker(requests, mapping, self._apply_adapters, self._adapter_manager.set_adapter_mapping) - def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + def _apply_adapters(self, adapter_requests: set[Any]) -> None: apply_adapters_worker(adapter_requests, self.list_adapters, self._adapter_manager.adapter_slots, self.remove_adapter, self.add_adapter) @@ -184,7 +186,7 @@ def remove_adapter(self, adapter_id: int) -> bool: def remove_all_adapters(self): self._adapter_manager.remove_all_adapters() - def list_adapters(self) -> Set[int]: + def list_adapters(self) -> set[int]: return list_adapters_worker(self._adapter_manager.list_adapters) @@ -195,7 +197,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): (unless they are already loaded) and least recently used LoRAs will be unloaded if the cache is above capacity.""" - _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager def create_lora_manager( self, @@ -213,7 +215,7 @@ def create_lora_manager( self._adapter_manager = lora_manager return lora_manager.model - def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: + def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index b0d00ee48187..acf7224675e4 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Type - import torch.nn as nn from vllm.config import get_current_vllm_config @@ -138,7 +136,7 @@ def default_on() -> bool: # Examples: # - MyOp.enabled() # - op_registry["my_op"].enabled() - op_registry: Dict[str, Type['CustomOp']] = {} + op_registry: dict[str, type['CustomOp']] = {} # Decorator to register custom ops. @classmethod diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py index 0b1f4762bc73..58adcc3caff9 100644 --- a/vllm/model_executor/guided_decoding/guidance_decoding.py +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import json -from re import escape as regex_escape import llguidance +from regex import escape as regex_escape from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.guidance_logits_processors import ( diff --git a/vllm/model_executor/guided_decoding/guidance_logits_processors.py b/vllm/model_executor/guided_decoding/guidance_logits_processors.py index 26fcafe31c76..e17df68b4b4d 100644 --- a/vllm/model_executor/guided_decoding/guidance_logits_processors.py +++ b/vllm/model_executor/guided_decoding/guidance_logits_processors.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import copy import os -from typing import Any, List +from typing import Any import llguidance import llguidance.hf @@ -34,9 +35,24 @@ def __init__( self.grammar = grammar self.tokenizer = tokenizer self.tokenizer_name = tokenizer.name_or_path + self.ll_tokenizer = None + self.ll_matcher = None + self.bitmask = None self.new_sampling = False self.initialized = False + def clone(self) -> "GuidanceLogitsProcessor": + cloned = copy.copy(self) + if self.initialized: + cloned.ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, # type: ignore[assignment] + self.grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + self.bitmask = llguidance.torch.allocate_token_bitmask( + 1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined] + return cloned + def _initialize(self): if self.initialized: return @@ -56,13 +72,13 @@ def _initialize(self): # create reusable bitmask self.bitmask = llguidance.torch.allocate_token_bitmask( - 1, self.ll_tokenizer.vocab_size) + 1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined] self.initialized = True def __call__( self, - input_ids: List[int], + input_ids: list[int], scores: torch.Tensor, ) -> torch.Tensor: # we initialize the guidance model here @@ -70,15 +86,17 @@ def __call__( self._initialize() if self.new_sampling and len(input_ids) > 0: - self.ll_matcher.consume_token(input_ids[-1]) - err = self.ll_matcher.get_error() + self.ll_matcher.consume_token( # type: ignore[attr-defined] + input_ids[-1]) + err = self.ll_matcher.get_error() # type: ignore[attr-defined] if err: logger.warning("Error in LLMatcher: %s", err) llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask, 0) llguidance.torch.apply_token_bitmask_inplace( - scores, self.bitmask.to(scores.device)) + scores, + self.bitmask.to(scores.device)) # type: ignore[attr-defined] self.new_sampling = True diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 1593868a164a..085f37a5d516 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -1,16 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional, TypedDict, Union +from typing import Optional, TypedDict, Union from pydantic import BaseModel # These classes are deprecated, see SamplingParams class LLMGuidedOptions(TypedDict, total=False): - guided_json: Union[Dict, BaseModel, str] + guided_json: Union[dict, BaseModel, str] guided_regex: str - guided_choice: List[str] + guided_choice: list[str] guided_grammar: str guided_decoding_backend: str guided_whitespace_pattern: str @@ -20,9 +20,9 @@ class LLMGuidedOptions(TypedDict, total=False): @dataclass class GuidedDecodingRequest: """One of the fields will be used to retrieve the logit processor.""" - guided_json: Optional[Union[Dict, BaseModel, str]] = None + guided_json: Optional[Union[dict, BaseModel, str]] = None guided_regex: Optional[str] = None - guided_choice: Optional[List[str]] = None + guided_choice: Optional[list[str]] = None guided_grammar: Optional[str] = None guided_decoding_backend: Optional[str] = None guided_whitespace_pattern: Optional[str] = None diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 564f9277a83c..e41af4b360e4 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -5,9 +5,9 @@ import os from enum import Enum from json import dumps as json_dumps -from re import escape as regex_escape -from typing import Optional, Tuple, Union +from typing import Optional, Union +from regex import escape as regex_escape from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.outlines_logits_processors import ( @@ -111,7 +111,7 @@ def get_local_outlines_guided_decoding_logits_processor( def _get_guide_and_mode( guided_params: GuidedDecodingParams -) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: +) -> Union[tuple[str, GuidedDecodingMode], tuple[None, None]]: if guided_params.json: if isinstance(guided_params.json, dict): # turn dict into hashable string diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 936fd0f06867..6986b6554c23 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -19,7 +19,7 @@ import json from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Optional, Union +from typing import Callable, Optional, Union import numpy as np import torch @@ -53,10 +53,16 @@ def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]): self._guide: Guide = guide self._reasoner: Optional[ReasoningParser] = reasoner # CFGState is used for the FSM state for CFGGuide - self._fsm_state: DefaultDict[int, Union[int, + self._fsm_state: defaultdict[int, Union[int, CFGState]] = defaultdict(int) - def __call__(self, input_ids: List[int], + def clone(self) -> "BaseLogitsProcessor": + cloned = copy.copy(self) + cloned._guide = self._guide.copy() + cloned._fsm_state = copy.deepcopy(self._fsm_state) + return cloned + + def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" @@ -160,7 +166,7 @@ def __init__( class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema: Union[str, Dict, BaseModel], + def __init__(self, schema: Union[str, dict, BaseModel], tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Union[str, None], reasoner: Optional[ReasoningParser]): @@ -181,7 +187,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel], """ if isinstance(schema, type(BaseModel)): schema_str = json.dumps(schema.model_json_schema()) - elif isinstance(schema, Dict): + elif isinstance(schema, dict): schema_str = json.dumps(schema) elif isinstance(schema, str): schema_str = schema @@ -218,6 +224,12 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, reasoner) self._guide = self._guide.copy() + def clone(self) -> "CFGLogitsProcessor": + cloned = copy.copy(self) + cloned._fsm_state = copy.deepcopy(self._fsm_state) + cloned._guide = self._guide.copy() + return cloned + @lru_cache(maxsize=32) def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): @@ -252,11 +264,11 @@ def convert_token_to_string(token: str) -> str: return string def change_decoder( - decoder: Callable[[List[int]], - str]) -> Callable[[List[int]], List[str]]: + decoder: Callable[[list[int]], + str]) -> Callable[[list[int]], list[str]]: """Sync vLLM's decoder with the outlines by returning list.""" - def new_decoder(inp_tokens: List[int]) -> List[str]: + def new_decoder(inp_tokens: list[int]) -> list[str]: if (isinstance(inp_tokens, list) and len(inp_tokens) == 1 and isinstance(inp_tokens[0], list)): inp_tokens = inp_tokens[0] diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py index 1ad1ef8fbf16..3f77cf394d9a 100644 --- a/vllm/model_executor/guided_decoding/utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -import re +import regex as re def has_xgrammar_unsupported_json_features(schema: dict) -> bool: diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index ac2d73626d78..d2e568609945 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -4,10 +4,10 @@ from __future__ import annotations import json -import re from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any +import regex as re import torch import vllm.envs @@ -273,7 +273,7 @@ def escape_ebnf_string(s: str) -> str: return re.sub(r'(["\\])', r'\\\1', s) @staticmethod - def choice_as_grammar(choice: List[str] | None) -> str: + def choice_as_grammar(choice: list[str] | None) -> str: if choice is None: raise ValueError("Choice is not set") escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice) @@ -302,8 +302,9 @@ class XGrammarLogitsProcessor: prefilled: bool = field(default=False) def __post_init__(self): - self.tokenizer_info = self.config.tokenizer_info( - self.config.tokenizer_data) + if self.tokenizer_info is None: + self.tokenizer_info = self.config.tokenizer_info( + self.config.tokenizer_data) def __getstate__(self) -> dict[str, Any]: return {'config': self.config, 'reasoner': self.reasoner} @@ -400,7 +401,8 @@ def __call__(self, input_ids: list[int], def clone(self) -> XGrammarLogitsProcessor: """Create a new instance with shared compiled grammar but separate state""" - new_processor = XGrammarLogitsProcessor(self.config, self.reasoner) + new_processor = XGrammarLogitsProcessor(self.config, self.reasoner, + None, self.tokenizer_info) # Share the compiled grammar context (immutable after compilation) new_processor.ctx = self.ctx diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f082afb7e9c0..a32c26317a88 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -354,7 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module: _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ "gelu": lambda: GeluAndMul(), "silu": lambda: SiluAndMul(), - "gelu_and_mul": lambda: GeluAndMul(), + "geglu": lambda: GeluAndMul(), }) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9829ccdb384f..5c262287f7dd 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from typing import Any, Dict, Optional +from typing import Any, Optional from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON -_config: Optional[Dict[str, Any]] = None +_config: Optional[dict[str, Any]] = None @contextmanager @@ -19,7 +19,7 @@ def override_config(config): _config = old_config -def get_config() -> Optional[Dict[str, Any]]: +def get_config() -> Optional[dict[str, Any]]: return _config @@ -36,10 +36,10 @@ def get_config() -> Optional[Dict[str, Any]]: import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) + cutlass_moe_fp4, cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + TritonExperts, fused_experts, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ "fused_moe", @@ -48,4 +48,6 @@ def get_config() -> Optional[Dict[str, Any]]: "get_config_file_name", "grouped_topk", "cutlass_moe_fp8", + "cutlass_moe_fp4", + "TritonExperts", ] diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..3e0ad0d5a989 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 960c7f834857..26a433da2189 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,10 +1,177 @@ # SPDX-License-Identifier: Apache-2.0 -"""Fused MoE kernel.""" +""" CUTLASS based Fused MoE kernels.""" from typing import Optional import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache +from vllm.scalar_type import scalar_types + + +class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.ab_strides1 = ab_strides1 + self.c_strides1 = c_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides2 = c_strides2 + self.out_dtype = out_dtype + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + # Note that K, N are transposed + N, K = K, N + workspace1 = M * topk * max(2 * N, K) + workspace2 = M * topk * N + return (workspace1, workspace2, self.out_dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + a1q = hidden_states + + assert w1_scale is not None + assert w2_scale is not None + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" + assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim( + ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ + 0], "Input scale shape mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1.shape[2], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2.shape[2], "W2 scale shape mismatch" + assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" + assert w1.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + assert w1.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert self.ab_strides1.shape[0] == w1.shape[ + 0], "AB Strides 1 expert number mismatch" + assert self.c_strides1.shape[0] == w1.shape[ + 0], "C Strides 1 expert number mismatch" + assert self.ab_strides2.shape[0] == w2.shape[ + 0], "AB Strides 2 expert number mismatch" + assert self.c_strides2.shape[0] == w2.shape[ + 0], "C Strides 2 expert number mismatch" + assert self.out_dtype in [torch.half, + torch.bfloat16], "Invalid output dtype" + + M = a1q.shape[0] + _, N, K = w2.shape # because w1 + w2 are transposed + device = a1q.device + + assert w1.shape[1] == K + assert global_num_experts != -1 + assert a1q_scale is not None + + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids] != -1, + expert_map[topk_ids], -1) + else: + local_topk_ids = topk_ids + + topk = local_topk_ids.shape[1] + + per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + expert_offsets = torch.empty((global_num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) + + # With expert_map each Rank processes only a subset of experts. As + # a result not all of a_map and c2 tensors are filled. We fill it + # zeros for correctness. + if expert_map is not None: + a_map = torch.zeros((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + else: + a_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + c_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, a_map, + c_map, global_num_experts, N, K) + + a1q = _fp8_perm(a1q, a_map) + a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale + + c1 = _resize_cache(workspace13, (M * topk, N * 2)) + c2 = _resize_cache(workspace2, (M * topk, N)) + c3 = _resize_cache(workspace13, (M * topk, K)) + + ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, + expert_offsets[:-1], problem_sizes1, + self.ab_strides1, self.ab_strides1, self.c_strides1) + + self.activation(activation, c2, c1) + + a2q, a2q_scale = ops.scaled_fp8_quant( + c2, a2_scale, use_per_token_if_dynamic=per_act_token) + + if expert_map is not None: + c3.fill_(0) + + ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, + self.ab_strides2, self.ab_strides2, self.c_strides2) + + c3 = c3[c_map] + + return c3 #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -15,7 +182,7 @@ def cutlass_moe_fp8( w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids_: torch.Tensor, + topk_ids: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, @@ -57,7 +224,7 @@ def cutlass_moe_fp8( - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize the intermediate result between the gemms. Shape: scalar or [M] - - out_dtype (torch.Tensor): The output tensor type. + - out_dtype (torch.dtype): The output tensor type. - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, every Rank is responsible for a subset of experts. expert_map is a mapping from global expert-id to local expert-id. When expert_map[i] @@ -69,112 +236,147 @@ def cutlass_moe_fp8( Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ - - assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" - assert w1_q.dtype == torch.float8_e4m3fn - assert w2_q.dtype == torch.float8_e4m3fn - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" - assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ - 0], "Input scale shape mismatch" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1_q.shape[2], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2_q.shape[2], "W2 scale shape mismatch" - assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - assert ab_strides1.shape[0] == w1_q.shape[ - 0], "AB Strides 1 expert number mismatch" - assert c_strides1.shape[0] == w1_q.shape[ - 0], "C Strides 1 expert number mismatch" - assert ab_strides2.shape[0] == w2_q.shape[ - 0], "AB Strides 2 expert number mismatch" - assert c_strides2.shape[0] == w2_q.shape[ - 0], "C Strides 2 expert number mismatch" - assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - - num_experts = w1_q.size(0) - m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) - - local_topk_ids = topk_ids_ - if expert_map is not None: - "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where(expert_map[topk_ids_] != -1, - expert_map[topk_ids_], -1) - - topk = local_topk_ids.size(1) - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - if apply_router_weight_on_input: - assert topk == 1, \ - "apply_router_weight_on_input is only implemented for topk=1" - # TODO: this only works for topK=1, will need to update for topK>1 - a = a * topk_weights.to(out_dtype) - - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) - device = a_q.device - - expert_offsets = torch.empty((num_experts + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - - a_map_initializer = torch.empty - c2_initializer = torch.empty - if expert_map is not None: - # With expert_map each Rank processes only a subset of experts. As - # a result not all of a_map and c2 tensors are filled. We fill it - # zeros for correctness. - a_map_initializer = torch.zeros - c2_initializer = torch.zeros - a_map = a_map_initializer((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - c_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP( + per_channel_quant=per_act_token, + quant_dtype=torch.float8_e4m3fn, + ), + CutlassExpertsFp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + out_dtype, + ), + ) + + return fn( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + - ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, num_experts, n, - k) +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype) +def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, + device: torch.device): + """ + MoE implementation for FP4 Inputs + + # Gemm 1 + a: Input tensor: [m, k] (half/bfloat16) + a1_gscale: Activation scale per expert: [e] (float32) + w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] + w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) + (Note: `n` is the up projection output dim, `k` is the input dim in + full precision) + w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) + (Block size = 16 for NVFP4) + + # Gemm 2 + a2_gscale: Activation scale per expert: [e] + w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] + w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) + w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 + + topk_weights: [m, topk] dtype: float8 + topk_ids: [m, topk] dtype: float8 + + m, n, k: Unquantized weight shapes, dtype: int + e: number of experts, dtype: int + + assumes that topk < k < n to satisfy - up/down projection expectations. + """ + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" + assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" + assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3 + and w2_blockscale.ndim + == 3), ("All Weights must be of rank 3 for cutlass_moe_fp4") + m_a, k_a = a.shape + e_w1, nx2_w1, half_k_w1 = w1_fp4.shape + e_w2, k_w2, half_n_w2 = w2_fp4.shape + + assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", + " between weights.") + assert (k_a // 2 == half_k_w1 + and k == k_w2), ("Hidden size mismatch between a, w1 and w2") + assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " + "expected `n`") + assert (m == m_a), "input shape mismatch" + assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" + assert (topk_weights.shape[0] == m and topk_ids.shape[0] + == m), ("topk must be provided for each row of a") + + out_dtype = a.dtype + num_topk = topk_ids.shape[1] + + expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,2n,k)) + problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,n,k)) + problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device) - ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, - expert_offsets[:-1], problem_sizes1, ab_strides1, - ab_strides1, c_strides1) + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + # problem shapes should have [m, n, k] + # Note that problem sizes are based on logical number of elements. + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, e, n, k) + + tokens_per_expert = problem_sizes1[:, 0] + rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128 + blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device) + blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0) + + rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( + a, + a1_gscale, + expert_offsets, + blockscale_offsets, + num_topk, + expert_map=a_map) + + c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, + w1_blockscale, w1_alphas, problem_sizes1, + expert_offsets[:-1], blockscale_offsets[:-1], + out_dtype, device) + del rep_a_fp4, rep_a_blockscale + # hidden size dimension is split to one halfpytho sized tensor. + intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2), + device=device, + dtype=out_dtype) - intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) torch.ops._C.silu_and_mul(intermediate, c1) - intemediate_q, a2_scale = ops.scaled_fp8_quant( - intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) - - ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, - expert_offsets[:-1], problem_sizes2, ab_strides2, - ab_strides2, c_strides2) - # Gather tokens - c2 = c2[c_map].view(m, topk, k) - if not apply_router_weight_on_input: - c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) - return c2.sum(dim=1) + int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( + intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk) + + c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, + w2_alphas, problem_sizes2, expert_offsets[:-1], + blockscale_offsets[:-1], out_dtype, device) + del int_fp4, int_blockscale + out = (c2[c_map].view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).half()).sum(dim=1) + return out.to(dtype=out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 353c8cc9d59f..46a814e6ecc3 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,16 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +import functools import importlib.util -from typing import Optional, Tuple +from typing import Optional import torch -import vllm.envs as envs -from vllm import _custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) from vllm.utils import round_up @@ -19,6 +20,19 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None +@functools.cache +def deep_gemm_block_shape() -> list[int]: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() + return [block, block] + + +def _valid_deep_gemm_shape(M: int, N: int, K: int): + align = deep_gemm_block_shape()[0] + return align <= M and N % align == 0 and K % align == 0 + + def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, aligned by `dg.get_m_alignment_for_contiguous_layout()`. """ if not has_deep_gemm: + logger.debug("DeepGemm disabled: deep_gemm not available.") return False - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - # Expert maps not supported yet. if expert_map is not None: + logger.debug("DeepGemm disabled: expert map NYI.") return False - align = dg.get_m_alignment_for_contiguous_layout() - M = hidden_states.shape[0] - _, K, N = w2.shape - - # For now, disable DeepGemm for small N until better permute/unpermute - # ops are available. - if N <= 512: + M = hidden_states.size(0) + _, K, N = w2.size() + if not _valid_deep_gemm_shape(M, N, K): + logger.debug("DeepGemm disabled: unalinged problem size.") return False - if align > M or N % align != 0 or K % align != 0: + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug("DeepGemm disabled: invalid weight dtype(s).") return False - return (hidden_states.is_contiguous() and w1.is_contiguous() - and w2.is_contiguous()) - - -def _moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - block_m: int, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - Optional[torch.Tensor]]: - """ - Determine the sorted_token_ids, expert_ids for the given problem size. - Permute the hidden states and scales according to `sorted_token_ids`. - """ - top_k_num = curr_topk_ids.shape[1] - - tokens_in_chunk, _ = curr_hidden_states.shape + if (not hidden_states.is_contiguous() or not w1.is_contiguous() + or not w2.is_contiguous()): + logger.debug( + "DeepGemm disabled: weights or activations not contiguous.") + return False - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - block_m, - global_num_experts, - expert_map, - pad_sorted_ids=True)) + return True + + +class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self): + super().__init__() + self.block_shape = deep_gemm_block_shape() + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + block_m = self.block_shape[0] + M_sum = (M * topk) + num_experts * (block_m - 1) + M_sum = round_up(M_sum, block_m) + workspace1 = M_sum * max(N * 2, K) + workspace2 = M_sum * N + return (workspace1, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + import deep_gemm as dg + + a1q = hidden_states + _, N, K = w1.size() + + assert global_num_experts != -1 + assert w2.size(1) == K + + a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( + a1q, + a1q_scale, + topk_ids, + global_num_experts, + expert_map, + self.block_shape[0], + ) + + # Note: M_sum is different than the pre-permuted shape of a1q. + M_sum = a1q.size(0) + workspace1 = _resize_cache(workspace13, (M_sum, N)) + workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) + workspace3 = _resize_cache(workspace13, (M_sum, K)) - inv_perm: Optional[torch.Tensor] = None + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) - num_tokens = top_k_num * tokens_in_chunk - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + self.activation(activation, workspace2, workspace1.view(-1, N)) - # Permute according to sorted token ids. - curr_hidden_states = _fp8_perm(curr_hidden_states, - sorted_token_ids // top_k_num) + a2q_scale: Optional[torch.Tensor] = None - if a1q_scale is not None: - a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False, + self.block_shape) - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) + workspace3 = workspace3[inv_perm, ...] -def _moe_unpermute_and_reduce( - out: torch.Tensor, - curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], - topk_weight: torch.Tensor, -) -> None: - """ - Unpermute the final result and apply topk_weights, then perform the final - reduction on the hidden states. - """ - M, topk = topk_weight.shape - K = curr_hidden.shape[1] - curr_hidden = curr_hidden[inv_perm, ...] - curr_hidden = curr_hidden.view(-1, topk, K) - curr_hidden.mul_(topk_weight.view(M, -1, 1)) - ops.moe_sum(curr_hidden, out) + return workspace3 def deep_gemm_moe_fp8( @@ -128,6 +165,7 @@ def deep_gemm_moe_fp8( expert_map: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input=False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -166,129 +204,24 @@ def deep_gemm_moe_fp8( Returns: - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - assert expert_map is None, "Expert maps not supported yet" - - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert w1.dtype == torch.float8_e4m3fn - assert w2.dtype == torch.float8_e4m3fn - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ - 0] == hidden_states.shape[0], "Input scale shape mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] - if global_num_experts == -1: - global_num_experts = E - - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - - assert _valid_deep_gemm(hidden_states, w1, w2, expert_map) - - if inplace: - out_hidden_states = hidden_states - else: - out_hidden_states = torch.empty_like(hidden_states) - - block_m = dg.get_m_alignment_for_contiguous_layout() - block_shape = [block_m, block_m] - - assert w1_scale is not None - assert w2_scale is not None - - # We attempt to transpose and align offline in Fp8MoEMethod, in which - # case these calls will be nops. Otherwise, they'll be performed every - # time the layer is executed. - w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() - w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - - M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - workspace13 = torch.empty(M_sum * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - - workspace1 = workspace13[:M_sum * N].view(M_sum, N) - workspace2 = torch.empty((M_sum, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - workspace3 = workspace13[:M_sum * K].view(M_sum, K) - - for chunk in range(num_chunks): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape - - if tokens_in_chunk == 0: - break - - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - - a1q_scale: Optional[torch.Tensor] = None - - qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states, - a1_scale, block_shape) - - (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale, - curr_topk_ids, global_num_experts, - expert_map, block_m) - - # Adjust the intermediate cache size and config for the last chunk. - # Note that in most cases we only have one chunk so the cache size - # and config are already set correctly and do not need to be adjusted. - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - curr_M = sorted_token_ids.numel() - workspace1 = _resize_cache(workspace1, (curr_M, N)) - workspace2 = _resize_cache(workspace2, (curr_M, N // 2)) - workspace3 = _resize_cache(workspace3, (curr_M, K)) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1, - expert_ids) - - if activation == "silu": - torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") - - a2q_scale: Optional[torch.Tensor] = None - - qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale, - block_shape) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids) - - _moe_unpermute_and_reduce( - out_hidden_states[begin_chunk_idx:end_chunk_idx], - workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights) - - return out_hidden_states + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn, + block_shape=deep_gemm_block_shape()), + DeepGemmExperts(), + ) + return fn( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py new file mode 100644 index 000000000000..c2db79365931 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -0,0 +1,755 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused batched MoE kernel.""" +from typing import Optional + +import torch +import triton +import triton.language as tl + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache + + +@triton.jit +def moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr): + + offs_k = tl.arange(0, BLOCK_K) + + if use_w8a16: + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + offs_bsn = offs_n // group_n + b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + + offs_bsn * stride_bsn) + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + expert_id) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + # We accumulate along the K dimension. + if use_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=mask_m, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + if use_w8a8: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + if use_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + + return accumulator + + +@triton.jit +def expert_triton_kernel( + a_ptr, #[max_tokens, K] + b_ptr, #[K, N] + c_ptr, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) % N + offs_k = tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + + accumulator = moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n, + group_k, + # Meta-parameters + BLOCK_M, + BLOCK_N, + BLOCK_K, + compute_type, + use_fp8_w8a8, + use_int8_w8a16) + + # store in C + offs_cn = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = mask_m[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def batched_triton_kernel( + a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_ce, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # Early exit + return + + pid_mn = tl.program_id(axis=1) + #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid_mn // num_pid_n + pid_n = pid_mn % num_pid_n + + cta_m_start = pid_m * BLOCK_M + cta_n_start = pid_n * BLOCK_N + if cta_m_start >= e_num_tokens: + # Early exit + return + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + cta_n_size = min(BLOCK_N, N - cta_n_start) + + a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am + b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn + c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + + cta_n_start * stride_cn) + + expert_triton_kernel( + a_ptr, + b_ptr, + c_ptr, + expert_id, + compute_type, + cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M, + BLOCK_N, + BLOCK_K) + + +def invoke_moe_batched_triton_kernel( + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: torch.Tensor, + B_scale: torch.Tensor, + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None): + + assert not use_int4_w4a16 + max_num_tokens = A.size(1) + K = A.size(2) + N = C.size(2) + + BLOCK_M = config['BLOCK_SIZE_M'] + BLOCK_N = config['BLOCK_SIZE_N'] + BLOCK_K = config['BLOCK_SIZE_K'] + assert (torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing() + or max_num_tokens % BLOCK_M == 0) + + grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * + triton.cdiv(B.size(1), BLOCK_N)) + + batched_triton_kernel[grid]( + A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K) + + +class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + A reference prepare/finalize class that reorganizes the tokens into + expert batched format, i.e. E x max_num_tokens x K. This is the format + that the PPLX dispatch/combine kernels use. + """ + + def __init__(self, max_num_tokens: Optional[int], world_size: int, + dp_size: int, rank: int): + super().__init__() + self.world_size = world_size + self.dp_size = dp_size + self.rank = rank + self.max_num_tokens = max_num_tokens + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert a1.dim() == 2 + assert topk_ids.dim() == 2 + assert topk_ids.size(0) == a1.size(0) + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + num_tokens, hidden_dim = a1.size() + topk = topk_ids.size(1) + + if self.max_num_tokens is None: + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) + self.max_num_tokens = int(tokens_per_expert.max().item()) + else: + tokens_per_expert = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) + + assert num_experts % self.world_size == 0 + + num_local_experts = num_experts // self.world_size + + b_a1 = torch.zeros( + (num_local_experts, self.max_num_tokens, hidden_dim), + dtype=a1.dtype, + device=a1.device) + + first_expert = num_local_experts * self.rank + last_expert = first_expert + num_local_experts + + for expert_id in range(first_expert, last_expert): + topks = torch.any(topk_ids == expert_id, dim=1).flatten() + rows = torch.count_nonzero(topks.flatten()) + b_a1[expert_id - + first_expert, :rows, :] = a1[:topks.numel()][topks] + tokens_per_expert[expert_id - first_expert] = rows + + return b_a1, a1_scale, tokens_per_expert + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + num_tokens = topk_ids.size(0) + num_local_experts = fused_expert_output.size(0) + K = fused_expert_output.size(-1) + assert output.size(0) == num_tokens and output.size(1) == K + + output.fill_(0) + + first_expert = num_local_experts * self.rank + last_expert = first_expert + num_local_experts + + for expert_id in range(first_expert, last_expert): + matching_tokens = topk_ids == expert_id + topks = torch.any(matching_tokens, dim=1).flatten() + rows = torch.count_nonzero(topks) + rhs = fused_expert_output[expert_id - first_expert, :rows, :] + if not apply_router_weight_on_input: + rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1)) + output[topks] = output[topks] + rhs + + +class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): + """ + A reference MoE expert class that operates on expert batched format, + i.e. E x max_num_tokens x K. This is the format that the pplx + dispatch/combine kernels use. + """ + + def __init__( + self, + world_size: int, + dp_size: int, + max_num_tokens: Optional[int] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[list[int]] = None, + block_m: Optional[int] = None, + ): + super().__init__() + assert block_shape is None + assert block_m is None + assert not use_fp8_w8a8, "NYI" + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" + self.max_num_tokens = max_num_tokens + self.world_size = world_size + self.dp_size = dp_size + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + assert a.dim() == 2 + num_dp = self.world_size // self.dp_size + max_num_tokens = a.size( + 0) if self.max_num_tokens is None else self.max_num_tokens + #print(f"WORKSPACE {max_num_tokens} {num_dp}") + workspace13 = num_experts * max_num_tokens * num_dp * K + workspace2 = max_num_tokens * num_dp * N + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + assert hidden_states.dim() == 3 + assert expert_num_tokens is not None + hidden_dim = hidden_states.size(-1) + + if self.max_num_tokens is None: + max_num_tokens = hidden_states.size(1) + else: + max_num_tokens = self.max_num_tokens + + num_dp = self.world_size // self.dp_size + num_experts = global_num_experts + out = _resize_cache(workspace13, + (num_experts, max_num_tokens * num_dp, hidden_dim)) + num_local_experts = w1.size(0) + assert num_local_experts == w1.size(0), ( + f"{num_local_experts} == {w1.size(0)}") + + N = w1.size(1) // 2 + + # Not cudagraph friendly + assert (torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing() + or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), ( + f"{expert_num_tokens} <= {max_num_tokens * num_dp}") + + for expert in range(num_local_experts): + # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor + if (torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing()): + num = max_num_tokens * num_dp + else: + num = int(expert_num_tokens[expert].item()) + tmp = _resize_cache(workspace2, (num, N)) + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + self.activation(activation, tmp, input) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + + return out + + +class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + """ + A Triton based MoE expert class that operates on expert batched format, + i.e. E x max_num_tokens x K. This is the format that the pplx + dispatch/combine kernels use. + """ + + def __init__( + self, + max_num_tokens: Optional[int] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[list[int]] = None, + world_size: int = 1, + dp_size: int = 1, + ): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a16 = use_int8_w8a16 + self.block_shape = block_shape + self.max_num_tokens = max_num_tokens + assert not use_int8_w8a8, "NYI" + assert not use_int4_w4a16, "NYI" + self.world_size = world_size + self.dp_size = dp_size + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + assert a.dim() == 2 + num_dp = self.world_size // self.dp_size + max_num_tokens = a.size( + 0) if self.max_num_tokens is None else self.max_num_tokens + workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) + workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.size(-1) // 2 == w1.size(2), ( + "Hidden size mismatch") + else: + assert hidden_states.size(-1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(-1)} " + f"!= {w1.size(2)}") + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + # TODO: num_tokens -> max_num_tokens? + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + assert w1.size(0) == E + assert w2.size(0) == E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.size(), + w2.size(), + top_k_num, + config_dtype, + num_tokens, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) + intermediate_cache2 = _resize_cache(workspace2, + (E, num_tokens, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) + + # MM1 + invoke_moe_batched_triton_kernel(A=hidden_states, + B=w1, + C=intermediate_cache1, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a1q_scale, + B_scale=w1_scale, + B_zp=w1_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + block_shape=self.block_shape) + + # TODO: would be nice to use expert_num_tokens here to reduce + # garbage compute + self.activation(activation, intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N)) + + #qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + # TODO (varun) : support w8a8 + assert not self.use_fp8_w8a8 + #if self.use_fp8_w8a8: + # qintermediate_cache2, a2q_scale = _fp8_quantize( + # intermediate_cache2, a2_scale, self.block_shape) + + invoke_moe_batched_triton_kernel(A=intermediate_cache2, + B=w2, + C=intermediate_cache3, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a2q_scale, + B_scale=w2_scale, + B_zp=w2_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + block_shape=self.block_shape) + + return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index b96d34ec2db3..4c84dd538332 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -25,6 +25,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type_id: int, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, @@ -64,11 +66,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type = ScalarType.from_id(quant_type_id) assert quant_type in [ scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn + scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f ] - int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8] - num_bits = 4 if quant_type in int4_scalar_types else 8 + bit4_scalar_types = [ + scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f + ] + num_bits = 4 if quant_type in bit4_scalar_types else 8 # Check constraints. assert hidden_states.shape[0] == gating_output.shape[ @@ -133,6 +137,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache1, w1, w1_scale, + global_scale1, w1_zeros, g_idx1, sort_indices1, @@ -165,6 +170,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache3, w2, w2_scale, + global_scale2, w2_zeros, g_idx2, sort_indices2, @@ -202,6 +208,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, global_num_experts: int = -1, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f6305822c2dc..78f8eb926dc8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -3,21 +3,22 @@ import functools import json import os -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import torch import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) -from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_group_quant_int8, per_token_quant_int8) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, moe_kernel_quantize_input) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -472,18 +473,32 @@ def invoke_fused_moe_kernel(A: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], + config: dict[str, Any], compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, per_channel_quant: bool, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[list[int]] = None) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if use_fp8_w8a8 or use_int8_w8a8: + assert B_scale is not None + assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + == B_scale.shape[-2]) + assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + == B_scale.shape[-1]) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + M = A.shape[0] num_tokens = M * top_k @@ -622,7 +637,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, def get_config_file_name(E: int, N: int, dtype: Optional[str], - block_shape: Optional[List[int]] = None) -> str: + block_shape: Optional[list[int]] = None) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" block_shape_selector = ("" if not block_shape or not all(block_shape) else @@ -638,7 +653,7 @@ def get_moe_configs( dtype: Optional[str], block_n: Optional[int] = None, block_k: Optional[int] = None, -) -> Optional[Dict[int, Any]]: +) -> Optional[dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -670,7 +685,7 @@ def get_moe_configs( return None -def get_moe_wna16_block_config(config: Dict[str, +def get_moe_wna16_block_config(config: dict[str, int], use_moe_wna16_cuda: bool, num_valid_tokens: int, size_k: int, size_n: int, num_experts: int, group_size: int, @@ -742,8 +757,8 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, - block_shape: Optional[List[int]] = None, -) -> Dict[str, int]: + block_shape: Optional[list[int]] = None, +) -> dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] @@ -795,13 +810,13 @@ def get_default_config( def try_get_optimal_moe_config( - w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], + w1_shape: tuple[int, ...], + w2_shape: tuple[int, ...], top_k: int, dtype: Optional[str], M: int, is_marlin: bool = False, - block_shape: Optional[List[int]] = None, + block_shape: Optional[list[int]] = None, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -855,7 +870,8 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + indices_type: Optional[torch.dtype] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -865,10 +881,11 @@ def fused_topk( topk, dtype=torch.float32, device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + topk_ids = torch.empty( + M, + topk, + dtype=torch.int32 if indices_type is None else indices_type, + device=hidden_states.device) token_expert_indices = torch.empty(M, topk, dtype=torch.int32, @@ -895,7 +912,7 @@ def grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -962,6 +979,20 @@ def get_config_dtype_str( return None +# TODO (bnell): use scalar_type instead of bools? +def get_config_qtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> Optional[torch.dtype]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + return None + + def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -982,7 +1013,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[list[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, @@ -1012,7 +1043,7 @@ def inplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[list[int]] = None) -> None: pass @@ -1046,7 +1077,7 @@ def outplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[list[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1076,7 +1107,7 @@ def outplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[list[int]] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1100,9 +1131,6 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: - if is_rocm_aiter_moe_enabled(): - from .rocm_aiter_fused_moe import rocm_aiter_fused_experts - return rocm_aiter_fused_experts if inplace: return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts @@ -1129,9 +1157,12 @@ def fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, + block_shape: Optional[list[int]] = None, allow_deep_gemm: bool = False) -> torch.Tensor: - if (allow_deep_gemm and use_fp8_w8a8 + # For now, disable DeepGemm for small N (<= 512) until better + # permute/unpermute ops are available. + N = w1.shape[1] + if (allow_deep_gemm and use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( @@ -1148,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, ) else: return dispatch_fused_experts_func(inplace)( @@ -1174,87 +1206,37 @@ def fused_experts(hidden_states: torch.Tensor, block_shape=block_shape) -def moe_kernel_prepare_input( - A: torch.Tensor, - B: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[List[int]] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if use_fp8_w8a8: - assert B_scale is not None - if block_shape is None: - # If weights are per-channel (per_channel_quant=True), then - # activations apply per-token quantization. Otherwise, assume - # activation tensor-wise fp8 quantization, dynamic or static - A, A_scale = ops.scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_channel_quant) - else: - # activation block-wise fp8 quantization - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a8: - assert B_scale is not None - if block_shape is None: - # activation channel-wise int8 quantization - assert (per_channel_quant - ), "int8 quantization only supports block or channel-wise" - A, A_scale = per_token_quant_int8(A) - else: - # activation block-wise int8 quantization - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_int8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - - return A, A_scale - - -def fused_experts_impl(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None): +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2], ( + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" @@ -1264,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, torch.float32, torch.float16, torch.bfloat16 ] - num_tokens, _ = hidden_states.shape + num_tokens = hidden_states.shape[0] E, N, _ = w1.shape K = w2.shape[1] if global_num_experts == -1: @@ -1279,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype) + qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) + get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, @@ -1341,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input( + qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, - B=w1, A_scale=a1_scale, - B_scale=w1_scale, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + qtype=qtype, per_channel_quant=per_channel_quant, block_shape=block_shape) @@ -1360,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, - qa1_scale, + a1q_scale, w1_scale, w1_zp, curr_topk_weights, @@ -1387,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - qintermediate_cache2, qa2_scale = moe_kernel_prepare_input( + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, - B=w2, A_scale=a2_scale, - B_scale=w2_scale, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + qtype=qtype, per_channel_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, - qa2_scale, + a2q_scale, w2_scale, w2_zp, curr_topk_weights, @@ -1452,7 +1429,7 @@ def fused_moe( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, + block_shape: Optional[list[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1497,7 +1474,7 @@ def fused_moe( a1. - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - - block_shape: (Optional[List[int]]): Optional block size for block-wise + - block_shape: (Optional[list[int]]): Optional block size for block-wise quantization. Returns: @@ -1537,3 +1514,209 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) + + +class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[list[int]] = None, + block_m: Optional[int] = None, + ): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int8_w8a16 = use_int8_w8a16 + self.block_shape = block_shape + self.block_m = block_m + self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) + self.per_channel_quant = per_channel_quant + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + factor = num_experts if a.dim() == 3 else 1 + workspace1 = M * topk * max(N * 2, K) * factor + workspace2 = M * topk * N * factor + return (workspace1, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.size(-1) // 2 == w1.size(2), ( + "Hidden size mismatch") + else: + assert hidden_states.size(-1) == w1.size(2), \ + (f"Hidden size mismatch {hidden_states.size(-1)} " + f"!= {w1.size(2)}") + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert hidden_states.dim() == 2 + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + if global_num_experts == -1: + global_num_experts = E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.shape, + w2.shape, + top_k_num, + config_dtype, + num_tokens, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, + (num_tokens, top_k_num, N)) + intermediate_cache2 = _resize_cache(workspace2, + (num_tokens * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, + (num_tokens, top_k_num, K)) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) + + self.activation(activation, intermediate_cache2, + intermediate_cache1.view(-1, N)) + + a2q_scale: Optional[torch.Tensor] = None + + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, + self.block_shape) + + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) + + return intermediate_cache3 + + +def modular_triton_fused_moe( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[list[int]] = None, +) -> mk.FusedMoEModularKernel: + qtype = get_config_qtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + ) + return mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP( + quant_dtype=qtype, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ), + TritonExperts( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ), + ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 35994c8ac6af..29b41e720852 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,21 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib from abc import abstractmethod +from dataclasses import dataclass from enum import Enum -from typing import Callable, List, Optional, Tuple +from typing import Callable, Optional import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import get_current_vllm_config -from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, +from vllm.config import ParallelConfig, get_current_vllm_config +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -23,17 +28,208 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op +has_pplx = importlib.util.find_spec("pplx_kernels") is not None + if current_platform.is_cuda_alike(): - from .fused_moe import fused_experts + from .fused_batched_moe import (BatchedPrepareAndFinalize, + BatchedTritonExperts) + from .fused_moe import TritonExperts, fused_experts + from .modular_kernel import (FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize) + if has_pplx: + from .pplx_prepare_finalize import PplxPrepareAndFinalize else: fused_experts = None # type: ignore + FusedMoEPermuteExpertsUnpermute = None # type: ignore + FusedMoEPrepareAndFinalize = None # type: ignore +if is_rocm_aiter_moe_enabled(): + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_biased_group_topk as grouped_topk) +else: + from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): - # the iterative moe implementation is used until the moe_pallas is fixed - from .moe_torch_iterative import fused_moe as fused_moe_pallas + from .moe_pallas import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore logger = init_logger(__name__) +# Note: this limit is somewhat arbitrary and might be changed later. +# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim. +MOE_DP_CHUNK_SIZE = 256 + + +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_pplx_kernels(self): + return self.dp_size > 1 and self.use_ep and \ + envs.VLLM_ALL2ALL_BACKEND == "pplx" + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, + dp_size_, ep_size_ and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vllm's parallel config + object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ + is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor + # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True) + + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class MoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + in_dtype: torch.dtype # The activation type. + + # TODO: add more quantization params, blocked, per-token, etc. + block_size: int = 128 + + max_num_tokens: int = MOE_DP_CHUNK_SIZE + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -50,6 +246,60 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError + def init_prepare_finalize(self, moe: MoEConfig, + quant_config: Optional[QuantizationConfig]): + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + prepare_finalize = None + if moe.use_pplx_kernels: + all_to_all_args = dict( + max_num_tokens=moe.max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=all2all_manager.rank, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize)), + group_name=all2all_manager.cpu_group.group_name, + ) + + handle = all2all_manager.get_handle(all_to_all_args) + + prepare_finalize = PplxPrepareAndFinalize( + handle, + max_num_tokens=moe.max_num_tokens, + world_size=all2all_manager.world_size, + rank=all2all_manager.rank, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + quant_dtype=moe.in_dtype, + ) + + if prepare_finalize is not None: + experts = self.select_gemm_impl(prepare_finalize) + self.fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + def select_gemm_impl( + self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] + ) -> FusedMoEPermuteExpertsUnpermute: + # based on the all2all implementation, select the appropriate + # gemm implementation + raise NotImplementedError( + "Subclass must select appropriate gemm implementation" + " based on the prepare_finalize") + @abstractmethod def apply( self, @@ -76,6 +326,54 @@ def apply( class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self, moe: MoEConfig): + super().__init__() + self.fused_experts = fused_experts # type: ignore + self.moe = moe + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + if self.rocm_aiter_moe_enabled: + from .rocm_aiter_fused_moe import rocm_aiter_fused_experts + self.rocm_aiter_fused_experts = rocm_aiter_fused_experts + else: + self.rocm_aiter_fused_experts = None # type: ignore + + def select_gemm_impl( + self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]): + + assert self.fused_experts == fused_experts + + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None + + if isinstance(prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + logger.debug("BatchedTritonExperts %s", self.moe) + experts = BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + ) + else: + logger.debug("TritonExperts %s", self.moe) + experts = TritonExperts( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, + ) + return experts + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -118,9 +416,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, shuffle_weights) - if is_rocm_aiter_moe_enabled(): - # reshaping weights is required for aiter moe kernel. + shuffle_weights) + + if self.rocm_aiter_moe_enabled: shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) @@ -201,19 +499,32 @@ def forward_cuda( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map) + e_score_correction_bias=e_score_correction_bias, + indices_type=torch.uint32 if self.moe.use_pplx_kernels else None) + + if self.rocm_aiter_moe_enabled: + assert expert_map is None + return self.rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) + else: + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) def forward_cpu( self, @@ -326,7 +637,7 @@ def forward_tpu( def determine_expert_map( ep_size: int, ep_rank: int, - global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]: + global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]: """ Calculates how many experts should be assigned to each rank for EP and creates a mapping from global to local expert index. Experts are @@ -338,7 +649,7 @@ def determine_expert_map( global_num_experts (int): The total number of experts in the model. Returns: - Tuple[int, Optional[torch.Tensor]]: A tuple containing: + tuple[int, Optional[torch.Tensor]]: A tuple containing: - local_num_experts (int): The number of experts assigned to the current rank. - expert_map (Optional[torch.Tensor]): A tensor of shape @@ -419,21 +730,16 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - # Note: here we guard against accessing the TP and DP groups when - # uninitialized (this happens when testing) - self.tp_size = (tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()) - tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank() - self.dp_size = (dp_size - if dp_size is not None else get_dp_group().world_size) - self.dp_rank = (0 - if self.dp_size == 1 else get_dp_group().rank_in_group) - self.global_num_experts = num_experts - - # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() - use_ep = (vllm_config.parallel_config.enable_expert_parallel - and self.tp_size * self.dp_size > 1) + self.moe_parallel_config: FusedMoEParallelConfig = ( + FusedMoEParallelConfig.make( + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size if dp_size is not None else + get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config)) + + self.global_num_experts = num_experts # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 @@ -444,28 +750,17 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix - if use_ep: - # Set TP size to 1 to adjust for EP and adjust EP size and rank - # for DP attention. - self.ep_rank = tp_rank + self.tp_size * self.dp_rank - self.tp_rank = 0 - self.ep_size = self.tp_size * self.dp_size - self.tp_size = 1 - + # Determine expert maps + if self.use_ep: self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) else: - # Adjust TP size for DP attention - self.tp_rank = tp_rank + self.tp_size * self.dp_rank - self.ep_rank = 0 - self.tp_size = self.tp_size * self.dp_size - self.ep_size = 1 - self.local_num_experts = self.global_num_experts - self.expert_map = None + self.local_num_experts, self.expert_map = (self.global_num_experts, + None) + self.top_k = top_k - self.global_num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size @@ -480,6 +775,7 @@ def __init__( self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation if self.scoring_func != "softmax" and not self.use_grouped_topk: @@ -489,16 +785,32 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + moe = MoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + # TODO (bnell): this needs to be fixed for quantized types. + in_dtype=params_dtype, + max_num_tokens=MOE_DP_CHUNK_SIZE, + ) + self.moe_config = moe + self.quant_config = quant_config + # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. + quant_method: Optional[QuantizeMethodBase] = None + if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod()) + quant_method = UnquantizedFusedMoEMethod(moe) else: - self.quant_method = quant_config.get_quant_method(self, prefix) - assert self.quant_method is not None + quant_method = quant_config.get_quant_method(self, prefix) + + assert quant_method is not None + assert isinstance(quant_method, FusedMoEMethodBase) + self.quant_method = quant_method - self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, @@ -516,6 +828,38 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -643,7 +987,7 @@ def weight_loader(self, param: torch.nn.Parameter, expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: return - + quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -697,8 +1041,9 @@ def weight_loader(self, param: torch.nn.Parameter, # this is needed for compressed-tensors only loaded_weight = loaded_weight.to(param.data.device) - if param.data[expert_id] != 1 and (param.data[expert_id] - - loaded_weight).abs() > 1e-5: + if ("compressed" in quant_method_name.lower() + and param.data[expert_id] != 1 + and (param.data[expert_id] - loaded_weight).abs() > 1e-5): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param.data[expert_id]} " @@ -718,6 +1063,22 @@ def weight_loader(self, param: torch.nn.Parameter, tp_rank=self.tp_rank) return + if "ModelOpt" in quant_method_name: + if ('weight_scale_2' in weight_name + or 'input_scale' in weight_name): + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + elif "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return + # Case weight scales, zero_points and offset if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): @@ -783,9 +1144,9 @@ def select_experts(hidden_states: torch.Tensor, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None): - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, grouped_topk) + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None): + from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk # DeekSeekv2 uses grouped_top_k if use_grouped_topk: @@ -800,38 +1161,51 @@ def select_experts(hidden_states: torch.Tensor, topk_group=topk_group, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) elif custom_routing_function is None: topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, - renormalize=renormalize) + renormalize=renormalize, + indices_type=indices_type, + ) else: topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) return topk_weights, topk_ids - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - buffer[start:end, :].copy_(x) - for idx in range(get_dp_group().world_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - get_dp_group().broadcast(buffer[start:end, :], idx) + def must_reduce_shared_expert_outputs(self) -> bool: + """ + The shared_experts are typically computed using the RowParallelLinear + layer. The result of this function is typically used as + the reduce_results argument to the module. + When just tensor-parallel is used, it is not required to reduce + the shared_experts results immediately. Instead we reduce at the + once at the end of the MoE op. (Refer to DeepSeekV2MoE module) + With EP and the pplx kernels - this is no longer viable as all + GPU ranks in DP, produce the complete set of hidden_states. + Therefore it is required that we reduce the shared_experts output + early. + """ + return self.use_pplx_kernels - return buffer + def maybe_all_reduce_tensor_model_parallel( + self, final_hidden_states: torch.Tensor): + """ + The pplx combine kernel reduces across GPU ranks by default. + """ + if self.use_pplx_kernels: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -841,19 +1215,66 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) + def forward_impl_chunked(self, full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor): + + full_final_hidden_states = torch.empty_like(full_hidden_states) + + def process_chunk(chunk_start, chunk_end, skip_result_store=False): + hidden_states = full_hidden_states[chunk_start:chunk_end, :] + router_logits = full_router_logits[chunk_start:chunk_end, :] + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + ) + + if not skip_result_store: + full_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states) + + ctx = get_forward_context() + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE + + num_tokens = full_hidden_states.size(0) + for chunk_start_ in range(0, max_tokens_across_dp, + moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, + max_tokens_across_dp) + # clamp start and end + chunk_start = min(chunk_start, num_tokens - 1) + chunk_end = min(chunk_end, num_tokens) + + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) + + return full_final_hidden_states + def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + if self.moe_parallel_config.use_pplx_kernels: + return self.forward_impl_chunked(hidden_states, router_logits) if self.dp_size > 1: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_cpu) - + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -874,12 +1295,7 @@ def forward_impl(self, hidden_states: torch.Tensor, ) if self.dp_size > 1: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - - all_hidden_states = get_dp_group().all_reduce(final_hidden_states) - final_hidden_states = all_hidden_states[start:end, :] + final_hidden_states = get_ep_group().combine(final_hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) @@ -892,7 +1308,7 @@ def forward_impl(self, hidden_states: torch.Tensor, def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, str]]: + num_experts: int) -> list[tuple[str, str, int, str]]: return [ # (param_name, weight_name, expert_id, shard_id) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py new file mode 100644 index 000000000000..7d3ddf8f14c4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -0,0 +1,364 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +# +# This file defines a set of base classes used to make MoE kernels more modular. +# The goal is to be able to utilize different communication mechanisms with +# any fused MoE kernel without needing to have combinatoric implementations. +# +# The fused moe kernels are broken down into the following components: +# +# [Router] โ†’ [Quantize-Dispatch] โ†’ [Permute-Experts-Unpermute] โ†’ [Combine] +# +# Each component will be independent of the others except for +# [Quantize-Dispatch] and `[Combine] (see below). The components can then be +# mixed and matched with so that DP+EP can be supported easily for multiple +# MoE kernel implementations. +# +# The following main classes are defined: +# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE +# inputs (e.g. quantization, distribution) and finalization of Moe outputs. +# The prepare method must take care of any needed quantization and the +# finalize method must apply weights and do the final reduction of the output. +# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused +# MoE operation. One important feature to note is that this class does not +# apply topk weights or reduce the final output. +# * FusedMoEModularKernel - an interface class that combines a +# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to +# provide the standard fused MoE kernel interface. +# +# [Quantize-Prepare] and [Finalize] functionality are bundled into a single +# class `FusedMoEPrepareAndFinalize` since they could use collective +# communication mechanisms that need to be consistent. +# + + +def _moe_problem_size( + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, +) -> tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + + Note: extracting the problem shape from the weight and activation tensors is + not obvious. It needs to be done this way specifically due to subtle issues + with particular kernels, e.g. the int4 kernels divide the trailing dimension + by two, so it's not "correct" to extract N or K from the trailing dimension + of w1 or w2. Similarly, some kernels transpose the weights, so this needs + to be kept in mind. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, N, _ = w1.size() + K = w2.size(1) + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.size(0) == a1.size(0), \ + f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + else: + assert a1.dim() == 3 + assert a1.size(0) == E, f"{a1.size(0)} == {E}" + M = a1.size(1) # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + + +class FusedMoEPrepareAndFinalize(ABC): + """ + An abstract base class for the [Quantize-Prepare] and [Finalize] steps + described above. + """ + + @abstractmethod + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform any quantization (and/or) dispatching needed + for this kernel. + - a1: The (unquantized) input to the MoE layer. + - a1_scale: Optional scales for a1 + - a2_scale: Optional scales for the second MoE gemm. Required to make + sure the quantization is consistent for both gemms. + - topk_ids: The topk ids. + - topk_weights: The topk weights. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. + + Returns a tuple of: + - quantized + dispatched a. + - quantized + dispatched a1_scales. + """ + raise NotImplementedError + + @abstractmethod + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. + """ + raise NotImplementedError + + +class FusedMoEPermuteExpertsUnpermute(ABC): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above. + """ + + @abstractmethod + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + """ + Compute the number of elements for the temporary outputs of the two + gemms and activation in the fused expert function. Since the + gemms are independent, the workspace for the first gemm can be shared + with the workspace for the last gemm. + + Returns a tuple of: + - Number of workspace13 elements: must be large enough to hold the + result of either expert gemm. + - Number of workspace2 elements: must be large enough to hold the + result of the activation function. + - Workspace type: The dtype to use for the workspace tensors. + """ + raise NotImplementedError + + def activation(self, activation: str, output: torch.Tensor, + input: torch.Tensor) -> None: + assert output.size(-1) * 2 == input.size(-1) + if activation == "silu": + torch.ops._C.silu_and_mul(output, input) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(output, input) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + @abstractmethod + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + This function computes the intermediate result of a Mixture of Experts + (MoE) layer using two sets of weights, w1 and w2. + + Parameters: + - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE + layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_ids (torch.Tensor): A map of row to expert id. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be + used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs + must be large enough to hold output of either MoE gemm. + - workspace2 (torch.Tensor): A scratch tensor used for the activation + function. + - expert_num_tokens: An optional tensor containing the number of tokens + assigned to each expert when using batched experts format input. + + Returns: + - torch.Tensor: The unweighted, unreduced output tensor + """ + raise NotImplementedError + + +class FusedMoEModularKernel(torch.nn.Module): + """ + This class combines a FusedMoEPrepareAndFinalize instance and + a FusedMoEPermuteExpertsUnpermute to provide an interface that + is compatible with the `fused_experts` function in fused_moe.py. + + It takes care of managing any required scratch space. + + Note: Instances of this class should only be used for a single model + layer due to any layer specific state that may be used by the component + objects. + """ + + def __init__( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + fused_experts: FusedMoEPermuteExpertsUnpermute, + ): + super().__init__() + self.prepare_finalize = prepare_finalize + self.fused_experts = fused_experts + + def forward( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + ) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets + of weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states: (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The topk weights applied at the end of + the layer. + - topk_ids (torch.Tensor): A map of row to expert id. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is + 1. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + a1 = hidden_states + E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) + + if global_num_experts == -1: + global_num_experts = E + + output = a1 if inplace else torch.zeros_like(a1) + + workspace13_shape, workspace2_shape, workspace_dtype = ( + self.fused_experts.workspace_shapes(a1, M, N, K, top_k, + global_num_experts)) + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + workspace13 = torch.zeros(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = torch.zeros(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) + + a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare( + a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, + expert_map, apply_router_weight_on_input) + + fused_out = self.fused_experts.apply( + a1q, + w1, + w2, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) + + self.prepare_finalize.finalize(output, fused_out, topk_weights, + topk_ids, apply_router_weight_on_input) + + return output diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index b68e58efa884..d025f1257a9f 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch @@ -153,7 +153,7 @@ def moe_align_block_size( num_experts: int, expert_map: Optional[torch.Tensor] = None, pad_sorted_ids: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index 8f28b64ed487..babeb97308a9 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -2,7 +2,23 @@ import torch import torch.nn.functional as F -from torch_xla.experimental.custom_kernel import _histogram + + +def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: + """ + Compute the histogram of a int32 tensor. The bin edges are defined by the + min and max values, with step = 1. + """ + assert input.dtype == torch.int32, "input must be of torch.int32 dtype." + assert min <= max, "min must be less than or equal to max." + + def searchsorted(sorted_sequence: torch.Tensor, + values_to_search: torch.Tensor) -> torch.Tensor: + return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) + + bin_edges = torch.linspace(min, max, max - min + 1, + dtype=input.dtype).to(input.device) + return searchsorted(bin_edges, input).to(torch.int32) def fused_moe( @@ -61,7 +77,7 @@ def fused_moe( x = torch.ops.xla.gmm(x, w2, group_sizes) x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) - x = x * topk_weights.unsqueeze_(dim=-1) + x = x * topk_weights.unsqueeze(dim=-1) x = x.sum(dim=-2) x = x.reshape(orig_shape) return x diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index cdf7e31c1436..cb396f26c96e 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -1,8 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm + + +def _moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """ + Determine the sorted_token_ids, expert_ids for the given problem size. + Permute the hidden states and scales according to `sorted_token_ids`. + """ + top_k_num = curr_topk_ids.size(1) + + tokens_in_chunk = curr_hidden_states.sizze(0) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, + block_m, + global_num_experts, + expert_map, + pad_sorted_ids=True)) + + inv_perm: Optional[torch.Tensor] = None + + num_tokens = top_k_num * tokens_in_chunk + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + + # Permute according to sorted token ids. + curr_hidden_states = _fp8_perm(curr_hidden_states, + sorted_token_ids // top_k_num) + + if a1q_scale is not None: + a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) + + +def _moe_unpermute_and_reduce( + out: torch.Tensor, + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, + apply_router_weight_on_input: bool, +) -> None: + """ + Unpermute the final result and apply topk_weights, then perform the final + reduction on the hidden states. + """ + M, topk = topk_weight.size() + K = curr_hidden.size(-1) + if inv_perm is not None: + curr_hidden = curr_hidden[inv_perm, ...] + curr_hidden = curr_hidden.view(-1, topk, K) + if not apply_router_weight_on_input: + curr_hidden.mul_(topk_weight.view(M, -1, 1)) + ops.moe_sum(curr_hidden, out) + def moe_permute( hidden_states: torch.Tensor, @@ -15,23 +83,23 @@ def moe_permute( expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - This function expands and permutes activation to gather uncontinuous tokens + This function expands and permutes activation to gather uncontinuous tokens for each expert. Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - hidden_states (torch.Tensor): The input tensor to the MoE layer. - topk_weights (torch.Tensor): topk expert route weight for each token. - topk_ids (torch.Tensor): topk expert route id for each token. - token_expert_indices (torch.Tensor): indice for expanded hidden. - topk (int): The number of top-k experts to select. - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - - fill_invalid_expert(int): fill expert id in m_indices for invalid expert + - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices Returns: - permuted_hidden_states (torch.Tensor): permuted activation. @@ -39,10 +107,10 @@ def moe_permute( of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. - - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records + - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records the group which the j-th row of the LHS belong to.` """ - n_token, n_hidden = hidden_states.shape + n_token, n_hidden = hidden_states.size() assert (n_hidden * hidden_states.element_size() ) % 16 == 0, "permue kernel need hidden dim align to 16B" permuted_row_size = n_token * topk @@ -87,7 +155,7 @@ def moe_unpermute( n_local_expert: int, ) -> torch.Tensor: """ - This function expands and permutes activation to gathering uncontinuous + This function expands and permutes activation to gathering uncontinuous tokens for each expert. Parameters: - permuted_hidden_states (torch.Tensor): permuted activation. @@ -99,10 +167,10 @@ def moe_unpermute( - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. Returns: - - hidden_states (torch.Tensor): The reduced and unpermuted activation - tensor. + - hidden_states (torch.Tensor): The reduced and unpermuted activation + tensor. """ - n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1] + n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" hidden_states = torch.empty((n_token, n_hidden), @@ -114,3 +182,7 @@ def moe_unpermute( expert_first_token_offset, n_expert, n_local_expert, topk, hidden_states) return hidden_states + + +def moe_permute_unpermute_supported(): + return torch.ops._moe_C.moe_permute_unpermute_supported() diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py new file mode 100644 index 000000000000..783ebebbfec9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import pplx_kernels as pplx +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) + + +# The max_num_tokens, world_size and dp_size must be the same +# as the ones used to create the AllToAll. +class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + + def __init__(self, + a2a: pplx.AllToAll, + max_num_tokens: int, + world_size: int, + rank: int, + dp_size: int, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + super().__init__() + assert max_num_tokens > 0 + self.a2a = a2a + self.block_shape = block_shape + self.max_num_tokens = max_num_tokens + self.world_size = world_size + self.rank = rank + self.dp_size = dp_size + self.quant_dtype = quant_dtype + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + rank_topk_weights: torch.Tensor, + rank_topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + num_tokens = a1.size(0) # M + hidden_dim = a1.size(-1) # K + + assert rank_topk_ids.size(0) == num_tokens + # assert expert_map is None, "NYI" + + # Is this always going to be a1.device? + device = a1.device + + if apply_router_weight_on_input: + topk = rank_topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * rank_topk_weights.to(a1.dtype) + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + self.quant_dtype, + per_act_token, + self.block_shape) + + # rem_experts need to be 0 for pplx to work properly. + rem_experts = num_experts % self.world_size + assert rem_experts == 0 + num_local_experts = ((num_experts // self.world_size) + + (1 if self.rank < rem_experts else 0)) + + expert_num_tokens = torch.empty( + num_local_experts, + dtype=torch.int32, + device=device, + ) + + num_dp = self.world_size // self.dp_size + expert_x = torch.empty( + (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), + dtype=a1q.dtype, + device=device, + ) + + expert_x_scale: Optional[torch.Tensor] = None + if a1q.dtype.itemsize == 1: + float32_size = torch.float32.itemsize + block_size = (self.block_shape[0] if self.block_shape is not None + else 1) * float32_size + expert_x_scale = torch.empty( + ( + num_experts, + expert_x.size(1), + (expert_x.size(2) + block_size - 1) // block_size, + ), + dtype=torch.float32, + device=device, + ) + + # This argument is optional, defaults to indices.size(0) + # There's not much point setting this unless it is != indices.size(0) + bound_m: Optional[torch.Tensor] = None + + self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=rank_topk_ids, + bound_m=bound_m, + ) + + return expert_x, expert_x_scale, expert_num_tokens + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + num_tokens = output.size(0) # M + # This argument is optional + # There's not much point setting this unless it is != topk_ids.size(0) + bound_m: Optional[torch.Tensor] = None + + assert topk_ids.size(0) == num_tokens, ( + f"{topk_ids.size(0)} == {num_tokens}") + assert output.size(0) <= self.max_num_tokens, ( + f"{output.size(0)} <= {self.max_num_tokens}") + assert output.size(1) == fused_expert_output.size(-1) + + # Set weights to 1 if we did them in dispatch. This is hacky. + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) + + self.a2a.combine(out_tokens=output, + indices=topk_ids, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py new file mode 100644 index 000000000000..98f98b3bd20b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) + + +class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): + + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + ): + super().__init__() + self.per_channel_quant = per_channel_quant + self.block_shape = block_shape + self.quant_dtype = quant_dtype + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + self.quant_dtype, + self.per_channel_quant, + self.block_shape) + + return a1q, a1q_scale, None + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + _moe_unpermute_and_reduce(output, fused_expert_output, None, + topk_weights, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index acaa93f5a23e..10b61fcda176 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from enum import IntEnum from functools import cache -from typing import List, Optional, Tuple +from typing import Optional import torch @@ -9,6 +10,28 @@ from vllm.utils import direct_register_custom_op +class QuantMethod(IntEnum): + # This allows interfacing with AITER QuantType Enum + # without importing the QuantType from AITER globally. + + # Note that these quantization methods are + # supported in AITER package. However, + # not all are used in this module. + + NO = 0 # a16w16 + PER_TENSOR = 1 # w8a8 (pre_Tensor) + PER_TOKEN = 2 # w8a8/w8a4 (per_Token) + BLOCK_1X128 = 3 # block quantized w8a8 (per_1x128) + BLOCK_128x128 = 4 # block quantized w8a8 (per_128x128) + + +class ActivationMethod(IntEnum): + # This allows interfacing with AITER ActivationType enum + # without importing the ActivationType enum from AITER globally. + SILU = 0 + GELU = 1 + + @cache def is_rocm_aiter_moe_enabled() -> bool: return current_platform.is_rocm() \ @@ -20,7 +43,7 @@ def rocm_aiter_asm_moe_tkw1_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weight: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None, @@ -29,18 +52,17 @@ def rocm_aiter_asm_moe_tkw1_impl( a16: bool = False, per_tensor_quant_scale: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None, - activation_str: str = "silu") -> torch.Tensor: + activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: from aiter import ActivationType from aiter.fused_moe_bf16_asm import asm_moe_tkw1 - activation = \ - ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu + activation = ActivationType(activation_method) return asm_moe_tkw1(hidden_states, w1, w2, - topk_weight, + topk_weights, topk_ids, fc1_scale=fc1_scale, fc2_scale=fc2_scale, @@ -56,7 +78,7 @@ def rocm_aiter_asm_moe_tkw1_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weight: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None, @@ -65,136 +87,7 @@ def rocm_aiter_asm_moe_tkw1_fake( a16: bool = False, per_tensor_quant_scale: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None, - activation_str: str = "silu") -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: - from aiter import ck_moe - return ck_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids) - - -def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - hidden_states_dtype: torch.dtype, - expert_mask: torch.Tensor, - a1: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - a1_scale: torch.Tensor, - block_shape: List[int], - smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - from aiter import fmoe_fp8_blockscale_g1u1 - from aiter.fused_moe_bf16_asm import moe_sorting_ck - - topk = topk_ids.shape[1] - model_dim = w1.shape[-1] - local_E = E = w1.shape[0] - if expert_mask is not None: - E = expert_mask.numel() - - ( - sorted_token_ids, - sorted_weight_buf, - sorted_expert_ids, - num_valid_ids, - out_asm, - ) = moe_sorting_ck(topk_ids, - topk_weights, - E, - model_dim, - hidden_states_dtype, - expert_mask=expert_mask) - - fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, - sorted_weight_buf, sorted_expert_ids, - num_valid_ids, topk, w1_scale.view(local_E, -1), - w2_scale.view(local_E, -1), - a1_scale.t().contiguous(), *block_shape, - smooth_scale) - - return out_asm - - -def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - hidden_states_dtype: torch.dtype, - expert_mask: torch.Tensor, - a1: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - a1_scale: torch.Tensor, - block_shape: List[int], - smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - - return torch.empty_like(a1, dtype=torch.bf16) - - -def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, - a16: bool = False, - activation: str = "silu") -> torch.Tensor: - import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe - from aiter import ActivationType - - assert activation in ["silu", "gelu"], "The given activation:" \ - f" {activation}" \ - " is not supported in" \ - " AITER." - if activation == "silu": - aiter_activation = ActivationType.Silu - else: - aiter_activation = ActivationType.Gelu - - return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weight=topk_weight, - topk_ids=topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - fc1_smooth_scale=fc1_smooth_scale, - fc2_smooth_scale=fc2_smooth_scale, - a16=a16, - activation=aiter_activation) - - -def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, - a16: bool = False, - activation: str = "silu") -> torch.Tensor: + activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -216,6 +109,81 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, pass +def rocm_aiter_biased_grouped_topk_impl( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + + from aiter import biased_grouped_topk + + biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids, + num_expert_group, topk_group, need_renorm, + routed_scaling_factor) + + +def rocm_aiter_biased_grouped_topk_fake( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + pass + + +def rocm_aiter_fused_moe_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, + quant_method: int = QuantMethod.NO.value, + doweight_stage1: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + activation = ActivationType(activation_method) + quant_type = QuantType(quant_method) + + return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, + activation, quant_type, doweight_stage1, w1_scale, + w2_scale, a1_scale, a2_scale) + + +def rocm_aiter_fused_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, + quant_method: int = QuantMethod.NO.value, + doweight_stage1: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if current_platform.is_rocm(): direct_register_custom_op( @@ -227,88 +195,86 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, ) direct_register_custom_op( - op_name="rocm_aiter_ck_moe", - op_func=rocm_aiter_ck_moe_impl, + op_name="rocm_aiter_fused_moe", + op_func=rocm_aiter_fused_moe_impl, mutates_args=[], - fake_impl=rocm_aiter_ck_moe_fake, + fake_impl=rocm_aiter_fused_moe_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( - op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1", - op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl, - mutates_args=[], - fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake, + op_name="rocm_aiter_topk_softmax", + op_func=rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=rocm_aiter_topk_softmax_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( - op_name="rocm_aiter_asm_moe", - op_func=rocm_aiter_asm_moe_impl, - mutates_args=[], - fake_impl=rocm_aiter_asm_moe_fake, + op_name="rocm_aiter_biased_grouped_topk", + op_func=rocm_aiter_biased_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=rocm_aiter_biased_grouped_topk_fake, dispatch_key=current_platform.dispatch_key, ) - direct_register_custom_op( - op_name="rocm_aiter_topk_softmax", - op_func=rocm_aiter_topk_softmax_impl, - mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], - fake_impl=rocm_aiter_topk_softmax_fake, - dispatch_key=current_platform.dispatch_key, - ) +def rocm_aiter_biased_group_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "sigmoid", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + assert scoring_func == "sigmoid", ( + "rocm_aiter_biased_group_topk only supports 'sigmoid' scoring_func.") + assert e_score_correction_bias is not None, ( + "'e_score_correction_bias' must not be None.") + token = hidden_states.shape[0] + device = hidden_states.device + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), + dtype=torch.float32, + device=device) + torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, + e_score_correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + ) + return topk_weights, topk_ids -def rocm_aiter_fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: - - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) +def rocm_aiter_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None) -> torch.Tensor: + + activation_method = (ActivationMethod.SILU + if activation == "silu" else ActivationMethod.GELU) # All AITER Fused MoE kernels are expecting the following datatypes topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) - # w8a8 block-scaled - if block_shape is not None and use_fp8_w8a8: - assert not apply_router_weight_on_input, ( - "apply_router_weight_on_input is not supported for block scaled moe" - ) - assert w1_scale is not None - assert w2_scale is not None - - # The default block sizes are 128 in AITER. - block_shape = [128, 128] if block_shape is None else block_shape - - a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1]) - - return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1( - topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1, - w2, w1_scale, w2_scale, a1_scale, block_shape, None) - # w8a8 per-channel quantization - elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # This applies topk_weights on the GEMM output of the first FC layer # rather than the second FC. @@ -330,87 +296,77 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, fc2_smooth_scale=None, a16=False, per_tensor_quant_scale=None, - expert_mask=expert_map, - activation_str=activation) - - # w8a8 per-tensor activation per-tensor weight - elif use_fp8_w8a8: - assert not apply_router_weight_on_input, ( - "apply_router_weight_on_input is not supported for fp8_w8a8") - return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weight=topk_weights, - topk_ids=topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, - fc1_smooth_scale=None, - fc2_smooth_scale=None, - a16=False, - activation=activation) - if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - - hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) - topk_ids = topk_ids.to(torch.int32) - topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + expert_mask=None, + activation_method=activation_method) - # w16a16 fallback to rocm_aiter_ck_moe w16a16 - return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids) + else: + quant_method = QuantMethod.NO.value + + # w8a8 block-scaled + if block_shape is not None and use_fp8_w8a8: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is\ + not supported for block scaled moe") + assert w1_scale is not None + assert w2_scale is not None + quant_method = QuantMethod.BLOCK_128x128.value + elif use_fp8_w8a8: + # Currently only per tensor quantization method is enabled. + quant_method = QuantMethod.PER_TENSOR.value + + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + + return torch.ops.vllm.rocm_aiter_fused_moe( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + quant_method=quant_method, + activation_method=activation_method, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + doweight_stage1=apply_router_weight_on_input) def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, - renormalize: bool) -> Tuple[torch.Tensor, ...]: + renormalize: bool) -> tuple[torch.Tensor, ...]: torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices, token_expert_indices, gating_output, renormalize) return topk_weights, topk_indices -def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: +def shuffle_weights( + *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) +) -> tuple[torch.Tensor, ...]: """ Applies shuffle_weight function from AITER to each input tensor and returns them. + + Rearranges (shuffles) the input tensor/s + into a specified block layout for optimized computation. Args: - *tensors: Variable number of torch.Tensor objects. + *tensors: Variable number of torch.Tensor objects. + layout: A pair of integers specifying the + block sizes used to divide the tensors during shuffling. + Default is (16, 16). Returns: A Tuple of shuffled tensors. """ from aiter.ops.shuffle import shuffle_weight - return tuple(shuffle_weight(tensor) for tensor in tensors) - - -def expand_weights(*tensors: torch.Tensor, - expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]: - """ - Expands the dimensions of input tensors. - - Args: - *tensors: A variable number of torch.Tensor objects. - expansion_dims: A list of expansion dimensions - corresponding to each tensor. - - Returns: - A Tuple of tensors with expanded dimensions. - """ - - assert len(tensors) == len(expansion_dims), \ - "Number of tensors must match the number of expansion dimensions." - return tuple( - tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1)) - for tensor, dim in zip(tensors, expansion_dims)) + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py new file mode 100644 index 000000000000..2cfe373140bb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts + + +class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + block_m: Optional[int] = None, + allow_deep_gemm: bool = False): + super().__init__() + self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + block_m=block_m) + self.deep_gemm_expert = DeepGemmExperts() + self.allow_deep_gemm = allow_deep_gemm + self.use_fp8_w8a8 = use_fp8_w8a8 + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + # Note: the deep gemm workspaces are strictly larger than the triton + # workspaces so we can be pessimistic here and allocate for DeepGemm + # even if we fall back to triton later, e.g. if expert maps are set. + if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + return self.deep_gemm_expert.workspace_shapes( + a, M, N, K, topk, num_experts) + else: + return self.triton_expert.workspace_shapes(a, M, N, K, topk, + num_experts) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + N = w1.size(1) + if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + return self.deep_gemm_expert.apply( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_num_tokens, + ) + else: + return self.triton_expert.apply( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_num_tokens, + ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index db31422f7275..d9d2520e18b3 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,48 +1,97 @@ # SPDX-License-Identifier: Apache-2.0 from math import prod -from typing import List, Optional, Tuple +from typing import Optional import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, per_token_quant_int8) from vllm.utils import cdiv -def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: +def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: """ Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod(v) <= x.numel() + assert prod( + v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly? return x.flatten()[:prod(v)].view(*v) def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], - block_shape: Optional[List[int]], -) -> Tuple[torch.Tensor, torch.Tensor]: + per_act_token: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Perform fp8 quantization on the inputs. If a block_shape is provided, the output will be blocked. """ if block_shape is None: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) + A, A_scale = ops.scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_act_token) else: assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_fp8(A, block_k) - assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert cdiv(A.size(-1), block_k) == A_scale.size(-1) + + return A, A_scale + + +def _int8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Perform int8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + + # If weights are per-channel (per_channel_quant=True), then + # activations apply per-token quantization. Otherwise, assume + # activation tensor-wise fp8/int8 quantization, dynamic or static + if block_shape is None: + assert per_act_token, \ + "int8 quantization only supports block or channel-wise" + A, A_scale = per_token_quant_int8(A) + else: + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_int8(A, block_k) + assert cdiv(A.size(-1), block_k) == A_scale.size(-1) + return A, A_scale +def moe_kernel_quantize_input( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + qtype: Optional[torch.dtype], + per_channel_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if qtype == torch.float8_e4m3fn: + return _fp8_quantize(A, A_scale, per_channel_quant, block_shape) + elif qtype == torch.int8: + return _int8_quantize(A, A_scale, per_channel_quant, block_shape) + else: + assert A_scale is None + return A, A_scale + + def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """ A permutation routine that works on fp8 types. """ - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + if torch.is_floating_point(m) and m.dtype.itemsize == 1: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 87d9b959e643..e8abd32ff6ba 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Custom normalization layers.""" -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch import torch.nn as nn @@ -31,7 +31,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, def fused_add_rms_norm( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops ops.fused_add_rms_norm( x, @@ -46,25 +46,32 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float) -> torch.Tensor: import aiter as rocm_aiter + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rocm_aiter.rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + return rocm_aiter.rms_norm(x, weight, variance_epsilon) def rocm_aiter_fused_add_rms_norm( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: import aiter as rocm_aiter - # Assuming the correct signature for rmsnorm2d_fwd_with_add + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) rocm_aiter.rmsnorm2d_fwd_with_add( - x, # output + output, # output x, # input residual, # residual input - residual, # residual output + residual_out, # residual output weight, variance_epsilon, ) - return x, residual + return output, residual_out def dispatch_cuda_rmsnorm_func(add_residual: bool): @@ -112,7 +119,7 @@ def forward_native( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype x = x.to(torch.float32) @@ -150,7 +157,7 @@ def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if self.variance_size_override is not None: return self.forward_native(x, residual) @@ -167,7 +174,7 @@ def forward_hpu( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: from vllm_hpu_extension.kernels import rms_norm HPUFusedRMSNorm = rms_norm() if HPUFusedRMSNorm is None: @@ -187,7 +194,7 @@ def forward_xpu( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if self.variance_size_override is not None: return self.forward_native(x, residual) @@ -237,7 +244,7 @@ def forward_static( variance_epsilon: float, x: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype if residual is not None: @@ -260,7 +267,7 @@ def forward_native( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" return self.forward_static(self.weight.data, self.variance_epsilon, x, residual) @@ -269,7 +276,7 @@ def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if torch.compiler.is_compiling(): return self.forward_native(x, residual) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 794de4c383b0..269ac043d26c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -261,6 +261,7 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -523,6 +524,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -585,8 +587,6 @@ def weight_loader(self, param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) - if len(param.data_container) == 2: - self.qweight = param.materialize_nested() return param_data = param.data @@ -805,6 +805,7 @@ class QKVParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -979,8 +980,6 @@ def weight_loader(self, param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) - if len(param.data_container) == 3: - self.qweight = param.materialize_nested() return param_data = param.data @@ -1155,7 +1154,13 @@ class RowParallelLinear(LinearBase): bias can be fused with other element-wise operations. We skip adding bias but instead return it. params_dtype: Data type for the parameters. + reduce_results: If true, call all-reduce on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y = X_iA_i quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.down_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -1425,8 +1430,8 @@ def sync_weight_attrs( ): missing_attrs_dict = { k: getattr(src_param, k) - for k in (set(src_param.__dict__.keys()) - - set(tgt_param.__dict__.keys())) + for k in (set(vars(src_param).keys()) - + set(vars(tgt_param).keys())) } # TODO(Isotr0py): handle bitsandbytes 8bit use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 4a359725bad0..6b69a260826b 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -119,7 +119,7 @@ def _get_logits( def extra_repr(self) -> str: s = f"vocab_size={self.vocab_size}" - s += f", forg_vocab_size={self.org_vocab_size}" + s += f", org_vocab_size={self.org_vocab_size}" s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" return s diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index e5b88de2fcc8..019f634a9ef4 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -5,10 +5,9 @@ import torch from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionMetadata) -from vllm.attention.backends.xformers import XFormersMetadata +from vllm.platforms import current_platform @dataclass @@ -23,6 +22,21 @@ class Mamba2Metadata: chunk_offsets: torch.Tensor +def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: + """Returns the appropriate metadata classes for the current platform.""" + if current_platform.is_rocm(): + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata) + return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata) + elif current_platform.is_cuda(): + from vllm.attention.backends.flash_attn import FlashAttentionMetadata + from vllm.attention.backends.xformers import XFormersMetadata + return (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata) + raise ValueError( + f"Unsupported platform for Mamba2: {current_platform.device_type}") + + def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int): @@ -78,9 +92,8 @@ def prepare_mamba2_metadata( # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: - if (isinstance(attn_metadata, - (FlashAttentionMetadata, XFormersMetadata, - PlaceholderAttentionMetadata)) + attn_metadata_instances = get_platform_metadata_classes() + if (isinstance(attn_metadata, attn_metadata_instances) and attn_metadata.context_lens_tensor is not None): has_initial_states = \ attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,] diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 05b9d87ac0af..f94ab75f9a4f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import nn @@ -34,7 +34,11 @@ @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): - def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): + def __init__(self, + full_hidden_size: int, + full_n_groups: int, + use_rms_norm: bool = True, + eps: float = 1e-6): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -44,11 +48,17 @@ def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): self.n_groups = full_hidden_size // self.group_size self.variance_epsilon = eps - self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) - set_weight_attrs(self.weight, - {"weight_loader": sharded_weight_loader(0)}) - assert self.full_hidden_size % self.tp_size== 0,\ - "Tensor parallel world size must divide hidden size." + self.use_rms_norm = use_rms_norm + if self.use_rms_norm: + # Register norm weight only if we're actually applying RMSNorm + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) + else: + # Avoid checkpoint mismatch by skipping unused parameter + self.register_parameter("weight", None) + assert (self.full_hidden_size % self.tp_size == 0 + ), "Tensor parallel world size must divide hidden size." def forward_native( self, @@ -66,6 +76,8 @@ def forward_native( # the input and then redundantly compute the RMSNorm. input_dtype = x.dtype x = x * nn.functional.silu(gate.to(torch.float32)) + if not self.use_rms_norm: + return x.to(input_dtype) if self.n_groups == 1: if self.tp_size > 1: @@ -74,7 +86,7 @@ def forward_native( global_sums = tensor_model_parallel_all_reduce(local_sums) # Calculate the variance count = self.tp_size * x.shape[-1] - variance = (global_sums / count) + variance = global_sums / count else: variance = x.pow(2).mean(-1, keepdim=True) @@ -104,7 +116,12 @@ def forward_cuda( self, x: torch.Tensor, gate: torch.Tensor, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + input_dtype = x.dtype + if not self.use_rms_norm: + # Keep gate in float32 for numerical stability during silu + return x * nn.functional.silu(gate.to( + torch.float32)).to(input_dtype) if self.tp_size > 1 or self.n_groups != 1: return self.forward_native(x, gate) @@ -124,7 +141,7 @@ def forward_cuda( def extra_groups_for_head_shards(ngroups: int, tp_size: int): - """Compute the increase in group numbers to account for + """Compute the increase in group numbers to account for replication in order to accompany the head shards.""" # in the case ngoups % tp_size == 0, this will be zero @@ -136,13 +153,13 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int): def mamba_v2_sharded_weight_loader( - shard_spec: List[Tuple[int, int, float]], + shard_spec: list[tuple[int, int, float]], tp_size: int, tp_rank: int, ) -> LoaderFunction: """Create a weight loader for mamba v2. This ensures that the projections are correctly sharded so that they can be split into x, B, C. It also - ensures the the all the groups corresponding to a head shard is placed + ensures that all the groups corresponding to a head shard is placed together with it. """ @@ -182,13 +199,15 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # seem to handle slices well. # https://github.com/python/mypy/issues/2410 param.data[ - boundary:(boundary + take), # type: ignore[misc] - ...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] - loaded_start_idx + take)] # type: ignore[misc] + boundary:(boundary + take), + ... # type: ignore[misc] + ] = loaded_weight[loaded_start_idx:(loaded_start_idx + + take) # type: ignore[misc] + ] # type: ignore[misc] # move indexing boundaries boundary += shard_size - loaded_boundary += (full_dim - extra) + loaded_boundary += full_dim - extra return loader @@ -206,19 +225,22 @@ class MambaMixer2(CustomOp): **selective** state spaces) """ - def __init__(self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation="silu", - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() # For TP, the sharding plan is as follows: @@ -238,17 +260,16 @@ def __init__(self, self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() - assert num_heads % self.tp_size == 0, \ - "Tensor parallel world size must divide num heads." + assert (num_heads % self.tp_size == 0 + ), "Tensor parallel world size must divide num heads." - assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ - ( - "If tensor parallel world size does not divide num_heads, " - "then num_groups must equal 1." - ) + assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( + "If tensor parallel world size does not divide num_heads, " + "then num_groups must equal 1.") - assert self.tp_size == 1 or quant_config is None, \ - "Tensor parallel currently not supported for quantized models." + assert ( + self.tp_size == 1 or quant_config is None + ), "Tensor parallel currently not supported for quantized models." self.ssm_state_size = ssm_state_size self.activation = activation @@ -265,8 +286,7 @@ def __init__(self, self.n_groups = n_groups + extra_groups_for_head_shards( n_groups, self.tp_size) - self.conv_dim = (intermediate_size + - 2 * self.n_groups * ssm_state_size) + self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, output_size=self.conv_dim, @@ -279,11 +299,12 @@ def __init__(self, # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = ColumnParallelLinear(input_size=hidden_size, - output_size=intermediate_size + - self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config) + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + ) # - because in_proj is a concatenation of 3 weights, we # need to interleave them before sharding @@ -305,7 +326,8 @@ def __init__(self, # - ditto for the otther two weights below delattr(self.conv1d.bias, "weight_loader") set_weight_attrs( - self.conv1d.bias, { + self.conv1d.bias, + { "weight_loader": mamba_v2_sharded_weight_loader( [ @@ -316,18 +338,25 @@ def __init__(self, self.tp_size, tp_rank, ) - }) + }, + ) delattr(self.conv1d.weight, "weight_loader") set_weight_attrs( - self.conv1d.weight, { + self.conv1d.weight, + { "weight_loader": - mamba_v2_sharded_weight_loader([ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], self.tp_size, tp_rank) - }) + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) if quant_config is None: # - quant layers do not have a weight loader @@ -345,8 +374,10 @@ def __init__(self, head_setings, # for dt ], self.tp_size, - tp_rank) - }) + tp_rank, + ) + }, + ) # - these are TPed by heads to reduce the size of the # temporal shape @@ -357,6 +388,7 @@ def __init__(self, )) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.use_rms_norm = use_rms_norm set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( @@ -365,18 +397,25 @@ def __init__(self, set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - self.out_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=use_bias, - input_is_parallel=True, - quant_config=quant_config) + self.out_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config, + ) self.norm = Mixer2RMSNormGated(intermediate_size, n_groups, + self.use_rms_norm, eps=rms_norm_eps) - def forward_native(self, hidden_states: torch.Tensor, - conv_state: torch.Tensor, ssm_state: torch.Tensor): + def forward_native( + self, + hidden_states: torch.Tensor, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): pass def forward_cuda( @@ -384,6 +423,7 @@ def forward_cuda( hidden_states: torch.Tensor, mamba_cache_params: MambaCacheParams, mamba2_metadata: Mamba2Metadata, + mup_vector: Optional[torch.Tensor] = None, ): # mamba2_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill @@ -401,6 +441,10 @@ def forward_cuda( # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) + + if mup_vector is not None: + projected_states = projected_states * mup_vector + gate, hidden_states_B_C, dt = torch.split( projected_states, [ @@ -561,6 +605,9 @@ def forward_cuda( hidden_states = torch.vstack(ssd_output_list) # 4. gated MLP + # GatedRMSNorm internally applying SiLU to the gate + # SiLU is applied internally before normalization, unlike standard + # norm usage hidden_states = self.norm(hidden_states, gate) # 5. Final linear projection diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3f6ab64e4fa9..6abbc90819a8 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from enum import IntEnum -from typing import List, Optional, Union +from typing import Optional, Union import torch import torch.nn as nn @@ -46,7 +46,7 @@ def from_pooling_type( normalize: bool, softmax: bool, step_tag_id: Optional[int] = None, - returned_token_ids: Optional[List[int]] = None, + returned_token_ids: Optional[list[int]] = None, ) -> "SimplePooler": if pooling_type == PoolingType.LAST: assert step_tag_id is None and returned_token_ids is None @@ -174,7 +174,7 @@ def __init__( normalize: bool, softmax: bool, step_tag_id: Optional[int] = None, - returned_token_ids: Optional[List[int]] = None, + returned_token_ids: Optional[list[int]] = None, ): super().__init__(normalize=normalize, softmax=softmax) @@ -242,9 +242,16 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], if self.softmax: if isinstance(pooled_data, list): - pooled_data = [F.softmax(data, dim=-1) for data in pooled_data] + pooled_data = [ + F.softmax(data, dim=-1) + if data.shape[-1] >= 2 else F.sigmoid(data) + for data in pooled_data + ] else: - pooled_data = F.softmax(pooled_data, dim=-1) + if pooled_data.shape[-1] >= 2: + pooled_data = F.softmax(pooled_data, dim=-1) + else: + pooled_data = F.sigmoid(pooled_data) return pooled_data @@ -259,7 +266,7 @@ def from_config_with_defaults( normalize: bool, softmax: bool, step_tag_id: Optional[int] = None, - returned_token_ids: Optional[List[int]] = None, + returned_token_ids: Optional[list[int]] = None, ) -> SimplePooler: return SimplePooler.from_pooling_type( pooling_type=PoolingType[pooler_config.pooling_type] diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 15e08220b7b5..407b9c72f41d 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Literal, Type, get_args +from typing import Literal, get_args from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -14,7 +14,7 @@ "ptpc_fp8", "fbgemm_fp8", "modelopt", - "nvfp4", + "modelopt_fp4", "marlin", "bitblas", "gguf", @@ -33,6 +33,7 @@ "quark", "moe_wna16", "torchao", + "auto-round", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -76,7 +77,7 @@ def _wrapper(quant_config_cls): return _wrapper -def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: if quantization not in QUANTIZATION_METHODS: raise ValueError(f"Invalid quantization method: {quantization}") @@ -84,6 +85,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig from .aqlm import AQLMConfig + from .auto_round import AutoRoundConfig from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig from .bitblas import BitBLASConfig @@ -110,7 +112,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig - method_to_config: dict[str, Type[QuantizationConfig]] = { + method_to_config: dict[str, type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, @@ -118,7 +120,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, - "nvfp4": ModelOptNvFp4Config, + "modelopt_fp4": ModelOptNvFp4Config, "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gguf": GGUFConfig, @@ -138,6 +140,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "quark": QuarkConfig, "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, + "auto-round": AutoRoundConfig, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 0b74e8faff9d..8bf0ca5c0448 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -4,7 +4,7 @@ # and https://arxiv.org/pdf/2401.06118.pdf import math -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch import torch.nn.functional as F @@ -98,7 +98,7 @@ def generic_dequantize_gemm( codebooks: torch. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - output_partition_sizes: List[int], + output_partition_sizes: list[int], bias: Optional[torch.Tensor], ) -> torch.Tensor: output_shape = input.shape[:-1] + (scales.shape[0], ) @@ -136,7 +136,7 @@ def optimized_dequantize_gemm( codebooks: torch. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - output_partition_sizes: List[int], + output_partition_sizes: list[int], bias: Optional[torch.Tensor], ) -> torch.Tensor: weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) @@ -191,7 +191,7 @@ def get_name(cls) -> QuantizationMethods: return "aqlm" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half] @classmethod @@ -199,11 +199,11 @@ def get_min_capability(cls) -> int: return 60 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return [] # no extra configs. @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": + def from_config(cls, config: dict[str, Any]) -> "AQLMConfig": in_group_size = cls.get_from_keys(config, ["in_group_size"]) nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) num_code_books = cls.get_from_keys(config, ["num_codebooks"]) @@ -230,7 +230,7 @@ def __init__(self, quant_config: AQLMConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py new file mode 100644 index 000000000000..2d9f5e52bd65 --- /dev/null +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fractions import Fraction +from typing import Any, Optional, Union + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class AutoRoundConfig(QuantizationConfig): + """Config class for AutoRound. + Reference: https://arxiv.org/pdf/2309.05516 + """ + + SUPPORTED_BITS = {2, 3, 4, 8} + SUPPORTED_DTYPES = {"int"} + SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"} + SUPPORTED_BACKENDS = { + "auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex" + } + + def __init__( + self, + weight_bits: int, + group_size: int, + sym: bool = True, + packing_format: str = "auto_round:auto_gptq", + block_name_to_quantize: Optional[Union[str, list[str]]] = None, + extra_config: Optional[dict[str, Any]] = None, + data_type: str = "int", + backend: str = "auto", + ) -> None: + super().__init__() + if weight_bits not in self.SUPPORTED_BITS: + raise ValueError(f"Unsupported weight_bits: {weight_bits}, " + f"currently only support {self.SUPPORTED_BITS}") + if data_type not in self.SUPPORTED_DTYPES: + raise ValueError( + f"Unsupported data_type: {data_type}," + f" currently only support {self.SUPPORTED_DTYPES}") + if packing_format not in self.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported packing_format: {packing_format}, " + f"currently only support {self.SUPPORTED_FORMATS}") + if backend not in self.SUPPORTED_BACKENDS: + raise ValueError( + f"Unsupported backend: {backend}, " + f"currently only support {self.SUPPORTED_BACKENDS}") + + self.weight_bits = weight_bits + self.group_size = group_size + self.sym = sym + self.packing_format = packing_format + self.block_name_to_quantize = (block_name_to_quantize.split(",") if + isinstance(block_name_to_quantize, str) + else block_name_to_quantize) + self.extra_config = extra_config + self.data_type = data_type + self.backend = backend + self.pack_factor = Fraction(32, weight_bits) + + def __repr__(self) -> str: + return (f"AutoRoundConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, sym={self.sym})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "auto-round" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantization_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": + return cls( + weight_bits=cls.get_from_keys(config, ["bits"]), + group_size=cls.get_from_keys(config, ["group_size"]), + sym=cls.get_from_keys(config, ["sym"]), + packing_format=cls.get_from_keys_or(config, ["packing_format"], + "auto_round:auto_gptq"), + block_name_to_quantize=cls.get_from_keys_or( + config, ["block_name_to_quantize", "to_quant_block_names"], + None), + extra_config=cls.get_from_keys_or(config, ["extra_config"], None), + data_type=cls.get_from_keys_or(config, ["data_type"], "int"), + backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], + "auto"), + ) + + def get_layer_config(self, layer, layer_name: str): + # Priority: extra_config > block_name_to_quantize > type fallback + if self.extra_config and layer_name in self.extra_config: + cfg = self.extra_config[layer_name] + return cfg.get("bits", self.weight_bits), cfg.get( + "group_size", self.group_size), cfg.get("sym", self.sym) + + quantized = True + if self.block_name_to_quantize: + quantized = any(name in layer_name + for name in self.block_name_to_quantize) + elif isinstance(layer, ParallelLMHead): + quantized = False + + return (self.weight_bits, self.group_size, + self.sym) if quantized else (16, -1, True) + + def check_quantized(self, weight_bits: int) -> bool: + return weight_bits < 16 + + def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, check_moe_marlin_supports_layer) + + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + + logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", + prefix, layer.__class__.__name__, weight_bits, group_size, + sym) + if backend == "auto" or "marlin" in backend: + AWQ_TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + use_marlin = (weight_bits + in AWQ_TYPE_MAP) and check_marlin_supported( + AWQ_TYPE_MAP[weight_bits], group_size, not sym) + + if isinstance(layer, FusedMoE): + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size) + + else: + use_marlin = False + if use_marlin: + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod) + quant_args_marlin = AWQMarlinConfig(weight_bits=weight_bits, + group_size=group_size, + zero_point=not sym, + lm_head_quantized=False, + full_config={}, + modules_to_not_convert=[]) + else: + from vllm.model_executor.layers.quantization.awq import ( + AWQConfig, AWQLinearMethod) + quant_args = AWQConfig( + weight_bits=weight_bits, + group_size=group_size, + zero_point=not sym, + ) + + if isinstance(layer, FusedMoE): + if use_marlin: + return AWQMoEMethod(quant_args_marlin) + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + config = { + "quant_method": "awq", + "bits": weight_bits, + "group_size": group_size, + "zero_point": not sym, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if use_marlin: + return AWQMarlinLinearMethod(quant_args_marlin) + else: + return AWQLinearMethod(quant_args) + return None + + def apply_gptq_quant_layer(self, + layer, + prefix: str, + backend: str = "auto"): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, check_moe_marlin_supports_layer) + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + + logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", + prefix, layer.__class__.__name__, weight_bits, group_size, + sym) + if backend == "auto" or "marlin" in backend: + GPTQ_TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP + and check_marlin_supported( + GPTQ_TYPE_MAP[(weight_bits, sym)], + group_size, + has_zp=not sym)) + if isinstance(layer, FusedMoE): + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size) + else: + use_marlin = False + if use_marlin: + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod) + quant_args_marlin = GPTQMarlinConfig(weight_bits=weight_bits, + group_size=group_size, + is_sym=sym, + lm_head_quantized=False, + desc_act=False, + dynamic={}, + full_config={}) + else: + from vllm.model_executor.layers.quantization.gptq import ( + GPTQConfig, GPTQLinearMethod) + quant_args = GPTQConfig(weight_bits=weight_bits, + group_size=group_size, + lm_head_quantized=False, + desc_act=False, + dynamic={}) + + if isinstance(layer, FusedMoE): + if use_marlin: + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + config = { + "quant_method": "gptq", + "bits": weight_bits, + "group_size": group_size, + "sym": sym, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + return GPTQMarlinMoEMethod(quant_args_marlin) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if use_marlin: + return GPTQMarlinLinearMethod(quant_args_marlin) + else: + return GPTQLinearMethod(quant_args) + + return None + + def apply_ipex_quant_layer(self, layer, prefix: str): + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + from vllm.model_executor.layers.quantization.ipex_quant import ( + IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod) + if isinstance(layer, (LinearBase, ParallelLMHead)): + if "awq" in self.packing_format: + config = IPEXConfig(method="awq", + weight_bits=weight_bits, + group_size=group_size) + return IPEXAWQLinearMethod(config) + elif "gptq" in self.packing_format: + config = IPEXConfig(method="gptq", + weight_bits=weight_bits, + group_size=group_size) + return IPEXGPTQLinearMethod(config) + else: + raise ValueError( + f"ipex backend only supports awq " + f"and gtpq format,but got {self.packing_format}") + else: + return None + + def get_quant_method(self, layer: torch.nn.Module, prefix: str): + if (current_platform.is_cpu() or current_platform.is_xpu() + or self.backend == "ipex"): + return self.apply_ipex_quant_layer(layer, prefix) + if "gptq" in self.packing_format or "gptq" in self.backend: + return self.apply_gptq_quant_layer(layer, prefix) + if "awq" in self.packing_format or "awq" in self.backend: + return self.apply_awq_quant_layer(layer, prefix) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index cfc31ae20549..4660c28c8de4 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch @@ -25,7 +25,7 @@ def __init__( weight_bits: int, group_size: int, zero_point: bool, - modules_to_not_convert: Optional[List[str]] = None, + modules_to_not_convert: Optional[list[str]] = None, ) -> None: super().__init__() self.weight_bits = weight_bits @@ -48,7 +48,7 @@ def __repr__(self) -> str: def get_name(self) -> QuantizationMethods: return "awq" - def get_supported_act_dtypes(self) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.half] @classmethod @@ -57,7 +57,7 @@ def get_min_capability(cls) -> int: return 75 @staticmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: return [ "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq @@ -65,7 +65,7 @@ def get_config_filenames() -> List[str]: ] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + def from_config(cls, config: dict[str, Any]) -> "AWQConfig": weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) @@ -82,7 +82,7 @@ def get_quant_method(self, layer: torch.nn.Module, return None -def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): +def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]): return any(module_name in prefix for module_name in modules_to_not_convert) @@ -98,7 +98,7 @@ def __init__(self, quant_config: AWQConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 556166f19f25..0c8d082bb428 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import torch from torch.nn import Parameter @@ -46,8 +46,8 @@ class AWQMarlinConfig(QuantizationConfig): def __init__(self, weight_bits: int, group_size: int, zero_point: bool, lm_head_quantized: bool, - modules_to_not_convert: Optional[List[str]], - full_config: Dict[str, Any]) -> None: + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any]) -> None: super().__init__() self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size @@ -79,7 +79,7 @@ def get_name(cls) -> QuantizationMethods: return "awq_marlin" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod @@ -87,11 +87,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": + def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) @@ -150,7 +150,7 @@ def get_quant_method(self, layer: torch.nn.Module, return None @classmethod - def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): + def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") @@ -189,7 +189,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 8cf058b406fb..c9533da9d46e 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -2,7 +2,7 @@ import inspect from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Optional import torch from torch import nn @@ -48,7 +48,7 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: def method_has_implemented_embedding( - method_class: Type[QuantizeMethodBase]) -> bool: + method_class: type[QuantizeMethodBase]) -> bool: """ Not all quant methods have embedding implemented, so we need to check that it exists for our given method. We check this by making sure the function @@ -68,7 +68,7 @@ class QuantizationConfig(ABC): def __init__(self): super().__init__() # mapping is updated by models as they initialize - self.packed_modules_mapping: Dict[str, List[str]] = dict() + self.packed_modules_mapping: dict[str, list[str]] = dict() @abstractmethod def get_name(self) -> QuantizationMethods: @@ -76,7 +76,7 @@ def get_name(self) -> QuantizationMethods: raise NotImplementedError @abstractmethod - def get_supported_act_dtypes(self) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: """List of supported activation dtypes.""" raise NotImplementedError @@ -93,13 +93,13 @@ def get_min_capability(cls) -> int: @staticmethod @abstractmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: """List of filenames to search for in the model directory.""" raise NotImplementedError @classmethod @abstractmethod - def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": + def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig": """Create a config class from the model's quantization config.""" raise NotImplementedError @@ -115,7 +115,7 @@ def override_quantization_method( return None @staticmethod - def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: + def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any: """Get a value from the model's quantization config.""" for key in keys: if key in config: @@ -124,7 +124,7 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @staticmethod - def get_from_keys_or(config: Dict[str, Any], keys: List[str], + def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any: """Get a optional value from the model's quantization config.""" try: diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index ab858d72034a..1cd12bb76317 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch @@ -105,7 +105,7 @@ def get_name(cls) -> QuantizationMethods: return "bitblas" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod @@ -114,12 +114,12 @@ def get_min_capability(cls) -> int: return 70 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @staticmethod - def get_from_keys(config: Dict[str, Any], - keys: List[str], + def get_from_keys(config: dict[str, Any], + keys: list[str], default: Any = None) -> Any: """Get a value from the model's quantization config.""" for key in keys: @@ -128,7 +128,7 @@ def get_from_keys(config: Dict[str, Any], return default @classmethod - def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": + def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"], -1) desc_act = cls.get_from_keys(config, ["desc_act"], False) @@ -193,7 +193,7 @@ def create_weights_gptq( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -329,7 +329,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index a472779d930b..049ce7a7191d 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch @@ -29,7 +29,7 @@ def __init__( bnb_4bit_use_double_quant: bool = False, llm_int8_enable_fp32_cpu_offload: bool = False, llm_int8_has_fp16_weight: bool = False, - llm_int8_skip_modules: Optional[List[str]] = None, + llm_int8_skip_modules: Optional[list[str]] = None, llm_int8_threshold: float = 6.0, ) -> None: super().__init__() @@ -61,7 +61,7 @@ def get_name(self) -> QuantizationMethods: return "bitsandbytes" @classmethod - def get_supported_act_dtypes(self) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.float32, torch.float16, torch.bfloat16] @classmethod @@ -69,13 +69,13 @@ def get_min_capability(cls) -> int: return 70 @staticmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: return [ "adapter_config.json", ] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": + def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig": def get_safe_value(config, keys, default_value=None): try: @@ -130,7 +130,7 @@ def get_quant_method(self, layer: torch.nn.Module, return None -def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): +def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]): # Split the prefix into its dot-separated components components = prefix.split('.') @@ -169,7 +169,7 @@ def __init__(self, quant_config: BitsAndBytesConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): from bitsandbytes.nn import Int8Params diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0585c09bd84b..27547f315fef 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import suppress -from typing import Any, Dict, List, Literal, Optional, Tuple, cast +from typing import Any, Literal, Optional, cast import torch from compressed_tensors.config import (CompressionFormat, @@ -23,9 +23,10 @@ CompressedTensorsMoEMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, - CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + CompressedTensorsScheme, CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -37,20 +38,20 @@ __all__ = ["CompressedTensorsLinearMethod"] SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" -QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]] +QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]] class CompressedTensorsConfig(QuantizationConfig): def __init__( self, - target_scheme_map: Dict[str, Any], - ignore: List[str], + target_scheme_map: dict[str, Any], + ignore: list[str], quant_format: str, - sparsity_scheme_map: Dict[str, SparsityCompressionConfig], - sparsity_ignore_list: List[str], - kv_cache_scheme: Optional[Dict[str, Any]] = None, - config: Optional[Dict[str, Any]] = None, + sparsity_scheme_map: dict[str, SparsityCompressionConfig], + sparsity_ignore_list: list[str], + kv_cache_scheme: Optional[dict[str, Any]] = None, + config: Optional[dict[str, Any]] = None, ): super().__init__() self.ignore = ignore @@ -65,7 +66,7 @@ def __init__( def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod @@ -101,8 +102,8 @@ def get_quant_method( return None @classmethod - def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": - ignore: List[str] = cast(List[str], config.get("ignore", [])) + def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig": + ignore: list[str] = cast(list[str], config.get("ignore", [])) quant_format = cast(str, config.get("format")) target_scheme_map = cls._quantization_scheme_map_from_config( config=config) @@ -120,8 +121,8 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": @classmethod def _parse_sparsity_config( - cls, config: Dict[str, Any] - ) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]: + cls, config: dict[str, Any] + ) -> tuple[dict[str, SparsityCompressionConfig], list[str]]: """ :param config: The `quantization_config` dictionary from config.json :return: A tuple with two elements @@ -134,7 +135,7 @@ def _parse_sparsity_config( sparsity_config = SparsityCompressionConfig.model_validate( sparsity_config) - sparse_scheme_map: Dict[str, SparsityCompressionConfig] = { + sparse_scheme_map: dict[str, SparsityCompressionConfig] = { target: sparsity_config for target in sparsity_config.targets or list() } @@ -143,13 +144,13 @@ def _parse_sparsity_config( @classmethod def _quantization_scheme_map_from_config( - cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: """ :param config: The `quantization_config` dictionary from config.json :return: A dictionary mapping target layer names to their corresponding quantization_args for weights and input activations """ - target_scheme_map: Dict[str, Any] = dict() + target_scheme_map: dict[str, Any] = dict() quant_format = cast(str, config.get("format")) # The quant_config has multiple config_groups, each containing @@ -187,7 +188,7 @@ def _quantization_scheme_map_from_config( return target_scheme_map @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return [] def _check_scheme_supported(self, @@ -216,6 +217,21 @@ def _check_scheme_supported(self, else: return False + def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, + input_quant: BaseModel): + + is_weight_only = weight_quant is not None and input_quant is None + is_group_quant = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_symmetric = weight_quant.symmetric + + is_group_size_16 = weight_quant.group_size == 16 + is_float_type = weight_quant.type == QuantizationType.FLOAT + is_4_bits = weight_quant.num_bits == 4 + + return (is_weight_only and is_group_quant and is_float_type + and is_4_bits and is_group_size_16 and is_symmetric) + def _is_static_tensor_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 @@ -315,6 +331,9 @@ def _get_scheme_from_parts( input_quant: BaseModel) -> "CompressedTensorsScheme": # Detect If Mixed Precision + if self._is_fp4a16_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A16Fp4() + if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): @@ -546,7 +565,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """ @@ -592,7 +611,7 @@ def __init__(self, quant_config: CompressedTensorsConfig): super().__init__(quant_config) @staticmethod - def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]): + def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]): """ Validator for the kv cache scheme. Useful for controlling the kv cache quantization schemes, that are being supported in vLLM diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ae16a20cfaab..9241ceeb4db2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,25 +2,32 @@ import enum from enum import Enum -from typing import Callable, List, Optional +from typing import Callable, Optional import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import (ActivationOrdering, QuantizationStrategy) +import vllm.envs as envs import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa + WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_moe_marlin_supports_layer, marlin_make_workspace_new, + marlin_moe_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -54,18 +61,20 @@ def get_moe_method( "input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): - # Prefer to use the non-marlin kernel when: - # 1. Many experts (MarlinMoE gives poor performance when >= 16) - # 2. Non-FP16 dtype (MarlinMoE only supports FP16) - # 3. Actorder is not group/dynamic (g_idx is unsupported) - # 4. Scaled are grouped (channelwise is unsupported) - if ((layer.local_num_experts >= 16 - or layer.params_dtype != torch.float16) and - weight_quant.actorder not in (ActivationOrdering.GROUP, - ActivationOrdering.DYNAMIC) - and weight_quant.strategy in QuantizationStrategy.GROUP): + # group_size=None means channelwise + group_size = weight_quant.group_size or -1 + # Prefer to use the MarlinMoE kernel when it is supported. + if not check_moe_marlin_supports_layer(layer, group_size): + if (weight_quant.strategy in QuantizationStrategy.GROUP and + weight_quant.actorder in (ActivationOrdering.GROUP, + ActivationOrdering.DYNAMIC)): + raise ValueError( + "WNA16MoE is not supported with actorder=group/dynamic." + ) + logger.info_once("Using CompressedTensorsWNA16MoEMethod") return CompressedTensorsWNA16MoEMethod(quant_config) else: + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) and layer.activation == "silu"): @@ -109,10 +118,28 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization.") + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + params_dtype = torch.float8_e4m3fn # WEIGHTS @@ -253,11 +280,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) - # Property to determine if AITER is used - if is_rocm_aiter_moe_enabled(): + if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 rocm_aiter_fused_experts, shuffle_weights) @@ -270,11 +294,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - self.fused_experts_func = rocm_aiter_fused_experts + self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts else: from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts + if self.use_marlin: + prepare_moe_fp8_layer_for_marlin(layer, False) + # Activations not quantized for marlin. + del layer.w13_input_scale + del layer.w2_input_scale + def apply( self, layer: torch.nn.Module, @@ -306,6 +336,40 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy == + QuantizationStrategy.CHANNEL, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + if self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + assert not apply_router_weight_on_input, ( + "Apply router weight on input not supported for Marlin MoE.") + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + return self.fused_experts_func( hidden_states=x, w1=layer.w13_weight, @@ -512,7 +576,8 @@ def apply( activation: str = "silu", ) -> torch.Tensor: - assert activation == "silu" + assert activation == "silu", ( + f"{activation} not supported for Cutlass MoE.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -705,15 +770,12 @@ def __init__( f"{CompressionFormat.pack_quantized.value} ", "is supported for the following bits: ", f"{WNA16_SUPPORTED_BITS}") + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): - assert params_dtype == torch.float16, ( - "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 - ) - intermediate_size_full = extra_weight_attrs.pop( "intermediate_size_full") @@ -837,50 +899,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.marlin_state = GPTQMarlinState.REPACK def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, - scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - num_experts = s.shape[0] - output = torch.empty((num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype) - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, - group_size, num_bits) - return output - - size_k2 = layer.w2_weight_packed.shape[2] - size_k13 = layer.w13_weight_packed.shape[2] - num_experts = layer.w13_weight_g_idx.shape[0] device = layer.w13_weight_g_idx.device @@ -938,7 +956,7 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, layer.w13_weight_packed.shape[2], self.num_bits, ) - replace_tensor("w13_weight_packed", marlin_w13_qweight) + replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_weight_packed, layer.w2_g_idx_sort_indices, @@ -946,25 +964,25 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, layer.w2_weight_packed.shape[2], self.num_bits, ) - replace_tensor("w2_weight_packed", marlin_w2_qweight) + replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_weight_scale, - size_k13, - layer.w13_weight_scale.shape[2], - self.group_size, - self.num_bits, + s=layer.w13_weight_scale, + size_k=layer.w13_weight_packed.shape[2], + size_n=layer.w13_weight_scale.shape[2], + group_size=self.group_size, ) - replace_tensor("w13_weight_scale", marlin_w13_scales) + replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_weight_scale, - layer.w2_weight_scale.shape[1] * + s=layer.w2_weight_scale, + size_k=layer.w2_weight_scale.shape[1] * (self.group_size if self.group_size != -1 else self.packed_factor), - size_k2, - self.group_size, - self.num_bits, + size_n=layer.w2_weight_scale.shape[2], + group_size=self.group_size, ) - replace_tensor("w2_weight_scale", marlin_w2_scales) + replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) + + layer.workspace = marlin_make_workspace_new(device, 4) def apply( self, @@ -984,15 +1002,10 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - assert activation == "silu", "Only SiLU activation is supported." - if expert_map is not None: - raise NotImplementedError( - "Expert Parallelism is not supported for " - "fused Marlin MoE method.") - if apply_router_weight_on_input: - raise NotImplementedError( - "Apply router weight on input is not supported for " - "fused Marlin MoE method.") + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + assert not apply_router_weight_on_input, ( + "Apply router weight on input not supported for Marlin MoE.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -1015,11 +1028,14 @@ def apply( router_logits, topk_weights, topk_ids, + quant_type_id=self.quant_type.id, + global_num_experts=global_num_experts, + expert_map=expert_map, g_idx1=layer.w13_weight_g_idx, g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.num_bits, + workspace=layer.workspace, is_k_full=self.is_k_full) @@ -1203,7 +1219,7 @@ def apply( activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts - assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -1223,6 +1239,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_int4_w4a16=self.num_bits == 4, use_int8_w8a16=self.num_bits == 8, global_num_experts=global_num_experts, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index b26c74f2484b..79bf5c108ac2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,6 +3,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) +from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 @@ -16,5 +17,5 @@ "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", - "CompressedTensors24" + "CompressedTensors24", "CompressedTensorsW4A16Fp4" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index ec805c934e4a..f010bc03418c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import torch from compressed_tensors import CompressionFormat, ModelCompressor @@ -31,7 +31,7 @@ def __init__( quantized: bool = False, weight_quant: Optional[QuantizationArgs] = None, input_quant: Optional[QuantizationArgs] = None, - model_compression_config: Optional[Dict[str, Any]] = None, + model_compression_config: Optional[dict[str, Any]] = None, ): self.quantized = quantized self.weight_quant = weight_quant @@ -53,7 +53,7 @@ def create_weights( self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, @@ -327,9 +327,9 @@ def _process_split( ) return sparsity_compressor.decompress_weight(weight_data) - split_weights: List[torch.Tensor] = [] - split_bitmask: List[torch.Tensor] = [] - split_shape: List[Tuple[int, int]] = [] + split_weights: list[torch.Tensor] = [] + split_bitmask: list[torch.Tensor] = [] + split_shape: list[tuple[int, int]] = [] if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): split_weights = torch.split(compressed, layer.logical_widths) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 535ea6b32cfb..6ea31e50caa7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional +from typing import Callable, Optional import torch from torch.nn import Parameter @@ -58,7 +58,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.meta = Parameter(layer.meta.data, requires_grad=False) def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py new file mode 100644 index 000000000000..cf60b34ba78a --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensorsW4A16Fp4"] + + +class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): + + def __init__(self): + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + # dont restrict as emulations + return 80 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer) -> None: + # Process parameters for marlin repacking + + # Rename weight_packed to weight that marlin expects + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + del layer.weight_packed + # Rename weight_global_scale to weight_scale_2 that marlin expects + # Note: ct stores the inverse of what is expected by the marlin kernel + layer.weight_scale_2 = Parameter( + 1 / layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) + del layer.weight_global_scale + + prepare_fp4_layer_for_marlin(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return apply_fp4_marlin_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index 1b54e154ecb0..61e4918ca47f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional +from typing import Callable, Optional import torch from compressed_tensors.quantization import QuantizationStrategy @@ -58,7 +58,7 @@ def process_weights_after_loading(self, layer) -> None: prepare_fp8_layer_for_marlin(layer) def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index e99a452963f4..99bb73b71e9f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional +from typing import Callable, Optional import torch from compressed_tensors.quantization import QuantizationStrategy @@ -90,7 +90,7 @@ def process_weights_after_loading(self, layer) -> None: layer.input_scale = None def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 08d86a4e5ddd..7792ce86553c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional, Set +from typing import Callable, Optional import torch from compressed_tensors.quantization import QuantizationStrategy @@ -19,7 +19,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): - _kernel_backends_being_used: Set[str] = set() + _kernel_backends_being_used: set[str] = set() def __init__(self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool): @@ -33,7 +33,7 @@ def get_min_capability(cls) -> int: return 75 def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 3535dd3f3f14..a33c58acb045 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional, Set +from typing import Callable, Optional import torch from compressed_tensors.quantization import ActivationOrdering @@ -35,7 +35,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): - _kernel_backends_being_used: Set[str] = set() + _kernel_backends_being_used: set[str] = set() def __init__(self, strategy: str, @@ -70,7 +70,7 @@ def get_min_capability(cls) -> int: return 80 def create_weights(self, layer: torch.nn.Module, output_size: int, - input_size: int, output_partition_sizes: List[int], + input_size: int, output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index d5d98ee8ba4d..2380d35702c6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Type +from typing import Optional import torch @@ -126,7 +126,7 @@ def triton_scaled_mm(input: torch.Tensor, weight: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: Type[torch.dtype], + out_dtype: type[torch.dtype], bias: Optional[torch.Tensor] = None, block_size_m: int = 32, block_size_n: int = 32, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 85ae1d5cb787..75e81c4dd49d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -import re +from collections.abc import Iterable, Mapping from types import MappingProxyType -from typing import Iterable, List, Mapping, Optional +from typing import Optional +import regex as re from compressed_tensors import CompressionFormat from torch.nn import Module @@ -20,7 +21,7 @@ def is_activation_quantization_format(format: str) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str] = tuple(), - fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) ) -> bool: if layer_name is None: return False @@ -84,7 +85,7 @@ def find_matched_target( layer_name: Optional[str], module: Module, targets: Iterable[str], - fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) ) -> str: """ Helper function to look up which "target" in the compressed-tensors @@ -171,7 +172,7 @@ def _is_equal_or_regex_match(value: str, def _match_fused_layer( layer_name: str, target_layers: Iterable[str], - fused_mapping: Mapping[str, List[str]]) -> Optional[str]: + fused_mapping: Mapping[str, list[str]]) -> Optional[str]: """ Match a fused layer name to its corresponding individual layer in target_layers. Returns first value in fused_mapping which matches targets @@ -201,7 +202,7 @@ def _match_fused_layer( ] # for each unfused component, find a match in targets - unfused_matches: List[Optional[str]] = [] + unfused_matches: list[Optional[str]] = [] for unfused in unfused_paths: for target in target_layers: if _is_equal_or_regex_match(unfused, target): diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index df7ec3376b55..0c1eaff93e8b 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch import torch.nn as nn @@ -46,7 +46,7 @@ def get_name(cls) -> QuantizationMethods: return "deepspeedfp" @classmethod - def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": + def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) return cls(weight_bits=weight_bits, group_size=group_size) @@ -55,7 +55,7 @@ def get_linear_method(self) -> "DeepSpeedFPLinearMethod": return DeepSpeedFPLinearMethod(self) @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod @@ -64,7 +64,7 @@ def get_min_capability(cls) -> int: return 60 @staticmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: return [ "quant_config.json", "quantize_config.json", @@ -91,7 +91,7 @@ def __init__(self, quant_config: DeepSpeedFPConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index cce95941b714..3601d219df3b 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import torch @@ -25,7 +25,7 @@ def get_name(cls) -> QuantizationMethods: return "experts_int8" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod @@ -33,11 +33,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config": + def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config": return cls() def get_quant_method(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 1fa2b3a8eeea..223682ee9765 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch.nn import Module @@ -28,7 +28,7 @@ class FBGEMMFp8Config(QuantizationConfig): """Config class for FBGEMM Fp8.""" - def __init__(self, ignore_list: List[str], input_scale_ub: float): + def __init__(self, ignore_list: list[str], input_scale_ub: float): super().__init__() self.ignore_list = ignore_list if ignore_list else [] self.input_scale_ub = input_scale_ub @@ -43,7 +43,7 @@ def get_name(cls) -> QuantizationMethods: return "fbgemm_fp8" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.float16] @classmethod @@ -51,11 +51,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": + def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config": ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) @@ -63,7 +63,9 @@ def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.ignore_list): + if is_layer_skipped(prefix=prefix, + ignored_layers=self.ignore_list, + fused_mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() return FBGEMMFp8LinearMethod(self) return None @@ -80,7 +82,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f7056016fe8c..ac9b74945e0c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +import functools import importlib.util -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import torch import torch.nn.functional as F @@ -57,14 +58,13 @@ def __init__( self, is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", - ignored_layers: Optional[List[str]] = None, - weight_block_size: Optional[List[int]] = None, + ignored_layers: Optional[list[str]] = None, + weight_block_size: Optional[list[int]] = None, ) -> None: super().__init__() + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized - if is_checkpoint_fp8_serialized: - logger.warning("Detected fp8 checkpoint. Please note that the " - "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError( f"Unsupported activation scheme {activation_scheme}") @@ -90,7 +90,7 @@ def get_name(cls) -> QuantizationMethods: return "fp8" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod @@ -98,11 +98,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + def from_config(cls, config: dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) @@ -182,6 +182,13 @@ def __init__(self, quant_config: Fp8Config): if current_platform.is_rocm(): self.use_marlin = False + # AITER is only supported on ROCm and only for FP8_FNUZ + # and at the moment are MI300 series + self.use_aiter_and_is_supported = (current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) + self.block_quant = self.quant_config.weight_block_size is not None self.fp8_linear = Fp8LinearOp( # Default to using per_token quantization if cutlass is supported @@ -191,7 +198,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -402,6 +409,7 @@ def apply(self, input_scale=layer.input_scale, bias=bias, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) return self.fp8_linear.apply(input=x, @@ -426,6 +434,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): """ def __init__(self, quant_config: Fp8Config): + from vllm.model_executor.layers.fused_moe import fused_experts self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None @@ -450,6 +459,11 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once( "DeepGemm not supported on the current platform.") + self.fused_experts = functools.partial( # type: ignore + fused_experts, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm) + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -581,7 +595,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights) + is_rocm_aiter_moe_enabled, shuffle_weights) + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -608,7 +624,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, requires_grad=False) - if is_rocm_aiter_moe_enabled(): + if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) @@ -655,19 +671,8 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - if is_rocm_aiter_moe_enabled(): + if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - w13_scales, w2_scales = expand_weights( - layer.w13_weight_scale.data, - layer.w2_weight_scale.data, - expansion_dims=[ - layer.w13_weight.shape[1], layer.w2_weight.shape[1] - ]) - layer.w13_weight_scale = torch.nn.Parameter( - w13_scales.contiguous(), requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_scales.contiguous(), requires_grad=False) - shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight, layer.w2_weight) @@ -740,18 +745,7 @@ def process_weights_after_loading(self, layer: Module) -> None: dq_weight, max_w13_scales[expert_id]) start += shard_size - if is_rocm_aiter_moe_enabled(): - # reshaping weights is required for aiter moe kernel. - expansion_dims = [ - layer.w13_weight.shape[1], layer.w2_weight.shape[1] - ] - max_w13_scales, w2_scales = expand_weights( - max_w13_scales, - layer.w2_weight_scale.data, - expansion_dims=expansion_dims) - layer.w2_weight_scale = torch.nn.Parameter( - w2_scales.contiguous(), requires_grad=False) - + if self.rocm_aiter_moe_enabled: shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight, layer.w2_weight) @@ -769,6 +763,21 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale + def select_gemm_impl(self, prepare_finalize): + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) + + assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( + "Marlin and ROCm AITER are not supported with all2all yet.") + + experts = TritonOrDeepGemmExperts( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + ) + + return experts + def apply( self, layer: torch.nn.Module, @@ -787,8 +796,6 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -802,7 +809,30 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) - if self.use_marlin: + if self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_fused_experts) + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + use_fp8_w8a8=True, + apply_router_weight_on_input=apply_router_weight_on_input, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size) + elif self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + assert not apply_router_weight_on_input, ( + "Apply router weight on input not supported for Marlin MoE.") return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -815,28 +845,26 @@ def apply( quant_type_id=scalar_types.float8_e4m3fn.id, global_num_experts=global_num_experts, expert_map=expert_map) - - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - ) + else: + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index c88152454941..1fcb6d7afc9b 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import gguf import torch @@ -9,7 +9,6 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEMethodBase) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase @@ -19,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -35,7 +35,7 @@ def __repr__(self) -> str: def get_name(self) -> QuantizationMethods: return "gguf" - def get_supported_act_dtypes(self) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.half, torch.bfloat16, torch.float32] @classmethod @@ -43,11 +43,11 @@ def get_min_capability(cls) -> int: return 60 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return [] # no extra configs. @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": + def from_config(cls, config: dict[str, Any]) -> "GGUFConfig": return cls() def get_quant_method(self, layer: torch.nn.Module, @@ -96,8 +96,8 @@ def get_quant_method(self, layer: torch.nn.Module, MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES -def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, - qweight_type: int) -> torch.Tensor: +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, + qweight_type: int) -> torch.Tensor: # HACK: when doing chunked prefill we don't generate output tokens # so input to logits generator is empty which causes invalid parameter if x.shape[0] == 0: @@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, return y +def _fused_mul_mat_gguf_fake( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, +) -> torch.Tensor: + return torch.empty(x.shape[0], + qweight.shape[0], + dtype=x.dtype, + device=x.device) + + +try: + direct_register_custom_op( + op_name="_fused_mul_mat_gguf", + op_func=_fused_mul_mat_gguf, + mutates_args=[], + fake_impl=_fused_mul_mat_gguf_fake, + ) + fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf + +except AttributeError as error: + raise error + + def _fused_moe_gguf( x: torch.Tensor, w1: torch.Tensor, @@ -138,8 +162,21 @@ def _fused_moe_gguf( topk_ids: torch.Tensor, qweight_type: int, qweight_type2: int, - act, + activation: str, ) -> torch.Tensor: + + def act(x: torch.Tensor): + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if activation == "silu": + torch.ops._C.silu_and_mul(out, x) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(out, x) + else: + raise ValueError(f"Unsupported activation: {activation}") + return out + # lazy import to avoid triggering triton import in CPU backend from vllm.model_executor.layers.fused_moe.fused_moe import ( moe_align_block_size) @@ -189,12 +226,12 @@ def _fused_moe_gguf( for ww, ii in zip(w, idx): expert_up = w1[ii] - out = _fuse_mul_mat(inp, expert_up, qweight_type) + out = fused_mul_mat_gguf(inp, expert_up, qweight_type) out = act(out) expert_down = w2[ii] - current_state = _fuse_mul_mat(out, expert_down, - qweight_type2).mul_(ww) + current_state = fused_mul_mat_gguf(out, expert_down, + qweight_type2).mul_(ww) if current_hidden_state is None: current_hidden_state = current_state else: @@ -203,6 +240,78 @@ def _fused_moe_gguf( return out_hidden_states +def _fused_moe_gguf_fake( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + qweight_type: int, + qweight_type2: int, + activation: str, +) -> torch.Tensor: + return torch.empty_like(x) + + +try: + direct_register_custom_op( + op_name="_fused_moe_gguf", + op_func=_fused_moe_gguf, + mutates_args=[], + fake_impl=_fused_moe_gguf_fake, + ) + fused_moe_gguf = torch.ops.vllm._fused_moe_gguf + +except AttributeError as error: + raise error + + +def _apply_gguf_embedding( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, + hidden_size: int, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if qweight_type in UNQUANTIZED_TYPES: + return torch.embedding(qweight, x) + elif qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + x_flat = x.flatten() + assert (hidden_size == qweight.shape[1] // type_size * block_size) + quant = torch.index_select(qweight, dim=0, index=x_flat) + dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, + x_flat.shape[0], dtype) + return dequant.view(*x.shape, hidden_size) + else: + qweight_type = WeightType(qweight_type) + raise NotImplementedError( + f"Unsupported GGUF quantization type: {qweight_type}") + + +def _apply_gguf_embedding_fake( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, + hidden_size: int, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device) + + +try: + direct_register_custom_op( + op_name="_apply_gguf_embedding", + op_func=_apply_gguf_embedding, + mutates_args=[], + fake_impl=_apply_gguf_embedding_fake, + ) + apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding + +except AttributeError as error: + raise error + + class GGUFLinearMethod(LinearMethodBase): """Linear method for GGUF. @@ -215,7 +324,7 @@ def __init__(self, quant_config: GGUFConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): self.params_dtype = params_dtype @@ -249,26 +358,76 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(qweight_type, extra_weight_attrs) layer.register_parameter("qweight_type", qweight_type) + def process_weights_after_loading(self, layer: torch.nn.Module): + qweight_type = layer.qweight_type.weight_type + if not (qweight_type in UNQUANTIZED_TYPES + or qweight_type in DEQUANT_TYPES): + qweight_type = WeightType(qweight_type) + raise ValueError( + f"Unsupported GGUF quantization type {qweight_type} in " + f"layer {layer}.") + # For MergedColumnParallelLinear and QKVParallelLinear, we need to + # materialize the padded weight parameter for CUDA Graph compatibility. + self._create_padded_weight_param(layer) + + def _create_padded_weight_param(self, layer: torch.nn.Module): + """Create padded weight parameter for GGUF MergedLinear layer.""" + qweight = layer.qweight + shard_id_map = qweight.shard_id_map + shard_id = qweight.shard_id + if len(data_container := qweight.data_container) > 1: + dtype = {data.dtype for data in data_container} + assert len(dtype) == 1, ValueError( + f"Data container has mixed dtypes: {dtype}") + dtype = next(iter(dtype)) + # concat dim0 and pad dim1 + padded_side = max(x.size(1) for x in data_container) + concat_side = sum(x.size(0) for x in data_container) + # Pad the quantized weights to dense tensor, and create a map + # with the location of each shard in the padded tensor. + padded_data = torch.zeros((concat_side, padded_side), + dtype=dtype, + device=qweight.device) + # (dim0_start, dim0_end, dim1_size) + shard_offset_map = dict[str, tuple[int, int, int]]() + for idx in shard_id: + id_in_container = shard_id_map[idx] + start = sum( + x.size(0) for x in data_container[:id_in_container]) + end = start + data_container[id_in_container].size(0) + size = data_container[id_in_container].size(1) + padded_data[start:end, :size] = data_container[id_in_container] + shard_offset_map[idx] = (start, end, size) + qweight.data_container.clear() + padded_param = Parameter(padded_data, requires_grad=False) + set_weight_attrs(padded_param, vars(qweight)) + set_weight_attrs(padded_param, + {"shard_offset_map": shard_offset_map}) + layer.register_parameter("qweight", padded_param) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - shard_id = getattr(layer.qweight, "shard_id", None) + shard_id = layer.qweight.shard_id if shard_id: # dequantize shard weights respectively shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id - qweight = layer.qweight.unbind(0) + qweight = layer.qweight result = [] for idx in shard_id: - q_idx = layer.qweight.shard_id_map[idx] + start, end, offset = layer.qweight.shard_offset_map[idx] qweight_type = layer.qweight_type.shard_weight_type[idx] - result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type)) + result.append( + fused_mul_mat_gguf( + x, qweight[start:end, :offset].contiguous(), + qweight_type)) out = torch.cat(result, axis=1) else: qweight = layer.qweight qweight_type = layer.qweight_type.weight_type - out = _fuse_mul_mat(x, qweight, qweight_type) + out = fused_mul_mat_gguf(x, qweight, qweight_type) if bias is not None: out.add_(bias) return out @@ -338,7 +497,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_qweight_type, extra_weight_attrs) layer.register_parameter("w2_qweight_type", w2_qweight_type) - self.act = SiluAndMul() def apply( self, @@ -375,10 +533,10 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, - topk_weights, topk_ids, - layer.w13_qweight_type.weight_type, - layer.w2_qweight_type.weight_type, self.act) + return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, + topk_weights, topk_ids, + layer.w13_qweight_type.weight_type, + layer.w2_qweight_type.weight_type, activation) class GGUFEmbeddingMethod(GGUFLinearMethod): @@ -392,34 +550,15 @@ def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor: qweight = layer.qweight qweight_type = layer.qweight_type.weight_type + hidden_size = qweight.tensor_shape[1] - block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] - hidden_size = qweight.shape[1] // type_size * block_size - if qweight_type < 2: - return torch.embedding(qweight, x) - x_flat = x.flatten() - quant = torch.index_select(qweight, dim=0, index=x_flat) - dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, - x_flat.shape[0], self.params_dtype) - return dequant.view(*x.shape, hidden_size) + return apply_gguf_embedding(x, + qweight, + qweight_type, + hidden_size, + dtype=self.params_dtype) class GGUFUninitializedParameter(UninitializedParameter): cls_to_become = Parameter - data_container: List[torch.Tensor] - - def materialize_nested(self) -> Parameter: - dtype = {data.dtype for data in self.data_container} - assert len(dtype) == 1, ValueError( - f"Data container has mixed dtypes: {dtype}") - dtype = next(iter(dtype)) - nested_data = torch.nested.nested_tensor(self.data_container, - device=self.device, - dtype=dtype) - self.data_container.clear() - param = torch.Tensor._make_subclass(self.cls_to_become, - nested_data, - require_grad=False) - for k, v in self.__dict__.items(): - setattr(param, k, v) - return param + data_container: list[torch.Tensor] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 5059e0cdfd4a..436f1e3ccc1a 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -3,7 +3,7 @@ import enum from enum import Enum from fractions import Fraction -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from torch.nn.parameter import Parameter @@ -34,11 +34,11 @@ def __init__( group_size: int, desc_act: bool, lm_head_quantized: bool, - dynamic: Dict[str, Dict[str, Union[int, bool]]], + dynamic: dict[str, dict[str, Union[int, bool]]], ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. - # Format is Dict[str, Dict] where key is a regex string that can + # Format is dict[str, dict] where key is a regex string that can # perform both positive ("+:" prefixed) or negative ("-:" prefixed) # matching of a module. # Default to positive match, override base quant config mode, if no @@ -84,7 +84,7 @@ def get_name(cls) -> QuantizationMethods: return "gptq" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half] @classmethod @@ -93,11 +93,11 @@ def get_min_capability(cls) -> int: return 60 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + def from_config(cls, config: dict[str, Any]) -> "GPTQConfig": dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic @@ -135,7 +135,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index b06c9579d63d..be9510abdffb 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Set +from typing import Any, Optional import torch from torch.nn.parameter import Parameter @@ -129,7 +129,7 @@ def get_name(cls) -> QuantizationMethods: return "gptq_bitblas" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod @@ -137,11 +137,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig": + def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) @@ -185,7 +185,7 @@ def torch_storage_dtype(self) -> torch.dtype: return self.TORCH_BITBLAS_STORAGE_DTYPE @classmethod - def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]): + def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]): # Extract data from quant config. num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") @@ -224,7 +224,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): """ kernel_type = BitBLASLinearKernel - _kernel_backends_being_used: Set[str] = set() + _kernel_backends_being_used: set[str] = set() def __init__(self, quant_config: GPTQBitBLASConfig) -> None: self.quant_config = quant_config @@ -236,7 +236,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 56aafca87e9e..cf012e145ee6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Optional, Union import torch @@ -45,8 +45,8 @@ class GPTQMarlinConfig(QuantizationConfig): def __init__(self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool, lm_head_quantized: bool, - dynamic: Dict[str, Dict[str, Union[int, bool]]], - full_config: Dict[str, Any]) -> None: + dynamic: dict[str, dict[str, Union[int, bool]]], + full_config: dict[str, Any]) -> None: super().__init__() if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False @@ -55,7 +55,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. - # Format is Dict[str, Dict] where key is a regex string that can + # Format is dict[str, dict] where key is a regex string that can # perform both positive ("+:" prefixed) or negative ("-:" prefixed) # matching of a module. # Default to positive match, override base quant config mode, if no @@ -105,7 +105,7 @@ def get_name(cls) -> QuantizationMethods: return "gptq_marlin" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod @@ -113,11 +113,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig": dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic @@ -167,7 +167,7 @@ def get_quant_method(self, layer: torch.nn.Module, GPTQMarlinLinearMethod) @classmethod - def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): + def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") @@ -199,7 +199,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): quant_config: The GPTQ Marlin quantization config. """ - _kernel_backends_being_used: Set[str] = set() + _kernel_backends_being_used: set[str] = set() def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config @@ -212,7 +212,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -610,9 +610,9 @@ def apply( activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - if apply_router_weight_on_input is not None: + if apply_router_weight_on_input: raise NotImplementedError( - "Apply router weight on input is not supported for" + "Apply router weight on input is not supported for " "fused Marlin MoE method.") topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 1fe08e4b34fe..e90416f37791 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch.nn.parameter import Parameter @@ -90,7 +90,7 @@ def get_name(cls) -> QuantizationMethods: return "gptq_marlin_24" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half] @classmethod @@ -99,11 +99,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config": + def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) return cls(weight_bits, group_size) @@ -146,7 +146,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 7bd398137e02..a8faf97723cd 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch @@ -32,7 +32,7 @@ def __init__( self, weight_bits: int, group_size: int, - skip_modules: Optional[List[str]] = None, + skip_modules: Optional[list[str]] = None, ) -> None: super().__init__() assert group_size == 64, ("The only supported HQQ group size is " @@ -55,7 +55,7 @@ def get_name(cls) -> QuantizationMethods: return "hqq" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod @@ -63,11 +63,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig": + def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig": wq_params = (config["quant_config"]["weight_quant_params"]) weight_bits = cls.get_from_keys(wq_params, ["nbits"]) group_size = cls.get_from_keys(wq_params, ["group_size"]) @@ -192,7 +192,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -304,8 +304,10 @@ def apply( marlin_out = ops.gptq_marlin_gemm( x, + None, layer.marlin_qweight, scales, + None, zeros, layer.g_idx, layer.g_idx_sort_indices, @@ -315,7 +317,7 @@ def apply( self.output_size_per_partition, self.input_size_per_partition, True, # is_k_full - True, # has_zp + False, # use atomic add True, # use 32-bit reduce True, # use float zp ) diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 212af278ff81..8108c797637d 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch @@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.platforms import current_platform -MIN_IPEX_VERSION = "2.5.0" +MIN_IPEX_VERSION = "2.7.0" class IPEXConfig(QuantizationConfig): @@ -32,7 +32,7 @@ def __init__( method: str, weight_bits: int, group_size: int, - modules_to_not_convert: Optional[List[str]] = None, + modules_to_not_convert: Optional[list[str]] = None, desc_act: Optional[bool] = None, lm_head_quantized: Optional[bool] = None, ) -> None: @@ -63,7 +63,7 @@ def get_name(cls) -> QuantizationMethods: return "ipex" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.float16] @classmethod @@ -71,14 +71,14 @@ def get_min_capability(cls) -> int: return -1 @staticmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: return [ "quant_config.json", "quantize_config.json", ] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": + def from_config(cls, config: dict[str, Any]) -> "IPEXConfig": method = cls.get_from_keys(config, ["quant_method"]).lower() if method == "awq": weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) @@ -181,8 +181,6 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) - if bias is not None: - out.add_(bias) return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index c06befaf3b5a..55ad00b1cf46 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Callable, Optional, Tuple +from typing import Callable, Optional import torch @@ -12,8 +12,8 @@ @dataclass class MPLinearLayerConfig: - full_weight_shape: Tuple[int, int] # [in, out] - partition_weight_shape: Tuple[int, int] + full_weight_shape: tuple[int, int] # [in, out] + partition_weight_shape: tuple[int, int] weight_type: ScalarType act_type: torch.dtype group_size: int @@ -31,7 +31,7 @@ def get_min_capability(cls) -> int: @classmethod @abstractmethod def can_implement(cls, - c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: raise NotImplementedError def __init__(self, @@ -75,7 +75,7 @@ def _transform_param(self, layer: torch.nn.Module, name: Optional[str], torch.nn.Parameter(new_param.data, requires_grad=False)) def _get_weight_params( - self, layer: torch.nn.Module) -> Tuple[ + self, layer: torch.nn.Module) -> tuple[ torch.Tensor, # w_q torch.Tensor, # w_s Optional[torch.Tensor], # w_zp, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index d144bb436104..bb1dc40ad71a 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Type +from typing import Optional import vllm.envs as envs from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 @@ -18,7 +18,7 @@ from vllm.platforms import current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ +_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, @@ -29,7 +29,7 @@ def choose_mp_linear_kernel( config: MPLinearLayerConfig, - compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + compute_capability: Optional[int] = None) -> type[MPLinearKernel]: """ Choose an MPLinearKernel that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of @@ -46,7 +46,7 @@ def choose_mp_linear_kernel( ValueError: If no kernel can implement the given config. Returns: - Type[MPLinearKernel]: Chosen kernel. + type[MPLinearKernel]: Chosen kernel. """ if compute_capability is None: if current_platform is None: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py index 56fdd6a18e0d..e07177dd675f 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch @@ -22,7 +22,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, - c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if c.has_g_idx: return False, "Act reordering currently not supported by AllSpark" diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py index 21452d08b8a1..29e20699184c 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch @@ -21,10 +21,10 @@ class BitBLASLinearKernel(MPLinearKernel): - OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES + OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES ENABLE_TUNING: bool = True MATMUL_LAYOUT: str = "nt" - BITBLAS_DTYPES: Dict[torch.dtype, str] = { + BITBLAS_DTYPES: dict[torch.dtype, str] = { torch.float32: "float32", torch.float16: "float16", torch.bfloat16: "bfloat16", @@ -103,7 +103,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, - c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: is_bitblas_installed = True diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py index 2706fbb539ab..50d293cf415b 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch @@ -25,7 +25,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, - c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if c.has_g_idx and\ c.partition_weight_shape[0] != c.full_weight_shape[0]: return False, "Act reordering currently not supported by Exllama, "\ diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index b3ffeca4f100..855867fa4a00 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Optional, Tuple +from typing import Optional import torch @@ -25,7 +25,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, - c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if c.has_g_idx and\ c.partition_weight_shape[0] != c.full_weight_shape[0]: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 97fcde1618c7..899011f00051 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch @@ -24,7 +24,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, - c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 91e7654053f9..2d92af74bbf9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch @@ -24,7 +24,7 @@ def get_min_capability(cls) -> int: @classmethod @abstractmethod def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: raise NotImplementedError def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, @@ -50,7 +50,7 @@ def apply_weights(self, raise NotImplementedError def _get_weight_params( - self, layer: torch.nn.Module) -> Tuple[ + self, layer: torch.nn.Module) -> tuple[ torch.Tensor, # weight torch.Tensor, # weight_scale Optional[torch.Tensor], # input_scale, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 014108e69506..5d58c0489a28 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Dict, List, Optional, Type +from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel) @@ -16,7 +16,7 @@ from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { +_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], @@ -27,7 +27,7 @@ def choose_scaled_mm_linear_kernel( config: ScaledMMLinearLayerConfig, compute_capability: Optional[int] = None -) -> Type[ScaledMMLinearKernel]: +) -> type[ScaledMMLinearKernel]: """ Choose an ScaledMMLinearKernel that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of @@ -44,7 +44,7 @@ def choose_scaled_mm_linear_kernel( ValueError: If no kernel can implement the given config. Returns: - Type[ScaledMMLinearKernel]: Chosen kernel. + type[ScaledMMLinearKernel]: Chosen kernel. """ if compute_capability is None: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 582b12f76562..6c2c464e6f1b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch @@ -20,7 +20,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_rocm(): return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 047724129522..98a0b30be1f6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch @@ -22,7 +22,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if (not current_platform.is_cuda() and not current_platform.is_cpu()): return False, "CutlassScaledMM requires running on CUDA or CPU." diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index 5da5df8efaeb..c09ca83d01cb 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch @@ -18,7 +18,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if current_platform.is_cpu(): return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 089314071d39..a97b53b9d7b9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from typing import Optional, Tuple +from typing import Optional import torch from functorch.experimental.control_flow import cond # noqa: F401 @@ -25,7 +25,7 @@ def get_min_capability(cls) -> int: @classmethod def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_tpu(): return False, "ScaledMMXLA requires running on TPU." diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 5dff8b09693c..67723c7c91cc 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -124,11 +124,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) layer._prob_scale.copy_(prob_scale) - if q_scale == 1.0 or prob_scale == 1.0: + if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 + or prob_scale == 1.0): logger.warning_once( - f"Using Q scale {q_scale} and prob scale {prob_scale} " - "with fp8 attention. This may cause accuracy issues. " - "Please make sure Q/prob scaling factors are " + f"Using uncalibrated q_scale {q_scale} and/or prob_scale " + f"{prob_scale} with fp8 attention. This may cause accuracy " + "issues. Please make sure q/prob scaling factors are " "available in the fp8 checkpoint.") del layer.k_scale diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 9ef71a7894d7..2437030c8771 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch.nn.parameter import Parameter @@ -68,7 +68,7 @@ def get_name(cls) -> QuantizationMethods: return "marlin" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half] @classmethod @@ -77,11 +77,11 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": + def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) @@ -128,7 +128,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 828447dd1019..1c5680f952ab 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch.nn import Module @@ -9,12 +9,17 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, is_fp4_marlin_supported, + prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -22,6 +27,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -47,7 +53,7 @@ def get_name(cls) -> QuantizationMethods: return "modelopt" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod @@ -55,11 +61,11 @@ def get_min_capability(cls) -> int: return 89 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": quant_config = cls.get_from_keys(config, ["quantization"]) quant_method = quant_config["quant_algo"] if quant_method not in QUANT_ALGOS: @@ -101,7 +107,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -171,7 +177,7 @@ def __init__( self, is_checkpoint_nvfp4_serialized: bool, kv_cache_quant_algo: str, - exclude_modules: List[str], + exclude_modules: list[str], group_size: int = 16, ) -> None: self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized @@ -186,22 +192,22 @@ def __init__( @classmethod def get_name(cls) -> QuantizationMethods: - return "nvfp4" + return "modelopt_fp4" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half, torch.float8_e4m3fn] @classmethod def get_min_capability(cls) -> int: - return 100 + return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config": + def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": quant_config = cls.get_from_keys(config, ["quantization"]) quant_method = quant_config["quant_algo"] if quant_method not in QUANT_ALGOS: @@ -210,25 +216,37 @@ def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config": "`hf_quant_config.json` file for your model's " "quant configuration.") is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) - kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] - group_size = quant_config["group_size"] - exclude_modules = quant_config["exclude_modules"] - if not (group_size and kv_cache_quant_algo and exclude_modules): + if ("group_size" and "kv_cache_quant_algo" + and "exclude_modules") not in quant_config: raise ValueError("NVFP4 quantization requires group size and " "kv_cache_quant_algo specified in " "hf_quant_config.json") + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + group_size = quant_config["group_size"] + exclude_modules = quant_config["exclude_modules"] return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size) + def is_layer_excluded(self, prefix: str, exclude_modules: list): + import regex as re + for pattern in exclude_modules: + regex_str = pattern.replace('.', r'\.').replace('*', r'.*') + if re.fullmatch(regex_str, prefix): + return True + return False + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.exclude_modules): + if (is_layer_skipped(prefix, self.exclude_modules) + or self.is_layer_excluded(prefix, self.exclude_modules)): return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) + elif isinstance(layer, FusedMoE): + return ModelOptNvFp4FusedMoE(self) return None @@ -264,15 +282,21 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config self.cutlass_nvfp4_supported = cutlass_fp4_supported() + self.use_marlin = False + if not self.cutlass_nvfp4_supported: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and above.") + if is_fp4_marlin_supported(): + self.use_marlin = True + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -377,6 +401,13 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + if self.use_marlin: + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + del layer.weight_scale_swizzled def apply( self, @@ -384,12 +415,19 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - output_dtype = x.dtype + if self.use_marlin: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) - # for input only the contracting dimension has a constraint. - x_m, _ = x.shape - w_n, _ = layer.weight.shape - output_shape = [x_m, w_n] + output_dtype = x.dtype + output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) s_quant = 1 / layer.input_scale @@ -409,3 +447,288 @@ def apply( if bias is not None: out = out + bias return out.view(*output_shape) + + +class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): + """ + MoE Method for FP4 Quantization. + Args: + quant_config: NVFP4 Quant Config + """ + + def __init__(self, quant_config: ModelOptNvFp4Config): + self.quant_config = quant_config + self.cutlass_nvfp4_supported = cutlass_fp4_supported() + self.use_marlin = False + + if not self.cutlass_nvfp4_supported: + if is_fp4_marlin_supported(): + self.use_marlin = True + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + weight_loader = extra_weight_attrs.get("weight_loader") + # GEMM 1 + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight", w2_weight) + + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // + self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + + w13_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + + w2_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + + w13_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # GEMM 1 + assert torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), ( + "w1_weight_scale_2 must match w3_weight_scale_2") + + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, + requires_grad=False) + + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( + torch.float32) + layer.g1_alphas = Parameter( + (w13_input_scale * w13_weight_scale_2).to(torch.float32), + requires_grad=False) + + assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w13_blockscale_swizzled = self.swizzle_blockscale( + layer.w13_weight_scale) + + layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = Parameter( + (1 / w13_input_scale).to(torch.float32), requires_grad=False) + + layer.w13_weight = Parameter(layer.w13_weight.data, + requires_grad=False) + + # GEMM 2 + layer.g2_alphas = Parameter( + (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w2_input_scale_quant = Parameter( + (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + + assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + + layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, + requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + del layer.g1_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant + del layer.w13_blockscale_swizzled + del layer.w2_blockscale_swizzled + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=layer.w13_weight_scale_2, + global_scale2=layer.w2_weight_scale_2, + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert activation == "silu", "Only SiLU activation is supported." + assert not apply_router_weight_on_input, ( + "Router weight on input is not " + "supported for ModelOptNvFp4FusedMoE.") + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4(a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device).to(x.dtype) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index b8e3a4364379..74bd6dc13f84 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import torch @@ -23,8 +23,8 @@ class MoeWNA16Config(QuantizationConfig): def __init__(self, linear_quant_method: str, weight_bits: int, group_size: int, has_zp: bool, lm_head_quantized: bool, - modules_to_not_convert: Optional[List[str]], - full_config: Dict[str, Any]) -> None: + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any]) -> None: super().__init__() self.weight_bits = weight_bits self.group_size = group_size @@ -69,7 +69,7 @@ def get_name(cls) -> QuantizationMethods: return "moe_wna16" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod @@ -77,11 +77,11 @@ def get_min_capability(cls) -> int: return 70 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": + def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config": linear_quant_method = cls.get_from_keys(config, ["quant_method"]) weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) @@ -109,7 +109,7 @@ def override_quantization_method( return None @classmethod - def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): + def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") @@ -163,7 +163,7 @@ def get_quant_method(self, layer: torch.nn.Module, return None -def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): +def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]): return any(module_name in prefix for module_name in modules_to_not_convert) diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py index 7933eab2a530..38b374feea81 100644 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -2,7 +2,7 @@ import os from importlib.util import find_spec -from typing import Any, Dict, List, Optional +from typing import Any, Optional from torch.nn import Module @@ -34,7 +34,7 @@ def __init__( def get_name(self) -> QuantizationMethods: return "neuron_quant" - def get_supported_act_dtypes(self) -> List[str]: + def get_supported_act_dtypes(self) -> list[str]: return SUPPORTED_QUANT_DTYPE_LIST @classmethod @@ -43,11 +43,11 @@ def get_min_capability(cls) -> int: "This function should not be called with Neuron Backend") @staticmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": + def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig": quantize_method = cls.get_from_keys(config, ["quantize_method"]) dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) return cls(dequant_dtype=dequant_dtype, diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 004d74e68b9a..9e4fb33639b2 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch.nn.parameter import Parameter @@ -32,7 +32,7 @@ class PTPCFp8Config(Fp8Config): def __init__( self, activation_scheme: str = "dynamic", - ignored_layers: Optional[List[str]] = None, + ignored_layers: Optional[list[str]] = None, ) -> None: if not current_platform.is_rocm(): raise ValueError( @@ -55,7 +55,7 @@ def get_name(cls) -> QuantizationMethods: return "ptpc_fp8" @classmethod - def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": + def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config": activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) return cls(activation_scheme=activation_scheme, diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index 06ff6c71b913..6028b8a2ada3 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch.nn.parameter import Parameter @@ -89,7 +89,7 @@ def get_name(cls) -> QuantizationMethods: return "qqq" @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half] @classmethod @@ -97,7 +97,7 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: """List of filenames to search for in the model directory.""" return [ "quant_config.json", @@ -105,7 +105,7 @@ def get_config_filenames(cls) -> List[str]: ] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "QQQConfig": + def from_config(cls, config: dict[str, Any]) -> "QQQConfig": weight_bits = cls.get_from_keys(config, ["wbits"]) group_size = cls.get_from_keys(config, ["group_size"]) return cls(weight_bits, group_size) @@ -131,7 +131,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 66e677f56ffd..df4bfbbbcb4c 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import fnmatch -from typing import Any, Dict, List, Optional, cast +from typing import Any, Optional, cast import torch @@ -29,9 +29,9 @@ class QuarkConfig(QuantizationConfig): def __init__(self, - quant_config: Dict[str, Any], - kv_cache_group: Optional[List[str]] = None, - kv_cache_config: Optional[Dict[str, Any]] = None, + quant_config: dict[str, Any], + kv_cache_group: Optional[list[str]] = None, + kv_cache_config: Optional[dict[str, Any]] = None, pack_method: str = "reorder"): super().__init__() if kv_cache_group is None: @@ -44,7 +44,7 @@ def __init__(self, def get_linear_method(self) -> "QuarkLinearMethod": return QuarkLinearMethod(self) - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod @@ -59,7 +59,7 @@ def get_quant_method(self, layer: torch.nn.Module, from vllm.attention.layer import Attention # Avoid circular import # Check if the layer is skipped for quantization. - exclude_layers = cast(List[str], self.quant_config.get("exclude")) + exclude_layers = cast(list[str], self.quant_config.get("exclude")) if should_ignore_layer(prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping): @@ -78,12 +78,12 @@ def get_quant_method(self, layer: torch.nn.Module, return None @classmethod - def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": + def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": export_config = config.get("export") if export_config is None: raise ValueError("The export key should be included in " "the configurations of Quark quantized model") - kv_cache_group = cast(List[str], export_config.get("kv_cache_group")) + kv_cache_group = cast(list[str], export_config.get("kv_cache_group")) pack_method = cast(str, export_config.get("pack_method")) # In the export model of quark, the quantization configuration @@ -95,7 +95,7 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": kv_cache_config = None else: kv_cache_set = set(kv_cache_group) - layer_quant_config = cast(Dict[str, Any], + layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config")) layer_quant_names = list(layer_quant_config.keys()) layer_quant_set = set(layer_quant_names) @@ -108,7 +108,7 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": "configuration.") q_configs = [ - cast(Dict[str, Any], layer_quant_config.get(name)) + cast(dict[str, Any], layer_quant_config.get(name)) for name in kv_cache_group ] if not all( @@ -131,7 +131,7 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": # In case q_proj output is also quantized, remove the configuration # to keep qkv consistency. - q_proj_q_config = cast(Dict[str, Any], + q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj")) if q_proj_q_config is not None: q_proj_q_config["output_tensors"] = None @@ -142,7 +142,7 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": pack_method=pack_method) @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return [] def _check_scheme_supported(self, @@ -162,8 +162,8 @@ def _check_scheme_supported(self, else: return False - def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], - input_quant: Optional[Dict[str, Any]]) -> bool: + def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]]) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: return False @@ -187,8 +187,8 @@ def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor") return is_per_tensor_activation - def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]], - input_quant: Optional[Dict[str, Any]]) -> bool: + def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]]) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: return False @@ -209,8 +209,8 @@ def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]], # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static - def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]], - input_quant: Optional[Dict[str, Any]]) -> bool: + def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]]) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: logger.debug("Quark model is not in MX-FP4 format: " @@ -258,7 +258,7 @@ def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]], return True def _find_matched_config(self, layer_name: str, - module: torch.nn.Module) -> Dict[str, Any]: + module: torch.nn.Module) -> dict[str, Any]: proj_name = layer_name.split(".")[-1] if proj_name in self.packed_modules_mapping: @@ -283,29 +283,29 @@ def _find_matched_config(self, layer_name: str, return shard_configs[0] else: layer_quant_config = cast( - Dict[str, Any], self.quant_config.get("layer_quant_config")) + dict[str, Any], self.quant_config.get("layer_quant_config")) for name_pattern in layer_quant_config: if fnmatch.fnmatch(layer_name, name_pattern): return layer_quant_config[name_pattern] layer_type = cast(str, type(module)) layer_type_quant_config = cast( - Dict[str, Any], + dict[str, Any], self.quant_config.get("layer_type_quant_config")) if layer_type in layer_type_quant_config: return layer_type_quant_config[layer_type] global_quant_config = cast( - Dict[str, Any], self.quant_config.get("global_quant_config")) + dict[str, Any], self.quant_config.get("global_quant_config")) return global_quant_config - def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme": + def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": if config.get("output_tensors") or config.get("bias"): raise NotImplementedError( "Currently, Quark models with output_tensors " "and bias quantized are not supported") - weight_config = cast(Dict[str, Any], config.get("weight")) - input_config = cast(Dict[str, Any], config.get("input_tensors")) + weight_config = cast(dict[str, Any], config.get("weight")) + input_config = cast(dict[str, Any], config.get("input_tensors")) if self._is_fp8_w8a8(weight_config, input_config): is_fp8_w8a8_supported = self._check_scheme_supported( @@ -373,7 +373,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """ @@ -417,7 +417,7 @@ def __init__(self, quant_config: QuarkConfig): super().__init__(quant_config) @staticmethod - def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]): + def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]): """ Validator for the kv cache configuration. Useful for controlling the kv cache quantization schemes, that are being supported in vLLM diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d1146c0f039d..aa7d725433ea 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import torch @@ -45,7 +45,7 @@ def get_moe_method( class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): - def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str, + def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]): self.weight_quant = weight_config self.input_quant = input_config diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 9da52a732fc4..34c077b29163 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import torch import torch.nn.functional as F @@ -18,8 +18,8 @@ class QuarkW4A4MXFP4(QuarkScheme): - def __init__(self, weight_quant_spec: Dict[str, Any], - input_quant_spec: Dict[str, Any]): + def __init__(self, weight_quant_spec: dict[str, Any], + input_quant_spec: dict[str, Any]): self.out_dtype = torch.get_default_dtype() self.qscheme = "per_group" self.weight_quant_spec = weight_quant_spec @@ -74,7 +74,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: torch.cuda.empty_cache() def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index f8eb3611592e..149c9093797f 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional +from typing import Callable, Optional import torch from torch.nn import Parameter @@ -88,7 +88,7 @@ def process_weights_after_loading(self, layer) -> None: layer.input_scale = None def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index da8ed8c08506..94f9fcd56aca 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional, Set +from typing import Callable, Optional import torch @@ -17,7 +17,7 @@ class QuarkW8A8Int8(QuarkScheme): - _kernel_backends_being_used: Set[str] = set() + _kernel_backends_being_used: set[str] = set() def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool], input_symmetric: Optional[bool]): @@ -31,7 +31,7 @@ def get_min_capability(cls) -> int: return 75 def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 17e0df021085..5e56bcb7564c 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -import re +from collections.abc import Iterable, Mapping from types import MappingProxyType -from typing import Any, Iterable, List, Mapping, Optional +from typing import Any, Optional + +import regex as re def deep_compare(dict1: Any, dict2: Any) -> bool: @@ -21,7 +23,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str], - fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) ) -> bool: if layer_name is None: return False diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py index 026881f2dbaa..c0be40c16aff 100644 --- a/vllm/model_executor/layers/quantization/schema.py +++ b/vllm/model_executor/layers/quantization/schema.py @@ -12,7 +12,7 @@ scaling factors. """ -from typing import Dict, Optional +from typing import Optional from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator @@ -23,7 +23,7 @@ class KVCacheQuantSchema(BaseModel): # layer indices to their per-tensor KV cache scaling factor. # TODO: Consider pulling this and its validation methods out into its # own schema class (tricky as its members are variable) - scaling_factor: Dict[int, Dict[int, float]] + scaling_factor: dict[int, dict[int, float]] @model_validator(mode="after") def check_is_fp8(self) -> "KVCacheQuantSchema": diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 751002fa0945..7f9f3e643bfa 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -24,7 +25,7 @@ def __repr__(self) -> str: def get_name(self) -> QuantizationMethods: return "torchao" - def get_supported_act_dtypes(self) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.float32, torch.float16, torch.bfloat16] @classmethod @@ -32,11 +33,11 @@ def get_min_capability(cls) -> int: return 75 @staticmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: return ["config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig": + def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig": """Create the quant config from an hf model config""" try: from torchao.core.config import config_from_dict @@ -55,12 +56,26 @@ def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig": return cls(ao_config) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["TorchAOLinearMethod"]: - if isinstance(layer, LinearBase): - return TorchAOLinearMethod(self) - return None - - def get_scaled_act_names(self) -> List[str]: + prefix: str) -> Optional["QuantizeMethodBase"]: + if not isinstance(layer, LinearBase): + return None + + from torchao.quantization import AOPerModuleConfig + + module_fqn = prefix + if isinstance(self.torchao_config, AOPerModuleConfig): + module_fqn_to_config = self.torchao_config.module_fqn_to_config + c = module_fqn_to_config.get( + module_fqn) or module_fqn_to_config.get("_default", None) + if c is not None: + current_torchao_config = TorchAOConfig(c) + return TorchAOLinearMethod(current_torchao_config) + else: + return UnquantizedLinearMethod() + + return TorchAOLinearMethod(self) + + def get_scaled_act_names(self) -> list[str]: return [] @@ -75,7 +90,7 @@ def torchao_quantize_param_data(param: torch.Tensor, """ from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ - assert isinstance(torchao_config, AOBaseConfig) + assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) dummy_linear.weight = param quantize_(dummy_linear, torchao_config) @@ -97,7 +112,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 8333c16ce6a1..7941ec9732fe 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import torch from torch.nn import Module @@ -31,7 +31,7 @@ def __init__( def get_name(self) -> QuantizationMethods: return "tpu_int8" - def get_supported_act_dtypes(self) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod @@ -40,11 +40,11 @@ def get_min_capability(cls) -> int: "This function should not be called with TPU Backend") @staticmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig": + def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig": activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) return cls(activation_scheme=activation_scheme) @@ -62,7 +62,7 @@ def __init__(self, quant_config: Int8TpuConfig): self.quant_config = quant_config def create_weights(self, layer: Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -77,7 +77,7 @@ def create_weights(self, layer: Module, input_size_per_partition: int, layer.register_parameter("weight", weight) def _quantize_weight( - self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: weight_dtype = weight.dtype weight = weight.cpu().to(torch.float32) n_bit = 8 diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py index e26ac4ea3d4c..70d24cc897e1 100644 --- a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch @@ -51,7 +51,7 @@ def _check_bitblas_supported( quant_type: ScalarType, group_size: Optional[int], has_zp: bool, - device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = current_platform.get_device_capability() @@ -133,7 +133,7 @@ def verify_bitblas_supports_shape(output_size_per_partition: int, def check_bitblas_supports_shape(output_size_per_partition: int, input_size_per_partition: int, input_size: int, group_size: int) \ - -> Tuple[bool, Optional[str]]: + -> tuple[bool, Optional[str]]: try: verify_bitblas_supports_shape(output_size_per_partition, input_size_per_partition, input_size, @@ -166,7 +166,7 @@ def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor: def bitblas_sort_g_idx( - g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 064cbb8cf52d..4c213f2c874e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -4,7 +4,7 @@ import functools import json import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch @@ -27,54 +27,140 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +def cutlass_scaled_mm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + return ops.cutlass_scaled_mm(A, + B.T, + out_dtype=output_dtype, + scale_a=As, + scale_b=Bs.T) + + +def rocm_aiter_gemm_w8a8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + import aiter as rocm_aiter + + return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) + + +def rocm_aiter_gemm_w8a8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8_blockscale", + op_func=rocm_aiter_gemm_w8a8_blockscale_impl, + mutates_args=[], + fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def dispatch_w8a8_blockscale_func( + use_cutlass: bool, use_aiter_and_is_supported: bool +) -> Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + list[int], + torch.dtype, +], torch.Tensor]: + if use_cutlass: + return cutlass_scaled_mm + if (use_aiter_and_is_supported): + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale + return w8a8_block_fp8_matmul + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, - block_size: List[int], + block_size: list[int], weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 - and weight.shape[1] % 128 == 0) - if current_platform.is_rocm(): - # TODO this is never used, as cutlass_block_fp8_supported is False - scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + - input_2d.shape[:-1])[::-1] - scale_b_shape = (weight_scale.view(-1, 1) - if weight_scale.dim() <= 1 else weight_scale.T).shape - ar, ac = scale_a_shape - br, bc = scale_b_shape - if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) - or br not in (1, weight.shape[0])): - shape_supported_by_cutlass = False - if cutlass_block_fp8_supported and shape_supported_by_cutlass: - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=True) - output = ops.cutlass_scaled_mm(q_input, - weight.T, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale.T) + if current_platform.is_cuda(): + if current_platform.has_device_capability(100): + + def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + use_cutlass = cutlass_block_fp8_supported and ( + ceil_div(weight.shape[0], 128) == weight_scale.shape[0] + and ceil_div(weight.shape[1], 128) == weight_scale.shape[1]) + else: + # TODO: update this after switching to public sm90 block scale gemm + # as it also supports weight.shape % 128 != 0 + use_cutlass = cutlass_block_fp8_supported and ( + weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + else: + use_cutlass = False + + w8a8_blockscale_func = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported) + + if use_cutlass: + rows, cols = input_2d.shape + # Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for + # optimal tensor core usage. Can be removed when targeting platforms + # without this constraint. + should_pad = current_platform.has_device_capability( + 100) and rows % 4 != 0 + if should_pad: + input_2d = torch.nn.functional.pad(input_2d, + (0, 0, 0, 4 - (rows % 4)), + value=0).contiguous() + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) + + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) + if should_pad: + output = output[:rows, :] + else: - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=False) - output = w8a8_block_fp8_matmul(q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=input.dtype) + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) + + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) + if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) @@ -83,9 +169,12 @@ def apply_w8a8_block_fp8_linear( def apply_w8a8_block_fp8_linear_fake( input: torch.Tensor, weight: torch.Tensor, - block_size: List[int], + block_size: list[int], weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, ) -> torch.Tensor: output_shape = [*input.shape[:-1], weight.shape[0]] return torch.empty(output_shape, dtype=input.dtype, device=input.device) @@ -102,7 +191,7 @@ def apply_w8a8_block_fp8_linear_fake( def input_to_float8( x: torch.Tensor, dtype: Optional[torch.dtype] = None -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values " "with tensor-wise quantization.""" dtype = current_platform.fp8_dtype() if dtype is None else dtype @@ -117,7 +206,7 @@ def input_to_float8( def block_quant_to_tensor_quant( x_q_block: torch.Tensor, x_s: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """This function converts block-wise quantization to tensor-wise quantization. The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale and the block size. @@ -235,7 +324,7 @@ def per_token_group_quant_fp8( eps: float = 1e-10, dtype: Optional[torch.dtype] = None, column_major_scales: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the quantized tensor along with the scaling factor used for quantization. @@ -246,7 +335,7 @@ def per_token_group_quant_fp8( dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. Returns: - Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ dtype = current_platform.fp8_dtype() if dtype is None else dtype @@ -400,7 +489,7 @@ def _w8a8_block_fp8_matmul( @functools.lru_cache def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, - block_k: int) -> Optional[Dict[int, Any]]: + block_k: int) -> Optional[dict[int, Any]]: """ Return optimized configurations for the w8a8 block fp8 kernel. The return value will be a dictionary that maps an irregular grid of @@ -440,7 +529,7 @@ def w8a8_block_fp8_matmul( B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, - block_size: List[int], + block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: """This function performs matrix multiplication with block-wise diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py index 5b0e6299f473..36161d13b24f 100644 --- a/vllm/model_executor/layers/quantization/utils/gptq_utils.py +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import re from copy import deepcopy -from typing import Dict, Optional, Union +from typing import Optional, Union +import regex as re import torch from vllm.config import QuantizationConfig @@ -52,7 +52,7 @@ def get_dynamic_override( layer_name: str, key: Optional[str] = None, default_value: Union[int, bool, - None] = None) -> Union[Dict, int, bool, None]: + None] = None) -> Union[dict, int, bool, None]: for pattern, pattern_dict in config.dynamic.items(): # Negative match: matched modules are excluded from quantized init if pattern.startswith("-:"): diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 431f0cf73fad..72fff3fa1aed 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -5,7 +5,7 @@ import json import logging import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import torch @@ -18,7 +18,7 @@ def apply_w8a8_block_int8_linear( input: torch.Tensor, weight: torch.Tensor, - block_size: List[int], + block_size: list[int], weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -43,7 +43,7 @@ def apply_w8a8_block_int8_linear( def input_to_int8( x: torch.Tensor, - dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]: + dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to int8 values with tensor-wise quantization.""" iinfo = torch.iinfo(dtype) @@ -58,7 +58,7 @@ def input_to_int8( def block_dequant( x_q_block: torch.Tensor, x_s: torch.Tensor, - block_size: List[int], + block_size: list[int], ) -> torch.Tensor: """This function conducts block-wise dequantization. The inputs are block-wise quantization tensor `x_q_block`, @@ -211,7 +211,7 @@ def per_token_group_quant_int8( group_size: int, eps: float = 1e-10, dtype: torch.dtype = torch.int8, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed int8 values and returns the @@ -225,7 +225,7 @@ def per_token_group_quant_int8( is supported for now. Returns: - Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ assert (x.shape[-1] % group_size == 0 @@ -358,7 +358,7 @@ def _w8a8_block_int8_matmul( @functools.lru_cache def get_w8a8_block_int8_configs(N: int, K: int, block_n: int, - block_k: int) -> Optional[Dict[int, Any]]: + block_k: int) -> Optional[dict[int, Any]]: """ Return optimized configurations for the w8a8 block fp8 kernel. @@ -399,7 +399,7 @@ def w8a8_block_int8_matmul( B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, - block_size: List[int], + block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: """This function performs matrix multiplication with block-wise diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py index cb7d49ed6f1c..6d840b568612 100644 --- a/vllm/model_executor/layers/quantization/utils/machete_utils.py +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple +from typing import Optional import torch @@ -10,19 +10,19 @@ MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] -def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: +def query_machete_supported_quant_types(zero_points: bool) -> list[ScalarType]: if zero_points: return [scalar_types.uint4, scalar_types.uint8] else: return [scalar_types.uint4b8, scalar_types.uint8b128] -def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: +def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]: return [torch.float16, torch.bfloat16] def check_machete_supports_shape(in_features: int, out_featrues: int) \ - -> Tuple[bool, Optional[str]]: + -> tuple[bool, Optional[str]]: if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: return False, "Input features size must be divisible by "\ f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index a2b1b7cb0e1d..e059a7ac3f92 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple +from typing import Optional import numpy import torch @@ -33,7 +33,7 @@ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, + has_zp: Optional[bool] = None, include_fp_type: bool = True, device_capability: Optional[int] = None, ): @@ -45,6 +45,16 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point return [scalar_types.uint4] @@ -52,7 +62,7 @@ def query_marlin_supported_quant_types( # GPTQ style, unsigned + symmetric bias res = [scalar_types.uint4b8, scalar_types.uint8b128] if include_fp_type: - res += [scalar_types.float8_e4m3fn] + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] return res @@ -60,7 +70,7 @@ def _check_marlin_supported( quant_type: ScalarType, group_size: Optional[int], has_zp: bool, - device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = current_platform.get_device_capability() @@ -133,7 +143,7 @@ def verify_marlin_supports_shape(output_size_per_partition: int, def check_marlin_supports_shape(output_size_per_partition: int, input_size_per_partition: int, input_size: int, group_size: int) \ - -> Tuple[bool, Optional[str]]: + -> tuple[bool, Optional[str]]: try: verify_marlin_supports_shape(output_size_per_partition, input_size_per_partition, input_size, @@ -161,13 +171,19 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ -> bool: hidden_size = layer.hidden_size intermediate_size_per_partition = layer.intermediate_size_per_partition + # apply_router_weight_on_input is not supported for moe marlin + supports_router_weight = not layer.apply_router_weight_on_input + # moe marlin requires the activation to be silu + supports_activation = layer.activation == "silu" # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # down: (n, k) = (hidden_size, intermediate_size_per_partition) # moe marlin requires n % 128 == 0 and k % 64 == 0 - return hidden_size % 128 == 0 and \ - intermediate_size_per_partition % max(64, group_size) == 0 and \ - group_size in [-1, 32, 64, 128] + supports_shape = hidden_size % 128 == 0 and \ + intermediate_size_per_partition % max(64, group_size) == 0 + supports_group_size = group_size in [-1, 32, 64, 128] + return supports_shape and supports_group_size and \ + supports_router_weight and supports_activation def marlin_make_workspace(output_size_per_partition: int, @@ -215,16 +231,16 @@ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: def marlin_sort_g_idx( - g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): scale_perm_single.extend( [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) @@ -394,6 +410,7 @@ def apply_gptq_marlin_linear( None, weight, weight_scale, + None, weight_zp, g_idx, g_idx_sort_indices, @@ -439,6 +456,7 @@ def apply_awq_marlin_linear( None, weight, weight_scale, + None, weight_zp, g_idx, g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 000000000000..15177af58ae6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + +logger = init_logger(__name__) + + +def is_fp4_marlin_supported(): + return current_platform.has_device_capability(80) + + +def fp4_marlin_process_scales(marlin_scales): + assert (marlin_scales >= 0).all() + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 1e0078e246be..1f6e74244c5d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -19,6 +19,20 @@ def is_fp8_marlin_supported(): return current_platform.has_device_capability(80) +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( input: torch.Tensor, weight: torch.Tensor, @@ -44,6 +58,7 @@ def apply_fp8_marlin_linear( c=None, b_q_weight=weight, b_scales=weight_scale, + global_scale=None, b_zeros=None, g_idx=None, perm=None, @@ -71,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition + weight_block_size = getattr(layer, "weight_block_size", None) if size_k_first: assert layer.weight.shape == (part_size_k, part_size_n) @@ -104,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, scales = layer.weight_scale_inv.to(layer.orig_dtype) del layer.weight_scale_inv - if layer.weight_block_size is None: - group_size = -1 - else: - group_size = layer.weight_block_size[1] + group_size = -1 if weight_block_size is None else weight_block_size[1] # marlin kernel only support channel-wise and group-wise quantization # we need to convert the scales - if layer.weight_block_size is None: + if weight_block_size is None: if scales.nelement() == 1: # tensor-wise quantization -> channel-wise quantization # (1, 1) =>(repeat)=> (1, size_n) @@ -132,8 +145,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, # block-wise quantization -> group-wise quantization # (size_k // block_size[1], ceil(size_n / block_size[0])) # =>(repeat)=> (size_k // block_size[1], size_n) - block_n = layer.weight_block_size[0] - scales = scales.T.repeat_interleave(block_n, 1) + if not size_k_first: + scales = scales.T.contiguous() + block_n = weight_block_size[0] + scales = scales.repeat_interleave(block_n, 1) # size_n may not divisible by block_size[0] scales = scales[:, :part_size_n] @@ -141,6 +156,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, size_k=part_size_k, size_n=part_size_n, group_size=group_size) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) @@ -155,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, e = layer.num_experts k = layer.hidden_size n = layer.intermediate_size_per_partition + weight_block_size = getattr(layer, "weight_block_size", None) # WORKSPACE device = layer.w13_weight.device @@ -195,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, # WEIGHT SCALES # Permute scales - if layer.weight_block_size is None: - group_size = -1 - else: - group_size = layer.weight_block_size[1] + group_size = -1 if weight_block_size is None else weight_block_size[1] for name in ["w13", "w2"]: if name + "_weight_scale" in dir(layer): @@ -218,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, # marlin kernel only support channel-wise and group-wise quantization # we need to convert the scales - if layer.weight_block_size is None: + if weight_block_size is None: if scales.nelement() == e: # tensor-wise quantization -> channel-wise quantization # (e, 1, 1) =>(repeat)=> (e, 1, size_n) @@ -239,8 +253,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, # block-wise quantization -> group-wise quantization # (e, size_k // block_size[1], ceil(size_n / block_size[0])) # =>(repeat)=> (e, size_k // block_size[1], size_n) - block_n = layer.weight_block_size[0] - scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2) + if not size_k_first: + scales = scales.permute(0, 2, 1) + block_n = weight_block_size[0] + scales = scales.repeat_interleave(block_n, 2) # size_n may not divisible by block_size[0] scales = scales[..., :size_n].contiguous() @@ -252,6 +268,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = fp8_fused_exponent_bias_into_scales(scales) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales) @@ -302,4 +319,6 @@ def marlin_quant_fp8_torch(weight, group_size): size_n=size_n, group_size=group_size) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index fb557a31393c..81112b27f53a 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Utility functions used for tests and benchmarks""" -from typing import List, Optional +from typing import Optional import numpy as np import torch @@ -64,9 +64,9 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm): def get_weight_perm(num_bits: int): - perm_list: List[int] = [] + perm_list: list[int] = [] for i in range(32): - perm1: List[int] = [] + perm1: list[int] = [] col = i // 4 for block in [0, 1]: for row in [ diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py index 3654268e27af..73feb4264a8b 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -2,7 +2,6 @@ """Utility functions used for tests and benchmarks""" import random -from typing import List import numpy import torch @@ -373,19 +372,19 @@ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): def get_scale_perms_24(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(8): scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) return scale_perm, scale_perm_single def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] + perm_list: list[int] = [] for i in range(32): - perm1: List[int] = [] + perm1: list[int] = [] col = i // 4 col_o = col // 2 for block in [0, 1]: diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py index 176b2947ab09..0123540fc5dd 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import numpy import torch @@ -34,10 +32,10 @@ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): def get_qqq_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): scale_perm_single.extend( [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) @@ -46,9 +44,9 @@ def get_qqq_scale_perms(): # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] + perm_list: list[int] = [] for i in range(32): - perm1: List[int] = [] + perm1: list[int] = [] col = i // 4 for block in [0, 1]: for row in [ diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 6312c3934fd4..e7c95e38e9fd 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Tuple import torch @@ -9,7 +8,7 @@ def per_token_group_quant_mxfp4(x: torch.Tensor, block_k: int, scale_calculation_mode: str = "even" - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: try: from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( fake_quantize_fp4_fp6_per_group_with_scale) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py new file mode 100644 index 000000000000..f292208311e2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +__all__ = [ + "break_fp4_bytes", + "dequantize_to_dtype", +] + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index c7ce3a42c81f..6ba327f3db7a 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """This file is used for /tests and /benchmarks""" +from collections.abc import Mapping from types import MappingProxyType -from typing import List, Mapping, Optional, Tuple +from typing import Optional import numpy import torch @@ -15,7 +16,7 @@ # Normalize the group_shape to the full extent for any dims that are -1 -def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int, +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int, int]): # -1 means full extent return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], @@ -56,9 +57,9 @@ def group_broadcast(t, shape): # (i.e. per-token-per-group) def scaled_quantize( x: torch.Tensor, - group_shape: Tuple[int, int], + group_shape: tuple[int, int], quant_dtype: torch.dtype, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: group_shape = _normalize_quant_group_shape(x, group_shape) assert quant_dtype.is_floating_point, \ "currently `scaled_quantize` only supports floating point dtypes " \ @@ -97,9 +98,9 @@ def scaled_quantize( def scaled_dequantize( x_q: torch.Tensor, x_s: torch.Tensor, - group_shape: Optional[Tuple[int, int]] = None, + group_shape: Optional[tuple[int, int]] = None, out_dtype: torch.dtype = torch.float32, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: if group_shape is not None: group_shape = _normalize_quant_group_shape(x_q, group_shape) @@ -173,8 +174,8 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor, def is_layer_skipped( prefix: str, - ignored_layers: List[str], - fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) + ignored_layers: list[str], + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) ) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8ab45d610053..4b041cff2ecc 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch @@ -81,7 +81,7 @@ def all_close_1d(x: torch.Tensor) -> bool: def convert_to_channelwise( weight_scale: torch.Tensor, - logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]: # Create channelwise buffer weight_scale_channel = torch.empty((sum(logical_widths), 1), dtype=torch.float32, @@ -99,7 +99,7 @@ def convert_to_channelwise( def requantize_with_max_scale( weight: torch.Tensor, weight_scale: torch.Tensor, - logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() @@ -136,7 +136,7 @@ def maybe_create_device_identity(): def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - output_shape: List, **kwargs) -> torch.Tensor: + output_shape: list, **kwargs) -> torch.Tensor: # Fused GEMM_DQ output = ops.cutlass_scaled_mm(qinput, @@ -154,7 +154,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, - output_shape: List) -> torch.Tensor: + output_shape: list) -> torch.Tensor: from vllm.platforms.rocm import on_mi250_mi300 if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: @@ -177,7 +177,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, - output_shape: List) -> torch.Tensor: + output_shape: list) -> torch.Tensor: output = torch._scaled_mm(qinput, weight, out_dtype=out_dtype, @@ -198,7 +198,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, - output_shape: List) -> torch.Tensor: + output_shape: list) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. # For now it has only been validated on ROCm platform. @@ -228,7 +228,7 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, - output_shape: List, + output_shape: list, **kwargs) -> torch.Tensor: # Use unfused DQ due to limitations with scaled_mm @@ -384,7 +384,7 @@ def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: assert weight.dtype == torch.float8_e4m3fn # The bits pattern 10000000(-128) represents zero in e4m3fn # but NaN in e4m3fnuz. So here we set it to 0. diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index d1d3326ac3f2..3db73495827c 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -2,7 +2,7 @@ from functools import cached_property from importlib.util import find_spec -from typing import Dict, Optional, Tuple +from typing import Optional import torch import torch.jit @@ -65,7 +65,7 @@ def forward( bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, - seeded_seqs: Optional[Dict[int, torch.Generator]] = None, + seeded_seqs: Optional[dict[int, torch.Generator]] = None, ) -> torch.Tensor: """Sample token ids using rejection sampling. This accepts or rejects tokens proposed by the draft model using the probability of each token @@ -123,12 +123,13 @@ def forward( # for rejection sampling if self.use_flashinfer and chain_speculative_sampling is not None: batch_size, k, _ = draft_probs.shape - uniform_samples = self._create_uniform_samples( - seeded_seqs, batch_size, k, draft_probs.device) - output_token_ids, accepted_token_num, emitted_token_num \ - = chain_speculative_sampling( - draft_probs, draft_token_ids, uniform_samples, - target_with_bonus_probs) + + (output_token_ids, accepted_token_num, + emitted_token_num) = chain_speculative_sampling( + draft_probs, + draft_token_ids, + target_with_bonus_probs, + ) # num_emitted_tokens returned by flashinfer # does not include the bonus token @@ -161,8 +162,8 @@ def _batch_modified_rejection_sampling( target_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_token_ids: torch.Tensor, # [batch_size, k] - seeded_seqs: Optional[Dict[int, torch.Generator]], - ) -> Tuple[torch.Tensor, torch.Tensor]: + seeded_seqs: Optional[dict[int, torch.Generator]], + ) -> tuple[torch.Tensor, torch.Tensor]: """Perform modified rejection sampling on each sequence. Returns: @@ -194,7 +195,7 @@ def _batch_modified_rejection_sampling( return accepted, recovered_token_ids def _create_uniform_samples(self, - seeded_seqs: Optional[Dict[int, + seeded_seqs: Optional[dict[int, torch.Generator]], batch_size: int, k: int, device: torch.device) -> torch.Tensor: @@ -210,7 +211,7 @@ def _create_uniform_samples(self, a seed. Args: - seeded_seqs : Optional[Dict[int, torch.Generator]] + seeded_seqs : Optional[dict[int, torch.Generator]] A dictionary mapping indices in the batch to `torch.Generator` objects. If `None`, all samples are generated without a seed. @@ -255,22 +256,22 @@ def _get_accepted( target_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_token_ids: torch.Tensor, # [batch_size, k] - seeded_seqs: Optional[Dict[int, torch.Generator]], + seeded_seqs: Optional[dict[int, torch.Generator]], ) -> torch.Tensor: r"""Create bool matrix over the proposed draft tokens. If True, then a token can be accepted, else it should be rejected. - Given {math}`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of - {math}`\hat{x}_{n+1}` given context {math}`x_1, \dots, x_n` according - to the target model, and {math}`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the + Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of + $\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according + to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the same conditional probability according to the draft model, the token is accepted with probability: - :::{math} + $$ \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) - ::: + $$ This implementation does not apply causality. When using the output, if a token is rejected, subsequent tokens should not be used. @@ -313,30 +314,31 @@ def _get_recovered_probs( target model is recovered (within hardware numerics). The probability distribution used in this rejection case is constructed - as follows. Given {math}`q(x|x_1, \dots, x_n)`, the probability of - {math}`x` given context {math}`x_1, \dots, x_n` according to the target - model and {math}`p(x|x_1, \dots, x_n)`, the same conditional probability + as follows. Given $q(x|x_1, \dots, x_n)$, the probability of + $x$ given context $x_1, \dots, x_n$ according to the target + model and $p(x|x_1, \dots, x_n)$, the same conditional probability according to the draft model: - :::{math} + $$ x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ - ::: + $$ - where {math}`(f(x))_+` is defined as: + where $(f(x))_+$ is defined as: - :::{math} + $$ (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} - ::: + $$ See https://github.com/vllm-project/vllm/pull/2336 for a visualization of the draft, target, and recovered probability distributions. Returns a tensor of shape [batch_size, k, vocab_size]. - Note: This batches operations on GPU and thus constructs the recovered - distribution for all tokens, even if they are accepted. This causes - division-by-zero errors, so we use self._smallest_positive_value to - avoid that. This introduces some drift to the distribution. + Note: + This batches operations on GPU and thus constructs the recovered + distribution for all tokens, even if they are accepted. This causes + division-by-zero errors, so we use self._smallest_positive_value to + avoid that. This introduces some drift to the distribution. """ _, k, _ = draft_probs.shape @@ -379,7 +381,7 @@ def _multinomial( probs: torch.Tensor, num_samples: int, k: int, - seeded_seqs: Dict[int, torch.Generator], + seeded_seqs: dict[int, torch.Generator], ) -> torch.Tensor: if num_samples > 1: diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 4c9860006c32..839688e313aa 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -33,7 +33,7 @@ """ import math from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np import torch @@ -69,7 +69,7 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_1d_sincos_pos_embed_from_grid( embed_dim: int, pos: np.ndarray, - version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + version: tuple[int, int] = (2, 0)) -> torch.Tensor: """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) @@ -96,7 +96,7 @@ def get_1d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed_from_grid( embed_dim: int, grid: np.ndarray, - version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + version: tuple[int, int] = (2, 0)) -> torch.Tensor: assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h @@ -114,9 +114,9 @@ def get_2d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed( embed_dim: int, - grid_size: Union[int, Tuple[int, int]], + grid_size: Union[int, tuple[int, int]], cls_token: bool = False, - version: Tuple[int, int] = (2, 0), + version: tuple[int, int] = (2, 0), ) -> torch.Tensor: """ grid_size: int of the grid height and width diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 32c2a2859b49..70463ecd90ae 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -23,7 +23,7 @@ # limitations under the License. """Rotary Positional Embeddings.""" import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -140,7 +140,7 @@ def forward_native( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets @@ -174,7 +174,7 @@ def forward_cuda( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm import _custom_ops as ops # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) @@ -202,7 +202,7 @@ def forward_xpu( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm._ipex_ops import ipex_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device, @@ -232,7 +232,7 @@ def forward_hpu( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) if offsets is not None: @@ -290,7 +290,7 @@ def forward_neuron( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: def _apply_rotary_emb_neuron( x: torch.Tensor, @@ -406,23 +406,23 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, - scaling_factors: Union[List[float], float], + scaling_factors: Union[list[float], float], dtype: torch.dtype, ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] - self.scaling_factors: List[float] = scaling_factors # noqa + self.scaling_factors: list[float] = scaling_factors # noqa super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) # Lazy initialized. - self._scaling_factor_to_offset: Dict[float, int] + self._scaling_factor_to_offset: dict[float, int] def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) - cache_list: List[torch.Tensor] = [] + cache_list: list[torch.Tensor] = [] # offsets to the next cache in a tensor. # Each offset corresponds to the same index in scaling_factors. - offsets: List[int] = [] + offsets: list[int] = [] for scaling_factor in self.scaling_factors: # NOTE(woosuk): self.max_position_embeddings is the original # maximum length before applying the rope scaling. @@ -452,10 +452,44 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: return torch.cat(cache_list, dim=0) @property - def scaling_factor_to_offset(self) -> Dict[float, int]: + def scaling_factor_to_offset(self) -> dict[float, int]: return self._scaling_factor_to_offset +class NTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with fixed and mixed NTK scaling. + https://kexue.fm/archives/9706 """ + + def __init__(self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + mixed_b: Optional[float] = None) -> None: + self.scaling_factor = scaling_factor + self.mixed_b = mixed_b + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + base = self.base * (self.scaling_factor if self.mixed_b is None else 1) + inv_freq = super()._compute_inv_freq(base) + + if self.mixed_b is None: + inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim) + else: + a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim / + 2)**self.mixed_b + lambda_1_m = (a * torch.arange( + 1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp() + inv_freq = inv_freq / lambda_1_m + + return inv_freq + + class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. @@ -512,7 +546,7 @@ def _yarn_find_correction_range( high_rot: int, dim: int, base: float = 10000, - max_position_embeddings: int = 2048) -> Tuple[int, int]: + max_position_embeddings: int = 2048) -> tuple[int, int]: low = math.floor( _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil( @@ -613,8 +647,8 @@ def __init__( base: int, is_neox_style: bool, dtype: torch.dtype, - short_factor: List[float], - long_factor: List[float], + short_factor: list[float], + long_factor: list[float], short_mscale: Optional[float] = None, long_mscale: Optional[float] = None, ): @@ -662,7 +696,7 @@ def __init__( long_short_cache, persistent=False) - def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor: rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) @@ -671,7 +705,7 @@ def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: def _compute_cos_sin_cache( self, max_position_embeddings: int, - rescale_factors: List[float], + rescale_factors: list[float], mscale: float, ) -> torch.Tensor: inv_freq = self._compute_inv_freq(rescale_factors) @@ -688,7 +722,7 @@ def forward( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert key is not None query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) @@ -799,7 +833,7 @@ def forward( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" assert key is not None query_rot = query[..., :self.rotary_dim] @@ -808,8 +842,9 @@ def forward( query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) + if self.cos_sin_cache.device != positions.device: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -929,7 +964,7 @@ def forward( self, query: torch.Tensor, key: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert key is not None self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) query_ = torch.view_as_complex(query.float().reshape( @@ -957,7 +992,7 @@ def __init__( base: int, is_neox_style: bool, dtype: torch.dtype, - mrope_section: Optional[List[int]] = None, + mrope_section: Optional[list[int]] = None, ) -> None: # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get @@ -975,7 +1010,7 @@ def forward( positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward(). Args: @@ -1023,16 +1058,16 @@ def forward( @classmethod def get_input_positions( cls, - input_tokens: List[int], + input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], - video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], - second_per_grid_ts: Optional[List[float]], + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]], context_len: int = 0, seq_len: Optional[int] = None, audio_feature_lengths: Optional[torch.Tensor] = None, use_audio_in_video: bool = False, - ) -> Tuple[List[List[int]], int]: + ) -> tuple[list[list[int]], int]: """Get mrope input positions and delta value.""" image_grid_thw = [] if image_grid_thw is None else image_grid_thw @@ -1058,16 +1093,16 @@ def get_input_positions( @classmethod def get_input_positions_tensor( cls, - input_tokens: List[int], + input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], - second_per_grid_ts: List[float], + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], context_len: int = 0, seq_len: Optional[int] = None, audio_feature_lengths: Optional[torch.Tensor] = None, use_audio_in_video: bool = False, - ) -> Tuple[torch.Tensor, int]: + ) -> tuple[torch.Tensor, int]: from vllm.transformers_utils.config import thinker_uses_mrope if thinker_uses_mrope(hf_config): return cls._omni_get_input_positions_tensor( @@ -1095,14 +1130,14 @@ def get_input_positions_tensor( @classmethod def _vl_get_input_positions_tensor( cls, - input_tokens: List[int], + input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], - second_per_grid_ts: List[float], + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], context_len: int = 0, seq_len: Optional[int] = None, - ) -> Tuple[torch.Tensor, int]: + ) -> tuple[torch.Tensor, int]: """Get mrope input positions and delta value.""" image_token_id = hf_config.image_token_id @@ -1194,16 +1229,16 @@ def _vl_get_input_positions_tensor( @classmethod def _omni_get_input_positions_tensor( cls, - input_tokens: List[int], + input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], - second_per_grid_ts: Optional[List[float]] = None, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: Optional[list[float]] = None, context_len: int = 0, seq_len: Optional[int] = None, audio_feature_lengths: Optional[torch.Tensor] = None, use_audio_in_video: bool = False, - ) -> Tuple[torch.Tensor, int]: + ) -> tuple[torch.Tensor, int]: """Get mrope input positions and delta value (Qwen2.5-Omni version). Differences from MRotaryEmbedding: @@ -1328,7 +1363,7 @@ def _omni_get_input_positions_tensor( place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 pure_audio_len = place_num - 2 added_audio_len = 0 - audio_llm_pos_ids_list: List[torch.Tensor] = [] + audio_llm_pos_ids_list: list[torch.Tensor] = [] for t_chunk in t_index_split_chunk: vision_ntoken_per_chunk = len( t_chunk) * grid_h * grid_w // (spatial_merge_size**2) @@ -1381,7 +1416,7 @@ def _get_llm_pos_ids_for_vision( start_idx: int, vision_idx: int, spatial_merge_size: int, - t_index: List[int], + t_index: list[int], grid_hs: torch.Tensor, grid_ws: torch.Tensor, ) -> torch.Tensor: @@ -1401,8 +1436,8 @@ def _get_llm_pos_ids_for_vision( @staticmethod def _split_list_into_ranges(lst: torch.Tensor, - interval: int) -> List[List[int]]: - ranges: List[List[int]] = [[] + interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)] for num in lst: index = num // interval @@ -1414,7 +1449,7 @@ def get_next_input_positions( mrope_position_delta: int, context_len: int, seq_len: int, - ) -> List[List[int]]: + ) -> list[list[int]]: return [ list( range(context_len + mrope_position_delta, @@ -1437,9 +1472,9 @@ def omni_get_updates_use_audio_in_video( cls, thinker_config: PretrainedConfig, audio_len: int, - video_grid_thw: Union[List[int], torch.Tensor], + video_grid_thw: Union[list[int], torch.Tensor], video_second_per_grid_t: float, - ) -> List[int]: + ) -> list[int]: """Get video prompt updates when `use_audio_in_video` is True. In this case, audio and vision update ids will be split into @@ -1485,7 +1520,185 @@ def omni_get_updates_use_audio_in_video( return updates -_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} +@CustomOp.register("dual_chunk_rotary_embedding") +class DualChunkRotaryEmbedding(CustomOp): + """Rotary positional embedding for Dual Chunk Attention.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + chunk_size: int, + local_size: int, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.chunk_size = chunk_size + self.local_size = local_size + self.dtype = dtype + self.device = torch.device(f"cuda:{torch.cuda.current_device()}") + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, + q_inter_cache) = self._compute_cos_sin_cache() + + self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) + self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) + self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) + self.register_buffer("cos_sin_qc_no_clamp_cache", + qc_no_clamp_cache, + persistent=False) + self.register_buffer("cos_sin_q_inter_cache", + q_inter_cache, + persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + chunk_len = self.chunk_size - self.local_size + q_t = torch.arange(chunk_len, dtype=torch.float) + qc_t = (torch.arange(chunk_len, dtype=torch.float) + + chunk_len).clamp(max=self.chunk_size) + k_t = torch.arange(self.max_position_embeddings, + dtype=torch.float) % chunk_len + + # count from chunk_len, no clamp(self.chunk_size) restriction + qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len + # count from self.chunk_size for q_inter's rope + q_inter_t = torch.arange(chunk_len, + dtype=torch.float) + self.chunk_size + + q_freqs = torch.outer(q_t, inv_freq) + qc_freqs = torch.outer(qc_t, inv_freq) + k_freqs = torch.outer(k_t, inv_freq) + qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq) + q_inter_freqs = torch.outer(q_inter_t, inv_freq) + + q_cos = q_freqs.cos() + q_sin = q_freqs.sin() + qc_cos = qc_freqs.cos() + qc_sin = qc_freqs.sin() + k_cos = k_freqs.cos() + k_sin = k_freqs.sin() + + qc_no_clamp_cos = qc_no_clamp_freqs.cos() + qc_no_clamp_sin = qc_no_clamp_freqs.sin() + q_inter_cos = q_inter_freqs.cos() + q_inter_sin = q_inter_freqs.sin() + + q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + else: + query_pass = None + key_pass = None + + positions_with_offsets = (torch.add(positions, offsets) + if offsets is not None else positions) + key = self._apply_rotary_embedding( + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass) + chunk_len = self.chunk_size - self.local_size + query = self._apply_rotary_embedding( + self.cos_sin_q_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_succ = self._apply_rotary_embedding( + self.cos_sin_qc_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter = self._apply_rotary_embedding( + self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), + query_rot, query_pass) + query_succ_critical = self._apply_rotary_embedding( + self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter_critical = self._apply_rotary_embedding( + self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + + # merge query into one tensor to simplify the interfaces + query = torch.cat(( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1) + return query, key + + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin + + if self.rotary_dim < self.head_size: + hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) + else: + hidden = hidden_rot + return hidden.flatten(-2).squeeze(0) + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" + return s + + +_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} def get_rope( @@ -1494,9 +1707,10 @@ def get_rope( max_position: int, base: int, is_neox_style: bool = True, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -1509,14 +1723,35 @@ def get_rope( rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args, dtype) + rope_scaling_args, dual_chunk_attention_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] - if not rope_scaling: + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + **extra_kwargs) + elif not rope_scaling: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: @@ -1564,6 +1799,14 @@ def get_rope( max_position, base, is_neox_style, scaling_factor, dtype) + elif scaling_type == "ntk": + scaling_factor = rope_scaling["factor"] + mixed_b = rope_scaling.get('mixed_b', None) + rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype, + mixed_b) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] rotary_emb = DynamicNTKScalingRotaryEmbedding( diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 920c0f5a6ec9..32375db0c8f1 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 """A layer that samples the next tokens from the model's outputs.""" import itertools -import warnings +from collections.abc import Iterator from dataclasses import dataclass from importlib.util import find_spec from math import inf -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Optional, Union import msgspec import torch @@ -23,7 +23,6 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): - import flashinfer.sampling # yapf: disable from flashinfer.sampling import ( top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) @@ -32,6 +31,10 @@ else: flashinfer_top_k_top_p_sampling = None +from vllm.logger import init_logger + +logger = init_logger(__name__) + def get_sampler() -> torch.nn.Module: if envs.VLLM_USE_V1: @@ -42,14 +45,14 @@ def get_sampler() -> torch.nn.Module: # (num_token_ids, num_parent_ids) per sequence group. -SampleResultType = List[Tuple[List[int], List[int]]] +SampleResultType = list[tuple[list[int], list[int]]] # Types of temporary data structures used for # computing sample_result -SampleMetadataType = Dict[SamplingType, Tuple[List[int], - List[SequenceGroupToSample]]] -MultinomialSamplesType = Dict[SamplingType, torch.Tensor] -SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] +SampleMetadataType = dict[SamplingType, tuple[list[int], + list[SequenceGroupToSample]]] +MultinomialSamplesType = dict[SamplingType, torch.Tensor] +SampleResultsDictType = dict[int, tuple[list[int], list[int]]] # Encapsulates temporary data structures for computing @@ -76,7 +79,7 @@ class SampleResultArgsType: MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] # Abbreviation of the _sample() return type -SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] +SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] class SamplerOutput( @@ -90,7 +93,7 @@ class SamplerOutput( also has optional fields for device tensors. """ - outputs: List[CompletionSequenceGroupOutput] + outputs: list[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. sampled_token_probs: Optional[torch.Tensor] = None @@ -225,17 +228,19 @@ def forward( ) -> Optional[SamplerOutput]: """ Single-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Pythonize sampling result & logprobs tensor + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Pythonize sampling result & logprobs tensor Multi-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Defer Pythonization of sampling result & logprobs - tensor - * Encapsulate arguments required for deferred Pythonization - in the {class}`SamplerOutput` structure + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Defer Pythonization of sampling result & logprobs + tensor + * Encapsulate arguments required for deferred Pythonization + in the + [`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput] + structure Args: logits: (num_tokens, vocab_size). @@ -350,7 +355,7 @@ def _apply_min_tokens_penalty( have not been generated yet """ # list of indices in logits that will be set to -inf - logits_to_penalize: List[Tuple[int, int]] = [] + logits_to_penalize: list[tuple[int, int]] = [] logits_applied = 0 for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -366,7 +371,7 @@ def _apply_min_tokens_penalty( min_tokens = sampling_params.min_tokens token_ids_to_penalize = sampling_params.all_stop_token_ids if min_tokens > 0 and token_ids_to_penalize: - seqs_to_penalize: List[int] = [] + seqs_to_penalize: list[int] = [] for j, seq_id in enumerate(seq_ids): seq_data = seq_group.seq_data[seq_id] if len(seq_data.output_token_ids_array) < min_tokens: @@ -436,7 +441,7 @@ def _apply_min_p( def _greedy_sample( - selected_seq_groups: List[SequenceGroupToSample], + selected_seq_groups: list[SequenceGroupToSample], samples: torch.Tensor, ) -> SampleResultType: """Run greedy sampling on a given samples. @@ -471,7 +476,7 @@ def _greedy_sample( def _random_sample( - selected_seq_groups: List[SequenceGroupToSample], + selected_seq_groups: list[SequenceGroupToSample], random_samples: torch.Tensor, ) -> SampleResultType: """Run random sampling on a given samples. @@ -522,7 +527,7 @@ def _random_sample( def _multinomial( probs: torch.Tensor, num_samples: int, - seq_groups: Optional[List[SequenceGroupToSample]] = None, + seq_groups: Optional[list[SequenceGroupToSample]] = None, ) -> torch.Tensor: if num_samples > 1: probs = probs.repeat_interleave(num_samples, dim=0) @@ -543,39 +548,16 @@ def _multinomial( def _top_k_top_p_multinomial_with_flashinfer( probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, - num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]): - max_top_k_round = 32 + num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): if num_samples > 1: probs = probs.repeat_interleave(num_samples, dim=0) top_ks = top_ks.repeat_interleave(num_samples) top_ps = top_ps.repeat_interleave(num_samples) - batch_size = probs.shape[0] - uniform_samples = torch.empty((max_top_k_round, batch_size), - device=probs.device) - if seq_groups is None: - uniform_samples.uniform_() - else: - sample_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - stride = len(seq_ids) * num_samples - assert seq_group.generator is not None - uniform_samples[:, sample_idx:sample_idx + - stride].uniform_(generator=seq_group.generator) - sample_idx += stride - batch_next_token_ids, success = flashinfer_top_k_top_p_sampling( + batch_next_token_ids = flashinfer_top_k_top_p_sampling( probs, - uniform_samples, top_ks, top_ps, ) - if not success.all(): - warnings.warn("FlashInfer rejection sampling failed, fallback.", - stacklevel=1) - probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks) - probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps) - batch_next_token_ids = flashinfer.sampling.sampling_from_probs( - probs, uniform_samples[0]) return batch_next_token_ids.view(-1, num_samples) @@ -648,7 +630,7 @@ def _sample_with_torch( tensors required for Pythonization ''' - categorized_seq_group_ids: Dict[SamplingType, List[int]] = { + categorized_seq_group_ids: dict[SamplingType, list[int]] = { t: [] for t in SamplingType } @@ -711,19 +693,14 @@ def _sample_with_torch( seq_groups) if flashinfer_top_k_top_p_sampling is not None: - multinomial_samples[ - sampling_type] = _top_k_top_p_multinomial_with_flashinfer( - probs[long_sample_indices], - sampling_tensors.top_ks[long_sample_indices], - sampling_tensors.top_ps[long_sample_indices], - max_n_in_batch, - seq_groups_arg, - ) - else: - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], - max_n_in_batch, - seq_groups=seq_groups_arg) + logger.warning("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") + + multinomial_samples[sampling_type] = _multinomial( + probs[long_sample_indices], + max_n_in_batch, + seq_groups=seq_groups_arg) if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. @@ -812,7 +789,7 @@ def get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sample_results: SampleResultType, -) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: +) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]: """Return sample logprobs and prompt logprobs. The logic consists of 3 parts. @@ -841,9 +818,9 @@ def get_logprobs( """ # The index of query token to calculate logprobs. It includes both # prompt and sample logprob indices. - query_indices: List[int] = [] + query_indices: list[int] = [] # The next token ids to get the logprob value from. - next_token_ids: List[int] = [] + next_token_ids: list[int] = [] # The largest requested number of logprobs. We find logprobs as many as the # largest num logprobs in this API. If every logprobs is None, it will be # set to -1. @@ -925,8 +902,8 @@ def get_logprobs( ranks = ranks.to('cpu') # Find prompt/sample logprobs. - prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] - sample_logprobs_per_seq_group: List[SampleLogprobs] = [] + prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = [] + sample_logprobs_per_seq_group: list[SampleLogprobs] = [] top_logprob_idx = 0 selected_logprobs_idx = 0 @@ -977,7 +954,7 @@ def _get_prompt_logprob_if_needed( for idx, token_id in enumerate(next_prompt_tokens): # Calculate the prompt logprob of the real prompt tokens. # {token_id: (logprob, rank_from_vocab)} - prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { + prompt_logprobs_dict: dict[int, tuple[float, int]] = { token_id: (selected_logprob_items[idx], rank_items[idx]) } @@ -1009,7 +986,7 @@ def _get_prompt_logprob_if_needed( def _get_sampled_logprob_if_needed( seq_group: SequenceGroupToSample, - sample_result: Tuple[List[int], List[int]], + sample_result: tuple[list[int], list[int]], selected_logprobs: torch.Tensor, ranks: torch.Tensor, top_token_ids: torch.Tensor, @@ -1130,9 +1107,9 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( maybe_deferred_sample_results: MaybeDeferredSampleResultType, sampling_metadata: SamplingMetadata, - prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], - sample_logprobs: Optional[List[SampleLogprobs]], - on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, + prompt_logprobs: Optional[list[Optional[PromptLogprobs]]], + sample_logprobs: Optional[list[SampleLogprobs]], + on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], skip_sampler_cpu_output: bool = False, ) -> SamplerOutput: @@ -1144,7 +1121,7 @@ def _build_sampler_output( allows post-processing without copies to CPU/serialization, e.g. in speculative decoding rejection sampling. """ - sampler_output: List[CompletionSequenceGroupOutput] = [] + sampler_output: list[CompletionSequenceGroupOutput] = [] if skip_sampler_cpu_output: assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) @@ -1166,7 +1143,7 @@ def _build_sampler_output( prompt_logprobs, sample_logprobs): seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result - seq_outputs: List[SequenceOutput] = [] + seq_outputs: list[SequenceOutput] = [] for parent_id, next_token_id, logprobs in zip( parent_ids, next_token_ids, group_sample_logprobs): seq_outputs.append( diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 54fd43fc6592..969cd59b57cc 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod -from typing import Dict, Optional, Union +from typing import Optional, Union import torch import torch.jit @@ -253,6 +253,6 @@ def forward( bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, - seeded_seqs: Optional[Dict[int, torch.Generator]] = None, + seeded_seqs: Optional[dict[int, torch.Generator]] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 527a301cd8e2..a14c86148e73 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -93,29 +93,27 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): Evaluates and returns a mask of accepted tokens based on the posterior probabilities. - Parameters: - ---------- - target_probs : torch.Tensor - A tensor of shape (batch_size, k, vocab_size) representing - the probabilities of each token in the vocabulary for each - position in the proposed sequence. This is the distribution - generated by the target model. - draft_token_ids : torch.Tensor - A tensor of shape (batch_size, k) representing the proposed - token ids. + Args: + target_probs (torch.Tensor): A tensor of shape + (batch_size, k, vocab_size) representing the probabilities of + each token in the vocabulary for each position in the proposed + sequence. This is the distribution generated by the target + model. + draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k) + representing the proposed token ids. A draft token_id x_{n+k} is accepted if it satisfies the following condition - :::{math} + $$ p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > \min \left( \epsilon, \delta * \exp \left( -H(p_{\text{original}}( \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) - ::: + $$ - where {math}`p_{\text{original}}` corresponds to target_probs - and {math}`\epsilon` and {math}`\delta` correspond to hyperparameters + where $p_{\text{original}}$ corresponds to target_probs + and $\epsilon$ and $\delta$ correspond to hyperparameters specified using self._posterior_threshold and self._posterior_alpha This method computes the posterior probabilities for the given @@ -126,13 +124,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): returns a boolean mask indicating which tokens can be accepted. Returns: - ------- - torch.Tensor - A boolean tensor of shape (batch_size, k) where each element - indicates whether the corresponding draft token has been accepted - or rejected. True indicates acceptance and false indicates - rejection. - + torch.Tensor: A boolean tensor of shape (batch_size, k) where each + element indicates whether the corresponding draft token has + been accepted or rejected. True indicates acceptance and false + indicates rejection. """ device = target_probs.device candidates_prob = torch.gather( @@ -156,17 +151,14 @@ def _get_recovered_token_ids(self, target_probs): The recovered token ids will fill the first unmatched token by the target token. - Parameters - ---------- - target_probs : torch.Tensor - A tensor of shape (batch_size, k, vocab_size) containing - the target probability distribution - - Returns - ------- - torch.Tensor - A tensor of shape (batch_size, k) with the recovered token - ids which are selected from target probs. + Args: + target_probs (torch.Tensor): A tensor of shape + (batch_size, k, vocab_size) containing the target probability + distribution. + + Returns: + torch.Tensor: A tensor of shape (batch_size, k) with the recovered + token ids which are selected from target probs. """ max_indices = torch.argmax(target_probs, dim=-1) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 751b86787c7b..18783d0d7785 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Utility methods for model layers.""" -from typing import Callable, Optional, Tuple +from typing import Callable, Optional import torch @@ -13,7 +13,7 @@ def get_token_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, num_seqs: int, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: # Compute the bin counts for the tokens. # vocab_size + 1 for padding. bin_counts = torch.zeros((num_seqs, vocab_size + 1), diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index d5eaeec1ae24..46d2075af99d 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -25,7 +26,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """Create weights for embedding layer.""" @@ -141,7 +142,7 @@ def get_masked_input_and_mask( input_: torch.Tensor, org_vocab_start_index: int, org_vocab_end_index: int, num_org_vocab_padding: int, added_vocab_start_index: int, - added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]: # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast org_vocab_mask = (input_ >= org_vocab_start_index) & ( @@ -298,7 +299,7 @@ def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, org_vocab_start_index, org_vocab_end_index, added_vocab_start_index, added_vocab_end_index) - def get_sharded_to_full_mapping(self) -> Optional[List[int]]: + def get_sharded_to_full_mapping(self) -> Optional[list[int]]: """Get a mapping that can be used to reindex the gathered logits for sampling. @@ -312,9 +313,9 @@ def get_sharded_to_full_mapping(self) -> Optional[List[int]]: if self.tp_size < 2: return None - base_embeddings: List[int] = [] - added_embeddings: List[int] = [] - padding: List[int] = [] + base_embeddings: list[int] = [] + added_embeddings: list[int] = [] + padding: list[int] = [] for tp_rank in range(self.tp_size): shard_indices = self._get_indices(self.num_embeddings_padded, self.org_vocab_size_padded, diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 92a0b0923b6e..a443a652d8a3 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + from torch import nn -from vllm.config import LoadConfig, LoadFormat, VllmConfig +from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.bitsandbytes_loader import ( BitsAndBytesModelLoader) @@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: return DefaultModelLoader(load_config) -def get_model(*, vllm_config: VllmConfig) -> nn.Module: +def get_model(*, + vllm_config: VllmConfig, + model_config: Optional[ModelConfig] = None) -> nn.Module: loader = get_model_loader(vllm_config.load_config) - return loader.load_model(vllm_config=vllm_config) + if model_config is None: + model_config = vllm_config.model_config + return loader.load_model(vllm_config=vllm_config, + model_config=model_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index f17cab05c25d..010dd515784a 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -18,6 +18,7 @@ def download_model(self, model_config: ModelConfig) -> None: raise NotImplementedError @abstractmethod - def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, *, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: """Load a model with the given configurations.""" raise NotImplementedError diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 57189bfafc06..0d83c8d53419 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -6,7 +6,8 @@ import itertools import math import os -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple +from collections.abc import Generator +from typing import Any, Callable, Optional import numpy as np import torch @@ -34,6 +35,7 @@ download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.models import is_pooling_model from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -49,21 +51,21 @@ def __init__(self, load_config: LoadConfig): super().__init__(load_config) # Save the module names without sharding. - self.unsharded_weights_modules: List[str] = [] + self.unsharded_weights_modules: list[str] = [] # Save the module names that are sharded by column. - self.column_sharded_weights_modules: List[str] = [] + self.column_sharded_weights_modules: list[str] = [] # Store all module names (from transformers) that support # BNB quantization. - self.target_modules: List[str] = [] + self.target_modules: list[str] = [] # mapping weight names from transformers to vllm. self.weight_mapper: Callable = lambda name: name def _get_weight_files( self, model_name_or_path: str, - allowed_patterns: List[str], + allowed_patterns: list[str], revision: Optional[str] = None, - ) -> Tuple[str, List[str], str]: + ) -> tuple[str, list[str], str]: """Retrieve weight files. Download the files if necessary. Return the weight files and the file pattern.""" @@ -95,7 +97,7 @@ def _get_weight_files( f"No model weights found in: `{model_name_or_path}`") def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> Tuple[List[str], bool]: + revision: Optional[str]) -> tuple[list[str], bool]: """Prepare weight files for the model.""" allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] @@ -132,6 +134,16 @@ def _prepare_weights(self, model_name_or_path: str, return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + def _maybe_pool_model(module_name:str): + # For pool model, we need to add the prefix `model.` + # for the weight name if possible. + if self.is_pool_model and self.target_modules[0]. \ + startswith("model.") and not module_name.startswith( + "model."): + return "model."+module_name + + return module_name + if use_safetensors: iterator = safetensors_weights_iterator( hf_weights_files, @@ -147,6 +159,9 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): # mapping weight names from transformers to vllm while preserving # original names. mapped_name = self.weight_mapper(org_name) + mapped_name=_maybe_pool_model(mapped_name) + + yield org_name, mapped_name, param def _get_quantized_weights_iterator( @@ -155,7 +170,7 @@ def _get_quantized_weights_iterator( revision: Optional[str], pre_quant: bool, load_8bit: bool, - ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, as well as the quantization state dictionary.""" @@ -175,7 +190,7 @@ def _get_quantized_weights_iterator( hf_weights_files, use_safetensors = self._prepare_weights( model_name_or_path, revision) - quant_state_dict: Dict[str, Any] = {} + quant_state_dict: dict[str, Any] = {} if pre_quant: if load_8bit: @@ -257,7 +272,7 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, # Closure to parse quant_state for each prequant weight def _parse_quant_state(param_name: str, - temp_state_dict: Dict) -> QuantState: + temp_state_dict: dict) -> QuantState: quant_state = {} for k in temp_state_dict: if param_name + "." in k: @@ -404,7 +419,7 @@ def _load_weights(self, model_config: ModelConfig, raise AttributeError( f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet. No 'packed_modules_mapping' found.") - + self.is_pool_model=is_pooling_model(model) self.modules_mapping = ParamMapping( copy.deepcopy(model.packed_modules_mapping)) @@ -415,7 +430,7 @@ def _load_weights(self, model_config: ModelConfig, # Modules whose weights might have fused on disk # we need their output_sizes to make shard in flight correctly with TP - self.maybe_fused_weights_modules: Dict[str, List[int]] = {} + self.maybe_fused_weights_modules: dict[str, list[int]] = {} self._get_bnb_target_modules(model) for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights @@ -480,7 +495,7 @@ def _load_weights(self, model_config: ModelConfig, torch.cuda.empty_cache() param_dict = dict(model.named_parameters()) - stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + stacked_quant_state_dict: dict[str, dict[int, Any]] = {} # TODO: Change this lazy import to normal import # after the checks are updated to run on a new version from vllm.model_executor.models.utils import is_pp_missing_parameter @@ -554,10 +569,9 @@ def _load_weights(self, model_config: ModelConfig, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config - with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index c8bc4aecaecf..29a6e0af4bc6 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -3,15 +3,16 @@ import glob import os import time -from typing import Generator, Iterable, List, Optional, Tuple, cast +from collections.abc import Generator, Iterable +from typing import Optional, cast import huggingface_hub import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm import envs from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig -from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( @@ -63,7 +64,7 @@ def _maybe_download_from_modelscope( Returns the path to the downloaded model, or None if the model is not downloaded from ModelScope.""" - if VLLM_USE_MODELSCOPE: + if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. # pylint: disable=C. @@ -92,7 +93,7 @@ def _prepare_weights( revision: Optional[str], fall_back_to_pt: bool, allow_patterns_overrides: Optional[list[str]], - ) -> Tuple[str, List[str], bool]: + ) -> tuple[str, list[str], bool]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" @@ -138,7 +139,7 @@ def _prepare_weights( else: hf_folder = model_name_or_path - hf_weights_files: List[str] = [] + hf_weights_files: list[str] = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) if len(hf_weights_files) > 0: @@ -173,7 +174,7 @@ def _prepare_weights( def _get_weights_iterator( self, source: "Source" - ) -> Generator[Tuple[str, torch.Tensor], None, None]: + ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt, @@ -238,7 +239,7 @@ def get_all_weights( self, model_config: ModelConfig, model: nn.Module, - ) -> Generator[Tuple[str, torch.Tensor], None, None]: + ) -> Generator[tuple[str, torch.Tensor], None, None]: primary_weights = DefaultModelLoader.Source( model_config.model, model_config.revision, @@ -263,13 +264,14 @@ def download_model(self, model_config: ModelConfig) -> None: fall_back_to_pt=True, allow_patterns_overrides=None) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, + model_config=model_config) weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 5047a161f3f9..0e2f0be1ec26 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -22,9 +22,9 @@ def __init__(self, load_config: LoadConfig): def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index ace1cd371286..806004bf9604 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Dict, Generator, Tuple +from collections.abc import Generator import gguf import torch @@ -84,17 +84,17 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): return gguf_to_hf_name_map def _get_weights_iterator( - self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] - ) -> Generator[Tuple[str, torch.Tensor], None, None]: + self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str] + ) -> Generator[tuple[str, torch.Tensor], None, None]: return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index e4a48483764a..e65d16cae76c 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -5,7 +5,7 @@ import copy import importlib import os -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch import torch.nn as nn @@ -33,7 +33,7 @@ } # Models supported by Neuron. -_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { +_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = { "LlamaForCausalLM": ("transformers_neuronx.llama.model", "LlamaForSampling", "LlamaForCausalLM"), "MistralForCausalLM": ("transformers_neuronx.mistral.model", @@ -146,7 +146,7 @@ def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[list[SamplerOutput]]: batch_size, num_steps = logits.shape seq_ids = [ seq_id for sg in sampling_metadata.seq_groups @@ -188,7 +188,7 @@ def _get_model_architecture(config: PretrainedConfig) -> str: f"{list(_NEURON_SUPPORTED_MODELS.keys())}") -def _get_buckets(env: str, default_value: List[int]) -> List[int]: +def _get_buckets(env: str, default_value: list[int]) -> list[int]: env_value = os.getenv(env) if env_value is None: return default_value @@ -464,7 +464,7 @@ def get_neuron_eagle_speculation_model(model_config: ModelConfig, draft_model.eval() - token_tree: Dict[int, List[int]] = ast.literal_eval( + token_tree: dict[int, list[int]] = ast.literal_eval( speculation_config.speculative_token_tree) speculation_model = EagleSpeculativeDecoder(draft_model.model, diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index f879c99ac2ef..557feea46a90 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -9,7 +9,7 @@ import multiprocessing import os import shutil -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch import torch.nn as nn @@ -46,8 +46,11 @@ } # Models supported by Neuronx distributed for inference. -_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = { +_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = { "LlamaForCausalLM": + ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "MistralForCausalLM": ("neuronx_distributed_inference.models.llama.modeling_llama", "NeuronLlamaForCausalLM"), "DbrxForCausalLM": @@ -84,16 +87,29 @@ def forward( input_block_ids: torch.Tensor, sampling_params: torch.Tensor, ) -> torch.Tensor: + # sort block ids sequentially for perf/neuron support reasons + sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + output = self.model(input_ids, attention_mask=None, position_ids=positions, - seq_ids=input_block_ids, + seq_ids=sorted_input_block_ids, sampling_params=sampling_params) # on-device sampling if self.config.neuron_config.on_device_sampling_config: - return output.hidden_states + output = output.hidden_states else: - return output.logits[:, -1, :] + output = output.logits[:, -1, :] + + restored_indices = torch.argsort(sorted_indices) + if input_block_ids.shape[0] != 1: + output = torch.index_select(output, 0, restored_indices) + + return output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: @@ -143,8 +159,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): config = neuronx_model_cls.get_config_cls()( neuron_config, load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5( - config.to_json_string().encode('utf-8')).hexdigest() + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): @@ -263,8 +279,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): config = neuronx_model_cls.get_config_cls()( neuron_config, load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5( - config.to_json_string().encode('utf-8')).hexdigest() + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): @@ -337,14 +353,26 @@ def forward( input_block_ids: torch.Tensor, sampling_params: torch.Tensor, ) -> torch.Tensor: + # sort block ids sequentially for perf/neuron support reasons + sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + output = self.model(input_ids, attention_mask=None, position_ids=positions, - seq_ids=input_block_ids, + seq_ids=sorted_input_block_ids, sampling_params=sampling_params) + restored_indices = torch.argsort(sorted_indices) + # CTX encoding if (positions[:, 0]).sum().item() == 0: - return output.fused_outputs[0][:, 0:1] + output = output.fused_outputs[0][:, 0:1] + if input_block_ids.shape[0] != 1: + output = torch.index_select(output, 0, restored_indices) + return output # Fused Spec (Generation) accepted_tokens_with_padding = output.fused_outputs[0] @@ -359,13 +387,17 @@ def forward( -1) >= generated_token_counts accepted_tokens_with_padding[mask] = -1 + if input_block_ids.shape[0] != 1: + accepted_tokens_with_padding = torch.index_select( + accepted_tokens_with_padding, 0, restored_indices) + return accepted_tokens_with_padding def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[list[SamplerOutput]]: batch_size, num_steps = logits.shape seq_ids = [ seq_id for sg in sampling_metadata.seq_groups @@ -413,6 +445,10 @@ def load_weights(self, model_name_or_path: str, draft_neuron_config.speculation_length = 0 draft_neuron_config.trace_tokengen_model = True draft_neuron_config.enable_fused_speculation = False + if getattr(config.neuron_config, "draft_model_modules_to_not_convert", + None): + draft_neuron_config.modules_to_not_convert = ( + draft_neuron_config.draft_model_modules_to_not_convert) if config.neuron_config.enable_eagle_speculation: draft_neuron_config.is_eagle_draft = True draft_neuron_config.sequence_parallel_enabled = False @@ -426,8 +462,8 @@ def load_weights(self, model_name_or_path: str, config.fused_spec_config = fused_spec_config self.config.neuron_config = neuron_config - hashed_config = hashlib.md5( - config.to_json_string().encode('utf-8')).hexdigest() + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): @@ -499,7 +535,7 @@ def _get_default_neuron_config(model_config: ModelConfig, max_context_length=scheduler_config.max_model_len, seq_len=scheduler_config.max_model_len, enable_bucketing=True, - is_continuous_batching=(batch_size > 1), + is_continuous_batching=True, quantized=False, torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], padding_side="right", @@ -517,6 +553,7 @@ def _get_default_speculation_config(model_config: ModelConfig, args.""" neuron_config = dict( tp_degree=parallel_config.tensor_parallel_size, + ctx_batch_size=1, batch_size=scheduler_config.max_num_seqs, max_context_length=scheduler_config.max_model_len, seq_len=scheduler_config.max_model_len, @@ -524,6 +561,7 @@ def _get_default_speculation_config(model_config: ModelConfig, trace_tokengen_model=False, enable_fused_speculation=True, enable_bucketing=True, + is_continuous_batching=True, quantized=False, torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], on_device_sampling_config=dict( diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 1fbb5ca56644..9f1022c25925 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -2,7 +2,8 @@ # ruff: noqa: SIM117 import glob import os -from typing import Generator, List, Optional, Tuple +from collections.abc import Generator +from typing import Optional import torch from torch import nn @@ -48,7 +49,7 @@ def __init__(self, load_config: LoadConfig): os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> List[str]: + revision: Optional[str]) -> list[str]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" @@ -87,7 +88,7 @@ def _prepare_weights(self, model_name_or_path: str, def _get_weights_iterator( self, model_or_path: str, - revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: + revision: str) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_weights_files = self._prepare_weights(model_or_path, revision) return runai_safetensors_weights_iterator( @@ -99,11 +100,10 @@ def download_model(self, model_config: ModelConfig) -> None: """Download model if necessary""" self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: """Perform streaming of the model to destination""" device_config = vllm_config.device_config - model_config = vllm_config.model_config - target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 152a3d699726..78bca89f0015 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -3,7 +3,8 @@ import collections import glob import os -from typing import Any, Dict, Generator, List, Optional, Tuple +from collections.abc import Generator +from typing import Any, Optional import torch from torch import nn @@ -48,12 +49,12 @@ def __init__(self, @staticmethod def _filter_subtensors( - tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: + tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: """ Filter out all tensors that share the same memory or a subset of the memory of another tensor. """ - same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( + same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = ( collections.defaultdict(list)) for key, tensor in tensors.items(): if tensor.numel(): @@ -63,7 +64,7 @@ def _filter_subtensors( def get_end_ptr(tensor: torch.Tensor) -> int: return tensor.view(-1)[-1].data_ptr() + tensor.element_size() - result: Dict[str, torch.Tensor] = {} + result: dict[str, torch.Tensor] = {} for group in same_storage_groups.values(): for k, t in group: a, b = t.data_ptr(), get_end_ptr(t) @@ -99,9 +100,9 @@ def _prepare_weights(self, model_name_or_path: str, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) from vllm.distributed import get_tensor_model_parallel_rank @@ -160,7 +161,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: return model.eval() def iterate_over_files( - self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]: + self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: if self.runai_model_streamer: yield from runai_safetensors_weights_iterator(paths, True) else: @@ -188,7 +189,7 @@ def save_model( part_idx = 0 total_size = 0 state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) - state_dict_part: Dict[str, torch.Tensor] = {} + state_dict_part: dict[str, torch.Tensor] = {} for key, tensor in state_dict.items(): param_size = tensor.nelement() * tensor.element_size() if max_size is not None and total_size + param_size > max_size: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 117251ccf05f..4c4502284a6a 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -1,23 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import contextlib +import contextvars import dataclasses import io +import json import os -import re +import threading import time +from collections.abc import Generator from dataclasses import dataclass from functools import partial -from typing import BinaryIO, Generator, Optional, Tuple, Type, Union +from typing import Any, BinaryIO, Optional, Union +import regex as re import torch from torch import nn +from torch.utils._python_dispatch import TorchDispatchMode from transformers import PretrainedConfig import vllm.envs as envs from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -57,9 +62,79 @@ logger = init_logger(__name__) +class MetaTensorMode(TorchDispatchMode): + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + if func._schema.name == "aten::empty" and "device" not in kwargs: + kwargs["device"] = "meta" + + return func(*args, **kwargs) + + +def meta_tensor_mode(loading_code=None, ): + + if loading_code is None: + return _NoInitOrTensorImpl.context_manager() + elif callable(loading_code): + with _NoInitOrTensorImpl.context_manager(): + return loading_code() + else: + raise TypeError( + "expected a callable to evaluate," + " or None if being used as a context manager;" + f' got an object of type "{type(loading_code).__name__}" instead.') + + +class _NoInitOrTensorImpl: + _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm) + _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES) + + is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", + default=False) + _count_active: int = 0 + _count_active_lock = threading.Lock() + + @classmethod + @contextlib.contextmanager + def context_manager(cls): + if cls.is_active.get(): + yield + return + + with cls._count_active_lock: + cls._count_active += 1 + if cls._count_active == 1: + for mod in cls._MODULES: + mod.reset_parameters = cls._disable(mod.reset_parameters) + + reset_token = cls.is_active.set(True) + + try: + with MetaTensorMode(): + yield + finally: + cls.is_active.reset(reset_token) + with cls._count_active_lock: + cls._count_active -= 1 + if cls._count_active == 0: + for mod, original in cls._MODULE_ORIGINALS: + mod.reset_parameters = original + + @staticmethod + def _disable(func): + + def wrapper(*args, **kwargs): + if not _NoInitOrTensorImpl.is_active.get(): + return func(*args, **kwargs) + + return wrapper + + @dataclass class TensorizerConfig: - tensorizer_uri: str + tensorizer_uri: Union[str, None] = None vllm_tensorized: Optional[bool] = False verify_hash: Optional[bool] = False num_readers: Optional[int] = None @@ -67,15 +142,32 @@ class TensorizerConfig: s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None s3_endpoint: Optional[str] = None - model_class: Optional[Type[torch.nn.Module]] = None + model_class: Optional[type[torch.nn.Module]] = None hf_config: Optional[PretrainedConfig] = None dtype: Optional[Union[str, torch.dtype]] = None + lora_dir: Optional[str] = None _is_sharded: bool = False def __post_init__(self): # check if the configuration is for a sharded vLLM model self._is_sharded = isinstance(self.tensorizer_uri, str) \ and re.search(r'%0\dd', self.tensorizer_uri) is not None + if not self.tensorizer_uri and not self.lora_dir: + raise ValueError("tensorizer_uri must be provided.") + if not self.tensorizer_uri and self.lora_dir: + self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" + assert self.tensorizer_uri is not None, ("tensorizer_uri must be " + "provided.") + self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) + self.lora_dir = self.tensorizer_dir + + @classmethod + def as_dict(cls, *args, **kwargs) -> dict[str, Any]: + cfg = TensorizerConfig(*args, **kwargs) + return dataclasses.asdict(cfg) + + def to_dict(self) -> dict[str, Any]: + return dataclasses.asdict(self) def _construct_tensorizer_args(self) -> "TensorizerArgs": tensorizer_args = { @@ -139,7 +231,9 @@ class TensorizerArgs: Args: tensorizer_uri: Path to serialized model tensors. Can be a local file - path or a S3 URI. + path or a S3 URI. This is a required field unless lora_dir is + provided and the config is meant to be used for the + `tensorize_lora_adapter` function. vllm_tensorized: If True, indicates that the serialized model is a vLLM model. This is used to determine the behavior of the TensorDeserializer when loading tensors from a serialized model. @@ -157,7 +251,7 @@ class TensorizerArgs: encryption_keyfile: File path to a binary file containing a binary key to use for decryption. `None` (the default) means no decryption. See the example script in - examples/other/tensorize_vllm_model.py. + examples/others/tensorize_vllm_model.py. s3_access_key_id: The access key for the S3 bucket. Can also be set via the S3_ACCESS_KEY_ID environment variable. s3_secret_access_key: The secret access key for the S3 bucket. Can also @@ -213,6 +307,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: group.add_argument( "--tensorizer-uri", + type=str, help="Path to serialized model tensors. Can be a local file path," " or an HTTP(S) or S3 URI.", ) @@ -225,6 +320,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) group.add_argument( "--encryption-keyfile", + type=str, default=None, help="The file path to a binary file containing a binary key to " "use for decryption. Can be a file path or S3 network URI.") @@ -238,18 +334,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "and model size. This greatly increases performance.") group.add_argument( "--s3-access-key-id", + type=str, default=None, help="The access key for the S3 bucket. Can also be set via the " "S3_ACCESS_KEY_ID environment variable.", ) group.add_argument( "--s3-secret-access-key", + type=str, default=None, help="The secret access key for the S3 bucket. Can also be set via " "the S3_SECRET_ACCESS_KEY environment variable.", ) group.add_argument( "--s3-endpoint", + type=str, default=None, help="The endpoint for the S3 bucket. Can also be set via the " "S3_ENDPOINT_URL environment variable.", @@ -290,10 +389,10 @@ def _init_model(self): model_args.torch_dtype = self.tensorizer_config.dtype assert self.tensorizer_config.model_class is not None # TODO: Do we need to consider old-style model class? - with no_init_or_tensor(), set_current_vllm_config(self.vllm_config, - check_compile=True): + with meta_tensor_mode(), set_current_vllm_config(self.vllm_config, + check_compile=True): return self.tensorizer_config.model_class( - vllm_config=self.vllm_config, ) + vllm_config=self.vllm_config) def _resize_lora_embeddings(self): """Modify LoRA embedding layers to use bigger tensors @@ -365,12 +464,12 @@ def deserialize(self): def tensorizer_weights_iterator( tensorizer_args: "TensorizerArgs" -) -> Generator[Tuple[str, torch.Tensor], None, None]: +) -> Generator[tuple[str, torch.Tensor], None, None]: logger.warning("Deserializing HuggingFace models is not optimized for " "loading on vLLM, as tensorizer is forced to load to CPU. " "Consider deserializing a vLLM model instead for faster " "load times. See the " - "examples/other/tensorize_vllm_model.py example script " + "examples/others/tensorize_vllm_model.py example script " "for serializing vLLM models.") deserializer_args = tensorizer_args.deserializer_params @@ -461,8 +560,73 @@ def tensorize_vllm_model(engine_args: EngineArgs, ) as stream: stream.write(encryption_params.key) - engine = LLMEngine.from_engine_args(engine_args) - engine.model_executor.collective_rpc( - "save_tensorized_model", - kwargs=dict(tensorizer_config=tensorizer_config), - ) + from vllm import LLMEngine + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + + if not envs.VLLM_USE_V1: + engine = LLMEngine.from_engine_args(engine_args) + engine.model_executor.collective_rpc( + "save_tensorized_model", + kwargs=dict(tensorizer_config=tensorizer_config), + ) + else: + engine = V1LLMEngine.from_vllm_config(engine_config) + engine.collective_rpc( + "save_tensorized_model", + kwargs=dict(tensorizer_config=tensorizer_config), + ) + + +def tensorize_lora_adapter(lora_path: str, + tensorizer_config: TensorizerConfig): + """ + Uses tensorizer to serialize a LoRA adapter. Assumes that the files + needed to load a LoRA adapter are a safetensors-format file called + adapter_model.safetensors and a json config file called adapter_config.json. + + Serializes the files in the tensorizer_config.lora_dir + """ + import safetensors + + from vllm.lora.utils import get_adapter_absolute_path + + lora_dir = get_adapter_absolute_path(lora_path) + + tensor_path = config_path = "" + + for file in os.listdir(lora_dir): + if file.startswith("adapter_model"): + tensor_path = lora_dir + "/" + file + if file.startswith("adapter_config"): + config_path = lora_dir + "/" + file + if tensor_path and config_path: + break + + if tensor_path.endswith(".safetensors"): + tensors = safetensors.torch.load_file(tensor_path) + elif tensor_path.endswith(".bin"): + tensors = torch.load(tensor_path) + else: + raise ValueError("Unsupported file: %s", tensor_path) + + with open(config_path) as f: + config = json.load(f) + + tensorizer_args = tensorizer_config._construct_tensorizer_args() + + with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json", + mode="wb+", + **tensorizer_args.stream_params) as f: + + f.write(json.dumps(config).encode("utf-8")) + + lora_uri = (f"{tensorizer_config.lora_dir}" + f"/adapter_model.tensors") + with open_stream(lora_uri, mode="wb+", + **tensorizer_args.stream_params) as f: + serializer = TensorSerializer(f) + serializer.write_state_dict(tensors) + serializer.close() + + logger.info("Successfully serialized LoRA files to %s", + str(tensorizer_config.lora_dir)) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 7cf3940ab644..2afe2b59e2f9 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # ruff: noqa: SIM117 import copy -from typing import Generator, Tuple +from collections.abc import Generator +from typing import Union import torch from torch import nn @@ -36,7 +37,7 @@ def _verify_config(self, model_config: ModelConfig, self.tensorizer_config.verify_with_parallel_config(parallel_config) def _get_weights_iterator( - self, ) -> Generator[Tuple[str, torch.Tensor], None, None]: + self, ) -> Generator[tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) @@ -47,7 +48,7 @@ def _load_model_serialized_cpu( """Load a serialized model with tensorizer to the CPU. This is only necessary when the model isn't vLLM-tensorized (see - examples/other/tensorize_vllm_model.py) This should still + examples/others/tensorize_vllm_model.py) This should still be faster than default HuggingFace loading, but will be slower than loading a vLLM-tensorized model. """ @@ -67,7 +68,7 @@ def _load_model_serialized( """Load a serialized model with tensorizer. Expects a vLLM-tensorized model. See the - examples/other/tensorize_vllm_model.py example script + examples/others/tensorize_vllm_model.py example script for serializing vLLM models.""" device_config = vllm_config.device_config @@ -92,8 +93,8 @@ def download_model(self, model_config: ModelConfig) -> None: with self.tensorizer_config.open_stream(): pass - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - model_config = vllm_config.model_config + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) @@ -111,8 +112,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: @staticmethod def save_model( model: torch.nn.Module, - tensorizer_config: TensorizerConfig, + tensorizer_config: Union[TensorizerConfig, dict], ) -> None: + if isinstance(tensorizer_config, dict): + tensorizer_config = TensorizerConfig(**tensorizer_config) serialize_vllm_model( model=model, tensorizer_config=tensorizer_config, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ddc857aebdc8..9c8d647a24fe 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -5,7 +5,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Type +from typing import Optional import torch import transformers @@ -42,9 +42,11 @@ def initialize_model( *, prefix: str = "", model_class: Optional[type[nn.Module]] = None, + model_config: Optional[ModelConfig] = None, ) -> nn.Module: """Initialize a model with the given configurations.""" - model_config = vllm_config.model_config + if model_config is None: + model_config = vllm_config.model_config if model_class is None: model_class, _ = get_model_architecture(model_config) @@ -124,7 +126,7 @@ def device_loading_context(module: torch.nn.Module, yield module return - original_device_states: Dict[str, torch.device] = {} + original_device_states: dict[str, torch.device] = {} # Store original device states and move parameters to GPU if they're on CPU for name, p in module.named_parameters(): @@ -214,7 +216,7 @@ def resolve_transformers_arch(model_config: ModelConfig, def get_model_architecture( - model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + model_config: ModelConfig) -> tuple[type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. @@ -223,17 +225,16 @@ def get_model_architecture( "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" ] - if (model_config.quantization is not None - and model_config.quantization not in mixtral_supported - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] - vllm_supported_archs = ModelRegistry.get_supported_archs() vllm_not_supported = not any(arch in vllm_supported_archs for arch in architectures) if (model_config.model_impl == ModelImpl.TRANSFORMERS or model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): architectures = resolve_transformers_arch(model_config, architectures) + elif (model_config.quantization is not None + and model_config.quantization not in mixtral_supported + and "MixtralForCausalLM" in architectures): + architectures = ["QuantMixtralForCausalLM"] model_cls, arch = ModelRegistry.resolve_model_cls(architectures) if model_config.task == "embed": @@ -257,8 +258,8 @@ class ParamMapping: It creates a bidirectional mapping between packed parameters and their constituent parts. """ - packed_mapping: Dict[str, List[str]] - inverse_packed_mapping: Dict[str, Tuple[str, + packed_mapping: dict[str, list[str]] + inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict) def __post_init__(self): @@ -273,7 +274,7 @@ def __post_init__(self): ) def get_sub_modules(self, - module_name: str) -> Optional[Tuple[str, List[str]]]: + module_name: str) -> Optional[tuple[str, list[str]]]: for key, value in self.packed_mapping.items(): if module_name.endswith(key): return key, value @@ -281,7 +282,7 @@ def get_sub_modules(self, def configure_quant_config(quant_config: QuantizationConfig, - model_class: Type[nn.Module]): + model_class: type[nn.Module]): """ Pass packed_modules_mapping by reference to quant_config so that quant_config can properly match fused modules diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index beff33414ad7..f61956f4e8e0 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -8,8 +8,9 @@ import tempfile import time from collections import defaultdict +from collections.abc import Generator from pathlib import Path -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import filelock import gguf @@ -217,12 +218,45 @@ def get_quant_config(model_config: ModelConfig, return quant_cls.from_config(config) +def get_sparse_attention_config( + model_config: ModelConfig, + load_config: LoadConfig, + sparse_attention_config_filename: str = "sparse_attention_config.json", +) -> dict[str, Any]: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + config_file = os.path.join(hf_folder, sparse_attention_config_filename) + if not os.path.exists(config_file): + return {} + + # Load the sparse attention config. + with open(config_file) as f: + config = json.load(f) + logger.info("Loaded sparse attention config from %s", config_file) + + return config + + def download_weights_from_hf( model_name_or_path: str, cache_dir: Optional[str], - allow_patterns: List[str], + allow_patterns: list[str], revision: Optional[str] = None, - ignore_patterns: Optional[Union[str, List[str]]] = None, + ignore_patterns: Optional[Union[str, list[str]]] = None, ) -> str: """Download model weights from Hugging Face Hub. @@ -230,11 +264,11 @@ def download_weights_from_hf( model_name_or_path (str): The model name or path. cache_dir (Optional[str]): The cache directory to store the model weights. If None, will use HF defaults. - allow_patterns (List[str]): The allowed patterns for the + allow_patterns (list[str]): The allowed patterns for the weight files. Files matched by any of the patterns will be downloaded. revision (Optional[str]): The revision of the model. - ignore_patterns (Optional[Union[str, List[str]]]): The patterns to + ignore_patterns (Optional[Union[str, list[str]]]): The patterns to filter out the weight files. Files matched by any of the patterns will be ignored. @@ -285,6 +319,7 @@ def download_safetensors_index_file_from_hf( Args: model_name_or_path (str): The model name or path. + index_file (str): The safetensors index file name cache_dir (Optional[str]): The cache directory to store the model weights. If None, will use HF defaults. revision (Optional[str]): The revision of the model. @@ -303,10 +338,10 @@ def download_safetensors_index_file_from_hf( ) # If file not found on remote or locally, we should not fail since # only some models will have index_file. - except huggingface_hub.utils.EntryNotFoundError: - logger.info("No %s found in remote.", index_file) except huggingface_hub.utils.LocalEntryNotFoundError: logger.info("No %s found in local cache.", index_file) + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", index_file) # For models like Mistral-7B-v0.3, there are both sharded @@ -314,9 +349,9 @@ def download_safetensors_index_file_from_hf( # Passing both of these to the weight loader functionality breaks. # So, we use the index_file to # look up which safetensors files should be used. -def filter_duplicate_safetensors_files(hf_weights_files: List[str], +def filter_duplicate_safetensors_files(hf_weights_files: list[str], hf_folder: str, - index_file: str) -> List[str]: + index_file: str) -> list[str]: # model.safetensors.index.json is a mapping from keys in the # torch state_dict to safetensors file holding that weight. index_file_name = os.path.join(hf_folder, index_file) @@ -339,7 +374,7 @@ def filter_duplicate_safetensors_files(hf_weights_files: List[str], def filter_files_not_needed_for_inference( - hf_weights_files: List[str]) -> List[str]: + hf_weights_files: list[str]) -> list[str]: """ Exclude files that are not needed for inference. @@ -375,9 +410,9 @@ def np_cache_weights_iterator( model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, - hf_weights_files: List[str], + hf_weights_files: list[str], use_tqdm_on_load: bool, -) -> Generator[Tuple[str, torch.Tensor], None, None]: +) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model np files. Will dump the model weights to numpy files if they are not already dumped. @@ -391,7 +426,7 @@ def np_cache_weights_iterator( # dumping the same model weights to numpy at the same time. with get_lock(model_name_or_path, cache_dir): if not os.path.exists(weight_names_file): - weight_names: List[str] = [] + weight_names: list[str] = [] for bin_file in tqdm( hf_weights_files, desc="Loading np_cache checkpoint shards", @@ -420,9 +455,9 @@ def np_cache_weights_iterator( def safetensors_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: list[str], use_tqdm_on_load: bool, -) -> Generator[Tuple[str, torch.Tensor], None, None]: +) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" for st_file in tqdm( hf_weights_files, @@ -437,9 +472,9 @@ def safetensors_weights_iterator( def runai_safetensors_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: list[str], use_tqdm_on_load: bool, -) -> Generator[Tuple[str, torch.Tensor], None, None]: +) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" with SafetensorsStreamer() as streamer: for st_file in tqdm( @@ -453,9 +488,9 @@ def runai_safetensors_weights_iterator( def fastsafetensors_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: list[str], use_tqdm_on_load: bool, -) -> Generator[Tuple[str, torch.Tensor], None, None]: +) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files using fastsafetensor library.""" if torch.distributed.is_initialized(): @@ -492,10 +527,10 @@ def fastsafetensors_weights_iterator( def pt_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: list[str], use_tqdm_on_load: bool, pt_load_map_location: Union[str, dict[str, str]] = "cpu", -) -> Generator[Tuple[str, torch.Tensor], None, None]: +) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" for bin_file in tqdm( hf_weights_files, @@ -511,7 +546,7 @@ def pt_weights_iterator( def get_gguf_extra_tensor_names( - gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]: + gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]: reader = gguf.GGUFReader(gguf_file) expected_gguf_keys = set(gguf_to_hf_name_map.keys()) exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) @@ -520,8 +555,8 @@ def get_gguf_extra_tensor_names( def gguf_quant_weights_iterator( - gguf_file: str, gguf_to_hf_name_map: Dict[str, str] -) -> Generator[Tuple[str, torch.Tensor], None, None]: + gguf_file: str, gguf_to_hf_name_map: dict[str, str] +) -> Generator[tuple[str, torch.Tensor], None, None]: """ Iterate over the quant weights in the model gguf files and convert them to torch tensors @@ -600,7 +635,7 @@ def row_parallel_weight_loader(param: torch.Tensor, return default_weight_loader(param, loaded_weight) -LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None] def sharded_weight_loader(shard_axis: int) -> LoaderFunction: diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index 730e770dc3d6..aefd6c973755 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -5,129 +5,14 @@ from typing import Optional import torch -from torch import nn, softmax +import torch.nn as nn from torch.nn import functional as F -from torch.nn.functional import gumbel_softmax, pad from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.transformers_utils.configs.ovis2 import (AIMv2Config, - Aimv2VisualTokenizerConfig) - -IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, - -305] # kept for vocab prefixed tokens - - -def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax - index = y_soft.max(dim, keepdim=True)[1] - y_hard = torch.zeros_like( - y_soft, memory_format=torch.legacy_contiguous_format).scatter_( - dim, index, 1.0) - ret = y_hard - y_soft.detach() + y_soft - return ret - - -class Aimv2VisualTokenizer(torch.nn.Module): - - def __init__(self, - config: Aimv2VisualTokenizerConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - **kwargs): - super().__init__() - self.config = config - self.backbone = AIMv2Model( - config=config.backbone_config, # noqa - quant_config=quant_config, - prefix=f"{prefix}.visual_tokenizer") - # reserved tokens for IMAGE_INDICATORS - head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) - self.head = torch.nn.Sequential( - ReplicatedLinear( - config.backbone_config.hidden_size * config.hidden_stride * - config.hidden_stride, - head_dim, - bias=False, - ), torch.nn.LayerNorm(head_dim)) - - @property - def dtype(self): - return self.backbone.dtype - - @property - def device(self): - return self.backbone.device - - def tokenize(self, logits): - if self.config.tokenize_function == 'softmax': - tokens = softmax(logits, dim=-1) - elif self.config.tokenize_function == 'gumbel_argmax': - tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) - elif self.config.tokenize_function == 'st_argmax': - tokens = st_argmax(logits, dim=-1) - else: - raise ValueError( - 'Invalid `max_type`, expected softmax or gumbel_argmax ' - f'or st_argmax, but got {self.config.tokenize_function}') - return tokens - - def encode(self, pixel_values): - features = self.backbone(pixel_values) - if self.config.drop_cls_token: - features = features[:, 1:, :] - - # merge number of `hidden_stride * hidden_stride` hidden states together - # to reduce token sequence length - # e.g., for hidden_stride=2, this leads to a token length reduction: - # 1024 -> 256 for aimv2 - if self.config.hidden_stride > 1: - # this `d` maybe different from the above `d`` - n, L, d = features.shape - sqrt_l = int(L**0.5) - assert sqrt_l**2 == L, ( - "The token sequence length should be a perfect square.") - features = features.reshape(n, sqrt_l, sqrt_l, d) - pl = (self.config.hidden_stride - - (sqrt_l % - self.config.hidden_stride)) % self.config.hidden_stride - features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) - sqrt_l += pl - features = features.reshape(n, sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, - sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, d) - # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] - features = features.permute(0, 1, 3, 2, 4, 5) - # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] - features = features.flatten(3) - # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] - features = features.reshape( - n, -1, - self.config.hidden_stride * self.config.hidden_stride * d) - - return features - - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]""" - features = self.encode(pixel_values) - logits, _ = self.head[0]( - features) # we spllit the sequncial here for not throwing an error - logits = self.head[1](logits) - tokens = self.tokenize(logits) - # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with - # [BatchSize, #Token, 5], after which, tokens' shape should become - # [BatchSize, #Token, VocabSize] - batch_size, token_len, _ = tokens.shape - padding_tensor = torch.zeros(size=(batch_size, token_len, - len(IMAGE_INDICATOR_IDS)), - dtype=tokens.dtype, - device=tokens.device, - layout=tokens.layout, - requires_grad=False) - tokens = torch.cat((tokens, padding_tensor), dim=2) - return tokens +from vllm.transformers_utils.configs.ovis import AIMv2Config class AIMv2SwiGLUFFN(nn.Module): @@ -302,14 +187,6 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.trunk") - @property - def dtype(self): - return self.trunk.blocks[0].attn.qkv.weight.dtype - - @property - def device(self): - return self.trunk.blocks[0].attn.qkv.device - def forward( self, pixel_values: torch.Tensor, diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index c518efdb54f8..94a4328564bb 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only Snowflake Arctic model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -458,8 +459,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -467,8 +468,8 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] - mlp_params_mapping: List[Tuple[str, str, int]] = [] - expert_params_mapping: List[Tuple[str, str, int]] = [] + mlp_params_mapping: list[tuple[str, str, int]] = [] + expert_params_mapping: list[tuple[str, str, int]] = [] num_layers = self.config.num_hidden_layers for layer in range(num_layers): @@ -497,7 +498,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("ws", f"experts.{expert_id}.w3.weight", expert_id)) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() logger.info( "It will take ~10 minutes loading from the 16-bit weights. " diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 7c716efab8ef..f74e13888c48 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable, Mapping, Sequence -from typing import List, Optional, Set, Tuple, TypedDict, Union +from typing import Optional, TypedDict, Union import torch import torch.nn as nn @@ -66,8 +66,8 @@ def __init__( # Identity layer self.post_layernorm = nn.Identity() - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -75,7 +75,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: # NOTE: post_layernorm is not used in Aria @@ -326,8 +326,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Adapted from LlamaModel.load_weights with the modification of adding # the expert weights mapping to `stacked_params_mapping` - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -339,7 +339,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("experts.w2_weight", "experts.fc2.weight", 'w2'), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -528,7 +528,7 @@ def __init__( self.vocab_size, logit_scale) def _validate_image_sizes( - self, images: List[torch.Tensor]) -> List[torch.Tensor]: + self, images: list[torch.Tensor]) -> list[torch.Tensor]: if not all(img.shape == images[0].shape for img in images): raise ValueError("All images must be the same size") return images @@ -578,7 +578,7 @@ def _create_patch_attention_mask( def _process_image_input( self, image_input: AriaImagePixelInputs - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.vision_tower is not None pixel_values = image_input['pixel_values'] @@ -651,6 +651,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index d152287e8fa3..08d49d71eca1 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 Adapted from # https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision -from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple, - TypedDict, Union, cast) +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union, cast import torch from torch import nn @@ -315,8 +315,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def dtype(self): return next(self.parameters()).dtype - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 444ed38d05c0..bcff6eb3fd31 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -20,7 +20,8 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -41,7 +42,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -230,7 +232,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -320,15 +322,15 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -383,7 +385,7 @@ def __init__( lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config - + self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config self.model = BaiChuanModel(vllm_config=vllm_config, prefix=prefix, @@ -421,8 +423,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) @@ -437,8 +439,10 @@ def lm_head_weight_loader(self, param: nn.Parameter, is_baichuan2 = self.config.vocab_size == 125696 if is_baichuan2: loaded_weight = torch.nn.functional.normalize(loaded_weight) - - default_weight_loader(param, loaded_weight) + if self.tp_size > 1: + row_parallel_weight_loader(param, loaded_weight) + else: + default_weight_loader(param, loaded_weight) class BaichuanForCausalLM(BaiChuanBaseForCausalLM): diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 87e1e102efd8..d6a705fb1859 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only Bamba model.""" # Added by the IBM Team, 2024 -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -355,8 +356,8 @@ def forward( hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -367,7 +368,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -495,7 +496,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size @@ -535,7 +536,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index bcfbe92c3a11..92bbe1bb67a3 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -19,7 +19,8 @@ # limitations under the License. """PyTorch BART model.""" import math -from typing import Iterable, Optional, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -859,14 +860,14 @@ def _rename_key(self, key: str): def _rename_stacked_param( self, name: str, - ) -> Tuple[str, Optional[str]]: + ) -> tuple[str, Optional[str]]: for key, mapping in self.stacked_params_mapping.items(): if key in name: name = name.replace(key, mapping["param_name"]) return name, mapping["shard_id"] return name, None - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): model_params_dict = dict(self.model.named_parameters()) top_params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 76a529c93343..0c6593bbe3a1 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -11,16 +12,13 @@ from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.activation import (get_act_and_mul_fn, - get_act_fn) +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -41,24 +39,19 @@ def __init__(self, config: BertConfig): self.size = config.hidden_size self.word_embeddings = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - + self.position_embeddings = VocabParallelEmbedding( + config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = VocabParallelEmbedding( config.type_vocab_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), ) self.position_embedding_type = config.position_embedding_type - if self.position_embedding_type == "absolute": - self.position_embeddings = VocabParallelEmbedding( - config.max_position_embeddings, config.hidden_size) - self.position_ids = nn.Parameter( - torch.empty((1, config.max_position_embeddings)), ) - elif self.position_embedding_type == "rotary": - self.position_embeddings = None - self.position_ids = None - else: - raise ValueError("Only 'absolute' and 'rotary' " + - "position_embedding_type is supported") + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") def forward( self, @@ -72,6 +65,9 @@ def forward( # Input embeddings. inputs_embeds = self.word_embeddings(input_ids) + # Position embeddings. + position_embeddings = self.position_embeddings(position_ids) + if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, @@ -79,12 +75,7 @@ def forward( token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - + embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings @@ -108,11 +99,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @support_torch_compile class BertEncoder(nn.Module): - def __init__(self, - vllm_config: VllmConfig, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, - prefix: str = ""): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -121,19 +108,16 @@ def __init__(self, BertLayer(config=config, cache_config=cache_config, quant_config=quant_config, - bias=bias, - rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.layer.{layer_idx}") for layer_idx in range(config.num_hidden_layers) ]) def forward( self, - positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: for layer in self.layer: - hidden_states = layer(positions, hidden_states) + hidden_states = layer(hidden_states) return hidden_states @@ -143,8 +127,6 @@ def __init__(self, config: BertConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, prefix: str = ""): super().__init__() @@ -154,36 +136,23 @@ def __init__(self, layer_norm_eps=config.layer_norm_eps, cache_config=cache_config, quant_config=quant_config, - bias=bias, - rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.attention") - if config.hidden_act in ["silu", "gelu_and_mul"]: - self.intermediate = BertGatedIntermediate( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.intermediate") - else: - self.intermediate = BertIntermediate( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.intermediate") + self.intermediate = BertIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") self.output = BertOutput(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, layer_norm_eps=config.layer_norm_eps, - bias=bias, quant_config=quant_config, prefix=f"{prefix}.output") - def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): - attn_output = self.attention(positions, hidden_states) + def forward(self, hidden_states: torch.Tensor): + attn_output = self.attention(hidden_states) intermediate_output = self.intermediate(attn_output) output = self.output(intermediate_output, attn_output) return output @@ -198,8 +167,6 @@ def __init__( layer_norm_eps: float, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, prefix: str = "", ): super().__init__() @@ -208,22 +175,18 @@ def __init__( num_attention_heads=num_attention_heads, cache_config=cache_config, quant_config=quant_config, - bias=bias, - rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.output") self.output = BertSelfOutput(hidden_size=hidden_size, layer_norm_eps=layer_norm_eps, - bias=bias, quant_config=quant_config, prefix=f"{prefix}.output") def forward( self, - positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - self_output = self.self(positions, hidden_states) + self_output = self.self(hidden_states) return self.output(self_output, hidden_states) @@ -235,8 +198,6 @@ def __init__( num_attention_heads: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, prefix: str = "", ): super().__init__() @@ -261,15 +222,10 @@ def __init__( head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, - bias=bias, + bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj") - if rotary_kwargs: - self.rotary_emb = get_rope(**rotary_kwargs) - else: - self.rotary_emb = None - self.attn = Attention(num_heads=self.num_heads, head_size=self.head_dim, scale=self.scaling, @@ -281,15 +237,10 @@ def __init__( def forward( self, - positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - if self.rotary_emb: - q, k = self.rotary_emb(positions, q, k) - output = self.attn(q, k, v) return output @@ -299,13 +250,12 @@ class BertSelfOutput(nn.Module): def __init__(self, hidden_size: int, layer_norm_eps: float, - bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = RowParallelLinear(input_size=hidden_size, output_size=hidden_size, - bias=bias, + bias=True, quant_config=quant_config, prefix=f"{prefix}.dense") self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) @@ -323,13 +273,12 @@ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str, - bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = ColumnParallelLinear(input_size=hidden_size, output_size=intermediate_size, - bias=bias, + bias=True, quant_config=quant_config, prefix=f"{prefix}.dense") self.intermediate_act_fn = get_act_fn(hidden_act) @@ -340,46 +289,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class BertGatedIntermediate(nn.Module): - # for NomciBert and GteModel - - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__() - self.act_fn = get_act_and_mul_fn(hidden_act) - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - gate_up, _ = self.gate_up_proj(hidden_states) - hidden_states = self.act_fn(gate_up) - return hidden_states - - class BertOutput(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, layer_norm_eps: float, - bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = RowParallelLinear(input_size=intermediate_size, output_size=hidden_size, - bias=bias, + bias=True, quant_config=quant_config, prefix=f"{prefix}.dense") @@ -393,33 +315,18 @@ def forward(self, hidden_states: torch.Tensor, class BertModel(nn.Module, SupportsQuant): - packed_modules_mapping = { - "qkv_proj": ["query", "key", "value"], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } + packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", embedding_class: type = BertEmbedding, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, add_pooling_layer: bool = False): super().__init__() - """ - For BertModel, all linear layers have bias. - For NomicBertModel, all linear layers do not have bias. - """ - config = vllm_config.model_config.hf_config self.embeddings = embedding_class(config) self.encoder = BertEncoder(vllm_config=vllm_config, - bias=bias, - rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.encoder") self.pooler = BertPooler(config) if add_pooling_layer else None @@ -441,21 +348,19 @@ def forward( seq_lens=attn_metadata.seq_lens_tensor, position_ids=position_ids, token_type_ids=token_type_ids) - return self.encoder(position_ids, hidden_states) + return self.encoder(hidden_states) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "query", "q"), ("qkv_proj", "key", "k"), ("qkv_proj", "value", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if self.pooler is None and "pooler" in name: continue @@ -497,7 +402,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() pooler_config = vllm_config.model_config.pooler_config - self.config = vllm_config.model_config.hf_config self.model = self._build_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self._pooler = self._build_pooler(pooler_config) @@ -521,7 +425,7 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights = self.hf_to_vllm_mapper.apply(weights) weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) @@ -569,7 +473,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] @@ -611,115 +515,3 @@ def forward( inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, token_type_ids=token_type_ids) - - -class NomicBertEmbeddingModel(BertEmbeddingModel): - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={ - "emb_ln": "embeddings.LayerNorm", - "layers": "layer", - "attn.Wqkv": "attention.self.qkv_proj", - "attn.out_proj": "attention.output.dense", - 'norm1': "attention.output.LayerNorm", - 'mlp.fc11': "intermediate.up_proj", - 'mlp.fc12': "intermediate.gate_proj", - 'mlp.fc2': "output.dense", - 'norm2': "output.LayerNorm", - }) - - def _build_model(self, - vllm_config: VllmConfig, - prefix: str = "") -> BertModel: - config = vllm_config.model_config.hf_config - - assert config.__class__.__name__ == "NomicBertConfig" - assert config.activation_function == "swiglu" - - # Assume NomicBertModel all linear layers do not have bias - assert not config.mlp_fc1_bias - assert not config.mlp_fc2_bias - assert not config.qkv_proj_bias - - config.layer_norm_eps = config.layer_norm_epsilon - config.position_embedding_type = "rotary" - config.intermediate_size = config.n_inner - config.hidden_act = "silu" - config.hidden_size = config.n_embd - config.num_hidden_layers = config.n_layer - - head_dim = config.hidden_size // config.num_attention_heads - rotary_kwargs = { - "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), - "max_position": config.max_trained_positions, - "base": config.rotary_emb_base, - "rope_scaling": { - "rope_type": "dynamic", - "factor": config.rotary_scaling_factor - } - } - - return BertModel(vllm_config=vllm_config, - prefix=prefix, - bias=False, - rotary_kwargs=rotary_kwargs, - embedding_class=BertEmbedding) - - -class GteEmbeddingModel(BertEmbeddingModel): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={ - "attention.qkv_proj": "attention.self.qkv_proj", - "attention.o_proj": "attention.output.dense", - 'attn_ln': "attention.output.LayerNorm", - 'mlp.down_proj': "output.dense", - 'mlp_ln': "output.LayerNorm", - }) - - def _build_model(self, - vllm_config: VllmConfig, - prefix: str = "") -> BertModel: - config = vllm_config.model_config.hf_config - - assert config.__class__.__name__ == "GteConfig" - assert config.position_embedding_type == "rope" - assert config.hidden_act == "gelu" - - config.position_embedding_type = "rotary" - config.hidden_act = "gelu_and_mul" - - head_dim = config.hidden_size // config.num_attention_heads - rotary_kwargs = { - "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), - "max_position": config.max_position_embeddings, - "base": config.rope_theta, - } - - model = BertModel(vllm_config=vllm_config, - prefix=prefix, - rotary_kwargs=rotary_kwargs, - embedding_class=BertEmbedding) - - # GteModel only gate_up_proj does not have bias. - # Hack method learned from vllm/model_executor/models/glm.py - for layer in model.encoder.layer: - layer.intermediate.gate_up_proj.bias = None - layer.intermediate.skip_bias_add = True - return model - - def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]): - n = "mlp.up_gate_proj" - for name, weight in weights: - if n in name: - up, gate = weight.chunk(2, dim=0) - yield name.replace(n, "intermediate.up_proj"), up - yield name.replace(n, "intermediate.gate_proj"), gate - else: - yield name, weight - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - weights = self.split_up_gate_proj(weights) - self.model.load_weights(weights) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py new file mode 100644 index 000000000000..af6deb3bf072 --- /dev/null +++ b/vllm/model_executor/models/bert_with_rope.py @@ -0,0 +1,714 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import (get_act_and_mul_fn, + get_act_fn) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import SupportsV0Only +from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.utils import WeightsMapper +from vllm.sequence import IntermediateTensors + + +class BertWithRopeEmbedding(nn.Module): + + def __init__(self, config: PretrainedConfig): + + super().__init__() + if config.position_embedding_type not in ["rope", "rotary"]: + raise ValueError("Only 'rotary'('rope') position_embedding_type" + + " is supported") + + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + if config.type_vocab_size > 0: + self.token_type_embeddings = VocabParallelEmbedding( + config.type_vocab_size, config.hidden_size) + else: + self.token_type_embeddings = None + + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_shape = input_ids.size() + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + if self.token_type_embeddings is not None: + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertWithRopeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = "", + ): + super().__init__() + + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_attention_heads + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.rotary_emb = get_rope(**rotary_kwargs) + + self.attn = Attention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_ONLY) + + self.out_proj = RowParallelLinear(input_size=hidden_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class BertWithRopeGatedMLP(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.act_fn = get_act_and_mul_fn(hidden_act) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class BertWithRopeMLP(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.act_fn = get_act_fn(hidden_act) + self.up_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.up_proj(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class NomicRouter(nn.Module): + + def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int): + super().__init__() + self.moe_top_k = moe_top_k + self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False) + + def forward( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax( + dim=-1, dtype=torch.float32) + top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) + weights = weights.to(x.dtype) + top_weights = top_weights.to(x.dtype) + return weights, top_weights, top_experts # type: ignore + + +class NomicExpertMLP(nn.Module): + + def __init__(self, hidden_size: int, ffn_hidden_size: int, + moe_num_experts: int, ffn_act_fn: str): + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + + self.w1 = nn.Parameter( + torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.w2 = nn.Parameter( + torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.activation_fn = get_act_fn(ffn_act_fn) + + def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: + expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + + x1 = x.matmul(expert_w1.t()) + act_out = self.activation_fn(x1) + x2 = act_out.matmul(expert_w2) + return x2 + + +class NomicExperts(nn.Module): + + def __init__(self, config, hidden_size: int, ffn_hidden_size: int, + moe_num_experts: int): + super().__init__() + self.moe_num_experts = moe_num_experts + + self.mlp = NomicExpertMLP(hidden_size=config.n_embd, + ffn_hidden_size=config.n_inner, + moe_num_experts=moe_num_experts, + ffn_act_fn=config.hidden_act) + self.bias = nn.Parameter(torch.zeros(config.n_embd)) + + def forward(self, x: torch.Tensor, weights: torch.Tensor, + top_weights: torch.Tensor, + top_experts: torch.LongTensor) -> torch.Tensor: + q_len, hidden_size = x.shape + x = x.view(-1, hidden_size) + out = torch.zeros_like(x) + + expert_mask = nn.functional.one_hot( + top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + for expert_idx in range(0, self.moe_num_experts): + topk_idx, token_idx = torch.where(expert_mask[expert_idx]) + if token_idx.shape[0] == 0: + continue + + token_list = token_idx.tolist() + topk_list = topk_idx.tolist() + + expert_tokens = x[None, token_list].reshape(-1, hidden_size) + expert_out = self.mlp( + expert_tokens, expert_idx) * top_weights[token_list, topk_list, + None] + + out.index_add_(0, token_idx, expert_out) + + out = out.reshape(q_len, hidden_size) + return out + self.bias + + +class NomicMoELayer(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + + self.router = NomicRouter( + config.n_embd, + moe_num_experts=config.num_experts, + moe_top_k=config.moe_top_k, + ) + + self.experts = NomicExperts( + config, + hidden_size=config.n_embd, + ffn_hidden_size=config.n_inner, + moe_num_experts=config.num_experts, + ) + + def forward(self, x: torch.Tensor): + weights, top_weights, top_experts = self.router(x) + out = self.experts(x, weights, top_weights, top_experts) + return out + + +class BertWithRopeBlock(nn.Module): + + def __init__(self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + moe: bool = False, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = ""): + super().__init__() + self.attn = BertWithRopeAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + bias=bias, + rotary_kwargs=rotary_kwargs, + prefix=f"{prefix}.attention") + + if moe: + self.mlp = NomicMoELayer(config=config, ) + else: + if config.hidden_act in ["silu", "geglu"]: + self.mlp = BertWithRopeGatedMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = BertWithRopeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.attn_ln = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp_ln = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): + attn_output = self.attn(positions, hidden_states) + hidden_states = self.attn_ln(hidden_states + attn_output) + mlp_out = self.mlp(hidden_states) + hidden_states = self.mlp_ln(hidden_states + mlp_out) + return hidden_states + + +@support_torch_compile +class BertWithRopeEncoder(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + every_n = getattr(config, "moe_every_n_layers", 0) + self.layers = nn.ModuleList([ + BertWithRopeBlock(config=config, + cache_config=cache_config, + quant_config=quant_config, + bias=bias, + moe=every_n > 0 and (layer_idx % every_n == 1), + rotary_kwargs=rotary_kwargs, + prefix=f"{prefix}.layer.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + for layer in self.layers: + hidden_states = layer(positions, hidden_states) + return hidden_states + + +class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.config = self.config_verify(vllm_config) + self.embeddings = BertWithRopeEmbedding(self.config) + self.encoder = BertWithRopeEncoder( + vllm_config=vllm_config, + bias=getattr(self.config, "bias", True), + rotary_kwargs=self.config.rotary_kwargs, + prefix=f"{prefix}.encoder") + + def config_verify(self, vllm_config): + raise NotImplementedError + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embeddings(input_ids=input_ids, + token_type_ids=token_type_ids) + return self.encoder(positions, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + weights = self.hf_to_vllm_mapper.apply(weights) + + if self.config.hidden_act in ["silu", "geglu"]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + else: + stacked_params_mapping = [] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "pooler" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class NomicBertModel(BertWithRope): + # for https://huggingface.co/nomic-ai/nomic-bert-2048 + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "emb_ln": "embeddings.LayerNorm", + "attn.Wqkv": "attn.qkv_proj", + "norm1": "attn_ln", + "mlp.fc1.": "mlp.up_proj.", + "mlp.fc11": "mlp.up_proj", + "mlp.fc12": "mlp.gate_proj", + "mlp.fc2": "mlp.down_proj", + "norm2": "mlp_ln", + }) + + def config_verify(self, vllm_config): + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "NomicBertConfig" + assert config.activation_function in ["swiglu", "gelu"] + config.position_embedding_type = getattr(config, + "position_embedding_type", + "rope") + + if config.activation_function == "swiglu": + config.hidden_act = "silu" + else: + config.hidden_act = config.activation_function + + assert (config.mlp_fc1_bias == config.mlp_fc2_bias == + config.qkv_proj_bias) + config.bias = config.qkv_proj_bias + + assert config.rotary_emb_scale_base is None + assert not config.rotary_emb_interleaved + + config.layer_norm_eps = config.layer_norm_epsilon + config.intermediate_size = config.n_inner + config.hidden_size = config.n_embd + config.num_hidden_layers = config.n_layer + + head_dim = config.hidden_size // config.num_attention_heads + rotary_emb_dim = head_dim * config.rotary_emb_fraction + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": rotary_emb_dim, + "max_position": config.max_trained_positions, + "base": getattr(config, "rope_theta", config.rotary_emb_base), + "rope_scaling": getattr(config, "rope_scaling", None) + } + + # we ignore config.rotary_scaling_factor so that for datasets shorter + # than max_trained_positions 2048, the results are consistent + # with SentenceTransformer. + # The context extension uses vllm style rope_theta and rope_scaling. + # See #17785 + + return config + + +class GteNewModel(BertWithRope): + # for https://huggingface.co/Alibaba-NLP/new-impl + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "new.": "", + "layer": "layers", + "attention.qkv_proj": "attn.qkv_proj", + "attention.o_proj": "attn.out_proj", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # GteNewModel only gate_up_proj does not have bias. + # Hack method learned from vllm/model_executor/models/glm.py + for layer in self.encoder.layers: + layer.mlp.gate_up_proj.bias = None + layer.mlp.gate_up_proj.skip_bias_add = True + + def config_verify(self, vllm_config): + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "NewConfig" + assert config.hidden_act == "gelu" + + config.hidden_act = "geglu" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + "rope_scaling": getattr(config, "rope_scaling", None) + } + return config + + def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]): + n = "mlp.up_gate_proj" + for name, weight in weights: + if n in name: + up, gate = weight.chunk(2, dim=0) + yield name.replace(n, "mlp.up_proj"), up + yield name.replace(n, "mlp.gate_proj"), gate + else: + yield name, weight + + def ignore_unnecessary_layers(self, + weights: Iterable[tuple[str, torch.Tensor]]): + for name, weight in weights: + if name.startswith("classifier"): + continue + yield name, weight + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + weights = self.ignore_unnecessary_layers(weights) + weights = self.split_up_gate_proj(weights) + return super().load_weights(weights) + + +class SnowflakeGteNewModel(GteNewModel): + # for Snowflake/snowflake-arctic-embed-m-v2.0 + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "layer": "layers", + "attention.qkv_proj": "attn.qkv_proj", + "attention.o_proj": "attn.out_proj", + }) + + def config_verify(self, vllm_config): + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "GteConfig" + assert config.hidden_act == "gelu" + + config.hidden_act = "geglu" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + "rope_scaling": getattr(config, "rope_scaling", None) + } + return config + + +class JinaRobertaModel(BertWithRope): + # for https://huggingface.co/jinaai/jina-embeddings-v3 + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "emb_ln": "embeddings.LayerNorm", + "mixer.Wqkv": "attn.qkv_proj", + "mixer.out_proj": "attn.out_proj", + "norm1": "attn_ln", + "mlp.fc1.": "mlp.up_proj.", + "mlp.fc2": "mlp.down_proj", + "norm2": "mlp_ln", + }) + + def config_verify(self, vllm_config): + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "XLMRobertaFlashConfig" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": getattr(config, "rope_theta", config.rotary_emb_base), + "rope_scaling": getattr(config, "rope_scaling", None) + } + return config + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return super().forward(input_ids=input_ids, + positions=position_ids, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + token_type_ids=token_type_ids) + + @torch.inference_mode() + def jina_merge_lora_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]): + # use for jina-embeddings-v3 + # Merge Lora weights into a single weight tensor. + # This is a temporary solution until we have a better way to handle + + scaling = self.config.lora_alpha / self.config.lora_rank + device = self.vllm_config.device_config.device + + weights = {name: weight for name, weight in weights} + + o = ".original" + a = ".0.lora_A" + b = ".0.lora_B" + + # text-matching + i = -1 + + for name in list(weights.keys()): + if o in name: + dtype = weights[name].dtype + shape = weights[name].shape + weight_name = name[:-len(o)] + + if "embeddings" in weight_name: + B = weights[weight_name + a][i].to(device).float() + A = weights[weight_name + b][i].to(device).float() + else: + B = weights[weight_name + b][i].to(device).float() + A = weights[weight_name + a][i].to(device).float() + + weight = (weights[weight_name + o].to(device) + + torch.matmul(B, A).view(shape) * scaling) + weight = weight.cpu().to(dtype) + + weights[weight_name.replace(".parametrizations", "")] = weight + + del weights[weight_name + o], weights[weight_name + + a], weights[weight_name + + b] + + return [(name, weight) for name, weight in weights.items()] + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + weights = self.jina_merge_lora_weights(weights) + return super().load_weights(weights) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index f3d488926d09..acbc5d04d7e3 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.nn as nn @@ -296,8 +297,8 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.post_layernorm(hidden_states) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -305,7 +306,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() layer_count = len(self.encoder.layers) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index f44565bd2e01..db0dd2051d52 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -186,7 +186,7 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, - ) -> Tuple[torch.Tensor]: + ) -> tuple[torch.Tensor]: self_output = self.attention( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -681,9 +681,8 @@ def forward( batch. pixel_values: The pixels in each input image. - :::{seealso} - {class}`Blip2ImageInputs` - ::: + Info: + [Blip2ImageInputs][] """ if intermediate_tensors is not None: @@ -712,7 +711,7 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 74d401b295ce..10424e218fbc 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -18,7 +18,8 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -42,7 +43,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -228,6 +229,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.config = config self.embed_dim = config.hidden_size @@ -277,6 +279,38 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): @@ -322,37 +356,17 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if name == "lm_head.weight": - continue - if not name.startswith("transformer."): - name = "transformer." + name - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) + weights = _add_transformer_prefix(weights) + return loader.load_weights(weights) - if "query_key_value" in name: - # NOTE: BLOOM's fused QKV's output_dim has the shape of - # (num_heads * 3 * head_size), while the - # required shape is (3 * num_heads * head_size). - # Thus, we need weight conversion. - output_dim = getattr(param, "output_dim", None) - num_heads = self.config.num_attention_heads - if output_dim is not None: - loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) - loaded_weight = loaded_weight.reshape(loaded_weight_shape) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params +def _add_transformer_prefix( + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: + for name, tensor in weights: + if not name.startswith('transformer.'): + name = 'transformer.' + name + yield name, tensor diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index ef8b033f3846..a4528ca26d01 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -229,7 +229,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 4096, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, @@ -292,7 +292,7 @@ def __init__( prefix=f"{prefix}.attn") def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # reshape for layernorm q = q.reshape(-1, self.num_heads, self.head_dim) k = k.reshape(-1, self.num_kv_heads, self.head_dim) @@ -367,7 +367,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is None: residual = hidden_states @@ -438,7 +438,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.self_attn( @@ -773,7 +773,7 @@ def __init__(self, config: ChameleonVQVAEConfig): def encode( self, pixel_values: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: hidden_states = self.encoder(pixel_values) hidden_states = self.quant_conv(hidden_states) quant, emb_loss, indices = self.quantize(hidden_states) @@ -786,7 +786,7 @@ class ChameleonImageVocabularyMapping: A class for mapping discrete image tokens from VQGAN to BPE tokens. """ - def __init__(self, vocab_map: Dict[str, int]): + def __init__(self, vocab_map: dict[str, int]): self.vocab_map = vocab_map self.image_token_id = vocab_map.get("<image>") @@ -1052,8 +1052,8 @@ def compute_logits( return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1063,7 +1063,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 233e9ee0a258..4e95afe1a147 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -3,7 +3,8 @@ # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" import json -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -358,15 +359,15 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("linear_proj.merged_proj", "linear_proj.gate_proj", 0), ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -440,7 +441,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 153054e5c028..e8f3ae2156e0 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.nn as nn @@ -368,8 +369,8 @@ def device(self): # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -377,7 +378,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 8f64e5d5c966..546b5f932877 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -21,7 +21,8 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -259,7 +260,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -404,8 +405,8 @@ def compute_logits( return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -415,7 +416,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: # Skip loading rotary embeddings since vLLM has its own diff --git a/vllm/model_executor/models/constant_size_cache.py b/vllm/model_executor/models/constant_size_cache.py index d073a7de6917..f1cc7e0f9e29 100644 --- a/vllm/model_executor/models/constant_size_cache.py +++ b/vllm/model_executor/models/constant_size_cache.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Any import torch @@ -16,7 +16,7 @@ class ConstantSizeCache(ABC): def __init__(self, max_batch_size: int): # Maps between the request id and a dict that maps between the seq_id # and its index inside the cache - self.cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.cache_indices_mapping: dict[str, dict[int, int]] = {} self.free_cache_indices = list(range(max_batch_size)) @property @@ -30,7 +30,7 @@ def _copy_cache(self, from_index: int, to_index: int): """Copy cache data from one index to another""" pass - def current_run_tensors(self, **kwargs) -> Tuple: + def current_run_tensors(self, **kwargs) -> tuple: """ Return the tensors for the current run's conv and ssm state. """ @@ -117,8 +117,8 @@ def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, return self.cache_indices_mapping[cur_rid][seq_id] def _prepare_current_run_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], - finished_requests_ids: List[str]) -> List[int]: + self, request_ids_to_seq_ids: dict[str, list[int]], + finished_requests_ids: list[str]) -> list[int]: return [ self._assign_seq_id_to_cache_index(req_id, seq_id, finished_requests_ids) @@ -127,7 +127,7 @@ def _prepare_current_run_cache( ] def _release_finished_requests(self, - finished_seq_groups_req_ids: List[str]): + finished_seq_groups_req_ids: list[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.cache_indices_mapping: for seq_id in self.cache_indices_mapping[req_id]: diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 9ec245cce189..f21887f71d85 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.nn as nn @@ -25,7 +26,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -79,7 +80,6 @@ def __init__( prefix=prefix, ) self.config = config - self.tp_size = get_tensor_model_parallel_world_size() self.d_model = config.d_model self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // self.tp_size) @@ -319,6 +319,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.quant_config = quant_config self.wte = VocabParallelEmbedding( config.vocab_size, config.d_model, @@ -364,6 +365,55 @@ def forward( hidden_states = self.norm_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + expert_params_mapping = [( + "w13" if weight_name in ["w1", "v1"] else "w2", + f"mlp.{weight_name}", + ) for weight_name in ["w1", "v1", "w2"]] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + if name.endswith(("w1", "w2", "v1")): + name = name + "_weight" + for param_name, weight_name in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, weight_name, name) + break + + else: + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class DbrxForCausalLM(nn.Module, SupportsPP): @@ -415,51 +465,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - expert_params_mapping = [( - "w13" if weight_name in ["w1", "v1"] else "w2", - f"mlp.{weight_name}", - ) for weight_name in ["w1", "v1", "w2"]] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - - for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - - if name.endswith(("w1", "w2", "v1")): - name = name + "_weight" - for param_name, weight_name in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, weight_name, name) - break - - else: - if is_pp_missing_parameter(name, self): - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index c6421143dd68..88d1ca9f7b83 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -184,7 +185,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -385,8 +386,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -397,7 +398,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -478,7 +479,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index b50175cf764f..03ef7bed0edc 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -18,6 +19,7 @@ from .deepseek_v2 import (DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name) +from .interfaces import SupportsPP from .utils import maybe_prefix @@ -144,7 +146,7 @@ def compute_logits( return logits -class DeepSeekMTP(nn.Module): +class DeepSeekMTP(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -176,8 +178,8 @@ def compute_logits( return self.model.compute_logits(hidden_states, sampling_metadata, spec_step_idx) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), @@ -190,7 +192,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ce86b9b2c4f0..b78c193c1345 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -31,9 +32,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -143,7 +142,8 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), prefix=f"{prefix}.shared_experts", ) @@ -154,6 +154,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + if hidden_states.dtype != torch.float16: final_hidden_states = self.experts( hidden_states=hidden_states, @@ -171,9 +172,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # See DeepseekV2DecoderLayer for more details. final_hidden_states = final_hidden_states + shared_output \ * (1. / self.routed_scaling_factor) + if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) return final_hidden_states.view(num_tokens, hidden_dim) @@ -198,7 +201,7 @@ def __init__( q_lora_rank: int, kv_lora_rank: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -350,7 +353,7 @@ def __init__( q_lora_rank: Optional[int], kv_lora_rank: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -453,7 +456,6 @@ def __init__( qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, kv_b_proj=self.kv_b_proj, ) @@ -475,6 +477,13 @@ def forward( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + attn_out = self.mla_attn( q, kv_c_normed, @@ -728,8 +737,8 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -745,7 +754,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 6d8f27530cee..5c8793f59ffb 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -4,7 +4,7 @@ """Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -45,7 +45,7 @@ class DeepseekVL2ImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: Union[torch.Tensor, list[torch.Tensor]] """ Shape: `(batch_size * num_images, num_channels, height, width)` """ @@ -57,7 +57,7 @@ class DeepseekVL2ImagePixelInputs(TypedDict): class DeepseekVL2VImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: Union[torch.Tensor, list[torch.Tensor]] """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. @@ -210,9 +210,7 @@ def _call_hf_processor( dict(prompt=prompt, **mm_data), mm_kwargs, ) - target_dtype = self.info.ctx.model_config.dtype - pixel_values = processed_outputs.pop("pixel_values").to( - target_dtype) + pixel_values = processed_outputs["pixel_values"] # split pixel values into patches corresponding to each image images_spatial_crop = processed_outputs["images_spatial_crop"] patches_per_image = [ @@ -394,8 +392,8 @@ def _init_vision_module( return model def _validate_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: h = w = self.vision_config.image_size expected_dims = (3, h, w) @@ -415,8 +413,8 @@ def _validate_shape(d: torch.Tensor): return data def _validate_images_spatial_crop( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: expected_dims = 2 def _validate_shape(d: torch.Tensor): @@ -640,8 +638,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) autoloaded_weights = loader.load_weights(weights, diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 4ff1e785494f..fb1675d29915 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Optional, Tuple +from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -145,6 +146,17 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) + # Handle both empty previous_hidden_states + # and mismatched batch size + batch_size = inputs_embeds.size(0) + if previous_hidden_states.size(0) == 0 or \ + previous_hidden_states.size(0) != batch_size: + hidden_dim = self.config.model.hidden_size + device = inputs_embeds.device + # Create zero tensor with matching batch size + previous_hidden_states = \ + torch.zeros(batch_size, hidden_dim, device=device) + if self.add_para_norm: inputs_embeds = torch.cat([ self.enorm(inputs_embeds), @@ -183,7 +195,7 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B # due to missing lm_head weights and its config being that of a # Llama model. Here's a compatible version with the same weights: diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 4a6490cd127a..838560692bcf 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -24,7 +24,8 @@ # limitations under the License. """Inference-only Exaone model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -102,7 +103,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, @@ -126,8 +127,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -196,7 +198,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, @@ -282,7 +284,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -384,8 +386,8 @@ def forward( hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -395,7 +397,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".c_fc_1", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -535,8 +537,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # With tie_word_embeddings, we can skip lm_head.weight diff --git a/vllm/model_executor/models/fairseq2_llama.py b/vllm/model_executor/models/fairseq2_llama.py index 310aca999bc2..00dbbebb120e 100644 --- a/vllm/model_executor/models/fairseq2_llama.py +++ b/vllm/model_executor/models/fairseq2_llama.py @@ -16,7 +16,7 @@ # limitations under the License. """Llama model for fairseq2 weights.""" -from typing import Iterable, Set, Tuple +from collections.abc import Iterable import torch from torch.nn import Parameter @@ -44,8 +44,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): f"model.{self.tp_rank}.pt", ] - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: # fairseq2's serialization adds a wrapper to usual .pt state_dict's: # { "model_key": my_model_name, "my_model_name": state_dict } # which we first need to unpack @@ -102,7 +102,7 @@ def reshape_fairseq2_weights( name: str, loaded_weight: torch.Tensor, params: dict[str, Parameter], - ) -> Tuple[str, torch.Tensor]: + ) -> tuple[str, torch.Tensor]: """Reshape fairseq2's weights.""" def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index e7e03fc09972..376793594f8b 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -20,7 +20,8 @@ """PyTorch Falcon model.""" import math -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -394,8 +395,8 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -405,7 +406,7 @@ def load_weights(self, weights: Iterable[Tuple[str, total_num_kv_heads = total_num_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: @@ -498,8 +499,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py new file mode 100644 index 000000000000..1c0e3911fcce --- /dev/null +++ b/vllm/model_executor/models/falcon_h1.py @@ -0,0 +1,684 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only FalconH1 model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import FalconH1Config + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsV0Only) +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class FalconH1MLP(nn.Module): + + def __init__( + self, + config: FalconH1Config, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.intermediate_size = config.intermediate_size + self.gate_multiplier, self.down_multiplier = config.mlp_multipliers + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x[:, :self.intermediate_size // self.tp_size] *= self.gate_multiplier + x = self.act_fn(x) + x, _ = self.down_proj(x) + x = x * self.down_multiplier + return x + + +class FalconH1SSMDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconH1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + + self.d_ssm = (int(config.mamba_expand * config.hidden_size) + if config.mamba_d_ssm is None else config.mamba_d_ssm) + + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=self.d_ssm, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + quant_config=quant_config, + use_rms_norm=config.mamba_rms_norm, + ) + # n_groups is overridden later by `MambaMixer2` + self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state + self.zxbcdt_multipliers = config.ssm_multipliers + self._init_mup_vector() + + def _init_mup_vector(self): + """ + Non learnable per-block scaling vector composed of element-wise + multipliersapplied to each separate contiguous block of the output + of the linear projection (in_proj) before further processing + (gating, convolution, SSM): + + - Z block: [0 : d_ssm] โ†’ zxbcdt_multipliers[0] + - X block: [d_ssm : 2 * d_ssm] โ†’ zxbcdt_multipliers[1] + - B block: [2 * d_ssm : 2 * d_ssm + G * S] โ†’ zxbcdt_multipliers[2] + - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S] + โ†’ zxbcdt_multipliers[3] + - dt block: [2 * d_ssm + 2 * G * S : end] โ†’ zxbcdt_multipliers[4] + + where: + - d_ssm: Dimension of state-space model latent + - G: Number of groups (n_groups) + - S: SSM state size per group + - All indices are divided by tp_size to support tensor parallelism + """ + vector_shape = (2 * self.d_ssm + 2 * self.groups_time_state_size + + self.config.mamba_n_heads) // self.tp_size + mup_vector = torch.ones(1, vector_shape) + # Z vector 0 -> d_ssm + mup_vector[:, :self.d_ssm // + self.tp_size] *= self.zxbcdt_multipliers[0] + # X vector d_ssm -> 2 * d_ssm + mup_vector[:, + (self.d_ssm // + self.tp_size):(2 * self.d_ssm // + self.tp_size)] *= self.zxbcdt_multipliers[1] + # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state) + mup_vector[ + :, + (2 * self.d_ssm) // + self.tp_size:(2 * self.d_ssm + self.groups_time_state_size) // + self.tp_size, + ] *= self.zxbcdt_multipliers[2] + # C vector 2 * d_ssm + (n_group * d_state) + # -> 2 * d_ssm + 2 * (n_group * d_state) + mup_vector[ + :, + (2 * self.d_ssm + self.groups_time_state_size) // + self.tp_size:(2 * self.d_ssm + 2 * self.groups_time_state_size) // + self.tp_size, + ] *= self.zxbcdt_multipliers[3] + # dt vector 2 * d_ssm + 2 * (n_group * d_state) + # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads + mup_vector[ + :, + (2 * self.d_ssm + 2 * self.groups_time_state_size) // + self.tp_size:, + ] *= self.zxbcdt_multipliers[4] + + self.register_buffer("mup_vector", mup_vector, persistent=False) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + hidden_states = self.mamba( + hidden_states, + mamba_cache_params, + mamba2_metadata=mamba2_metadata, + mup_vector=self.mup_vector, + ) + return hidden_states, residual + + +class FalconH1AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconH1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 1e11) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = (config.hidden_size // self.total_num_heads if getattr( + config, "head_dim", None) is None else config.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=None, # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.key_multiplier = config.key_multiplier + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + k = k * self.key_multiplier + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + ) + return hidden_states, residual + + +class FalconH1ParallelHybrid(nn.Module): + """ + A hybrid decoder layer for FalconH1 where the input is processed + in parallel through both the self-attention branch and the SSM (Mamba) + branch. Their outputs are then summed to produce the final hidden state. + + This layer uses: + - FalconH1AttentionDecoderLayer for the multi-head self-attention branch. + - FalconH1SSMDecoderLayer for the state-space (Mamba) branch. + """ + + def __init__( + self, + config: FalconH1Config, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Instantiate the attention branch + self.self_attn = FalconH1AttentionDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + # Instantiate the SSM branch + self.mamba = FalconH1SSMDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + ) + self.ssm_out_multiplier = config.ssm_out_multiplier + self.ssm_in_multiplier = config.ssm_in_multiplier + + self.attention_in_multiplier = config.attention_in_multiplier + self.attn_out_multiplier = config.attention_out_multiplier + + self.feed_forward = FalconH1MLP(config) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Process input through the attention branch. + # FalconH1AttentionDecoderLayer expects positions, hidden_states, + # kv_cache, attn_metadata, and residual. + attn_hidden, _ = self.self_attn( + positions=positions, + hidden_states=hidden_states * self.attention_in_multiplier, + residual=residual, + **kwargs, + ) + + # Process input through the SSM branch. + # FalconH1SSMDecoderLayer expects hidden_states, attn_metadata, + # residual, mamba_cache_params, and sequence_idx. + ssm_hidden, _ = self.mamba( + hidden_states=hidden_states * self.ssm_in_multiplier, + residual=residual, + mamba_cache_params=mamba_cache_params, + mamba2_metadata=mamba2_metadata, + **kwargs, + ) + # Sum the outputs from both branches. + # We assume both branches produce outputs of the same + # dimensionality (config.hidden_size). + hidden_states = (attn_hidden * self.attn_out_multiplier) + ( + ssm_hidden * self.ssm_out_multiplier) + hidden_states = hidden_states + residual + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class FalconH1Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: FalconH1Config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank: + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + else: + self.embed_tokens = PPMissingLayer() + self.embedding_multiplier = 1.0 + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = FalconH1ParallelHybrid + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.final_layernorm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + attn_metadata = get_forward_context().attn_metadata + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds * self.embedding_multiplier + else: + hidden_states = (self.get_input_embeddings(input_ids) * + self.embedding_multiplier) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsV0Only): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "FalconH1 currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = FalconH1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.tie_word_embeddings = config.tie_word_embeddings + self.unpadded_vocab_size = config.vocab_size + self.mamba_cache: Optional[MambaCacheManager] = None + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + ) + self.lm_head_multiplier = config.lm_head_multiplier + if self.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + # Used to track and store by the Mamba cache between steps. + + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=config.lm_head_multiplier, + ) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + if self.mamba_cache is None: + self.mamba_cache = MambaCacheManager( + self.vllm_config, + self.lm_head.weight.dtype + if hasattr(self.lm_head, 'weight') else torch.bfloat16, + self.config.num_hidden_layers, + *self._get_mamba_cache_shape(), + ) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model( + input_ids, + positions, + mamba_cache_params, + intermediate_tensors, + inputs_embeds, + ) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> tuple[tuple[int, int], tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = (int(self.config.mamba_expand * + hidden_size) if self.config.mamba_d_ssm + is None else self.config.mamba_d_ssm) + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size) + + # - heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if "mamba" in name: + name = name.replace("mamba", "mamba.mamba") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if self.tie_word_embeddings and "lm_head" in name: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.tie_word_embeddings: + loaded_params.add("lm_head.weight") + return loaded_params diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index d1a36c3f481a..f8acc56706d2 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -3,7 +3,7 @@ import math from collections import OrderedDict from collections.abc import Iterable, Mapping, Sequence -from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -713,8 +713,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -723,7 +723,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -922,8 +922,8 @@ def _build_image_projection_layers(self, config: PretrainedConfig): 'Florence2 only supports COSINE as temporal embedding.') def _validate_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: size = self.processor_config["size"] h, w = size["height"], size["width"] @@ -944,12 +944,12 @@ def _validate_shape(d: torch.Tensor): return data def _parse_and_validate_image_input(self, **kwargs: object): - pixel_values: Optional[Union[List[List[torch.Tensor]], - List[torch.Tensor], + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], torch.Tensor]] = kwargs.pop( "pixel_values", None) - image_embeds: Optional[Union[List[List[torch.Tensor]], - List[torch.Tensor], + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], torch.Tensor]] = kwargs.pop( "image_embeds", None) @@ -1096,7 +1096,7 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index d6bd6155a447..fbad7f56d0ba 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -18,7 +18,7 @@ """ PyTorch Fuyu model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, Set, Tuple, TypedDict +from typing import Literal, Optional, TypedDict import torch import torch.nn as nn @@ -382,7 +382,7 @@ def compute_logits( self.language_model.lm_head, hidden_states, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index c1cc0df11178..0f6d94e7518b 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -15,8 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" +from collections.abc import Iterable from functools import cache -from typing import Iterable, Optional, Set, Tuple, Union +from typing import Optional, Union import torch from torch import nn @@ -231,7 +232,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -318,8 +319,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -329,7 +330,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: @@ -413,8 +414,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 7fb2e9948c06..b46716213c62 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -15,7 +15,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -218,7 +219,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -305,8 +306,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -316,7 +317,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -413,8 +414,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 4e0d4f84ca6b..3a88adcce0bd 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -14,7 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.nn.functional as F @@ -320,7 +321,7 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -412,8 +413,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -423,7 +424,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -521,8 +522,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 65c177f8c5ad..182cc86d3ca8 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, Set, Tuple, TypedDict +from typing import Any, Literal, Optional, TypedDict import torch from torch import nn @@ -504,18 +504,12 @@ def dtype(self): return next(self.parameters()).dtype def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - if d.shape != expected_dims: - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_dims}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - + image_size = self.config.vision_config.image_size + expected_dims = (3, image_size, image_size) + if data.shape[1:] != expected_dims: + raise ValueError( + "The expected shape of pixel values per image per batch is " + f"{expected_dims}. You supplied {tuple(data.shape)}.") return data def _parse_and_validate_image_input( @@ -549,9 +543,7 @@ def _image_pixels_to_features( vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_features = vision_tower(pixel_values.to(dtype=target_dtype)) - return image_features + return vision_tower(pixel_values) def _process_image_input( self, @@ -701,8 +693,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 290be968cb54..f351ce5a0681 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -21,7 +21,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -60,7 +61,7 @@ def __init__(self, rope_theta: float = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None, + rope_scaling: Optional[tuple] = None, prefix: str = "", attn_type: str = AttentionType.DECODER) -> None: super().__init__() @@ -183,7 +184,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -293,8 +294,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index e3219333915e..c2c310fca4d9 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -18,7 +18,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -42,7 +43,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -234,6 +235,35 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class GPT2LMHeadModel(nn.Module, SupportsPP): @@ -280,34 +310,18 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if ".attn.bias" in name or ".attn.masked_bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if not name.startswith("transformer.") and not name.startswith( - "lm_head"): - name = "transformer." + name - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + weights = _add_transformer_prefix(weights) + return loader.load_weights(weights) + + +def _add_transformer_prefix( + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: + for name, tensor in weights: + if not name.startswith('transformer.') and not name.startswith( + "lm_head"): + name = 'transformer.' + name + yield name, tensor diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index def6b1544d8c..c4ae4fc3c006 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -19,7 +19,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -243,10 +244,10 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if ".attn.bias" in name: # Skip attention mask. @@ -271,12 +272,6 @@ def load_weights(self, weights: Iterable[Tuple[str, class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} - # LoRA specific attributes - embedding_modules = { - "wte": "input_embeddings", - "lm_head": "output_embeddings", - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -327,10 +322,13 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + skip_prefixes = None + if self.config.tie_word_embeddings: + skip_prefixes = ["lm_head."] loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."]), + skip_prefixes=skip_prefixes, ) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 3db96fb8e187..69fdd90cfbe8 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -17,7 +17,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -228,8 +229,8 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -239,7 +240,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "attn.bias" in name or "attn.masked_bias" in name: continue @@ -331,7 +332,7 @@ def compute_logits( sampling_metadata, self.lm_head.bias) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 620ee66f57e7..401fa9f5cc8b 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -17,7 +17,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -240,10 +241,10 @@ def forward( hidden_states = self.final_layer_norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): @@ -324,7 +325,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 0696a7245c22..3524d036db22 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -97,7 +98,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, @@ -121,8 +122,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = config.attention_multiplier @@ -230,7 +232,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -321,8 +323,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -332,7 +334,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -475,20 +477,16 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - skip_prefixes = [ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached", - ] + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings: - skip_prefixes.append("lm_head.weight") + skip_prefixes = (["lm_head."] + if self.config.tie_word_embeddings else None) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index b43b59da6d11..fd8fb48c50e3 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -21,9 +21,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only IBM Granite speeech model.""" +"""Inference-only IBM Granite speech model.""" import math -from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union +from collections.abc import Iterable, Mapping +from typing import Optional, TypedDict, Union import torch import torch.nn.functional as F @@ -625,7 +626,7 @@ def _build_input_features_mask( audio_embed_sizes: torch.Tensor, ) -> torch.Tensor: """Calculate the input features mask, which will generally be used - to mask the the padded features for all entries in the batch except + to mask the padded features for all entries in the batch except for those with the most audio features. Args: @@ -763,8 +764,8 @@ def compute_logits( def load_weights( self, - weights: Iterable[Tuple[str, torch.Tensor]], - ) -> Set[str]: + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 7fff14cb9f12..f342dfff824f 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GraniteMoe model.""" -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -305,8 +306,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: new_weights = {} for n, p in weights: if n.endswith('.block_sparse_moe.input_linear.weight'): @@ -425,8 +426,8 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 706e648f1b4f..443b102c9968 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only GraniteMoeHybrid model.""" # Added by the IBM Team, 2025 -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -381,10 +382,10 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() def _load(n, p): param = params_dict[n] @@ -538,7 +539,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size @@ -578,7 +579,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 4e660cbf667b..817e6091d276 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -4,7 +4,8 @@ The architecture is the same as granitemoe but with the addition of shared experts. """ -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -208,8 +209,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: new_weights = {} for n, p in weights: if n.endswith('.block_sparse_moe.input_linear.weight'): @@ -329,8 +330,8 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index e4692c458088..6a444e8d1068 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,22 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 from array import array -from typing import Optional, Union +from typing import Optional import torch import torch.nn as nn -from xformers.ops.fmha.attn_bias import BlockDiagonalMask -from vllm.attention.backends.xformers import XFormersImpl from vllm.config import ModelConfig, VllmConfig -from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) -from vllm.sequence import (IntermediateTensors, PoolerOutput, - PoolingSequenceGroupOutput) +from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsV0Only @@ -204,38 +200,20 @@ def __init__( prefix: str = "", **kwargs, ) -> None: - super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + # Use full attention for pooling + if vllm_config.model_config.runner_type == "pooling": + hf_config = vllm_config.model_config.hf_config + hf_config.is_causal = False - self.runner_type = vllm_config.model_config.runner_type + vllm_config.cache_config.sliding_window = None - self._pooler = GritLMPooler(vllm_config.model_config) + for attr in ("sliding_window", "interleaved_sliding_window"): + if hasattr(hf_config, attr): + delattr(hf_config, attr) - for layer in self.model.layers: - if self.runner_type == "pooling" and hasattr(layer, "self_attn"): - assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( - "GritLM embedding is only supported by XFormers backend, " - "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - - # Change attention to non-causal for pooling tasks. - if self.runner_type == "pooling": - attn_metadata = get_forward_context().attn_metadata - assert attn_metadata.prefill_metadata.attn_bias is None - attn_metadata.prefill_metadata.attn_bias = [ - BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) - ] - - return super().forward( - input_ids=input_ids, - positions=positions, - **kwargs, - ) + self._pooler = GritLMPooler(vllm_config.model_config) def pooler( self, diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 6f56eb2d5e38..bc9e9a3c0206 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -21,13 +21,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Grok1 model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.nn.functional as F from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -181,25 +182,20 @@ def __init__( quant_config=quant_config, logits_soft_cap=attn_logits_soft_cap, prefix=f"{prefix}.attn") + self.attn_multiplier = getattr(self.config, "attn_output_multiplier", + 1.0) if self.config else 1.0 def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) - - # Apply attention output multiplier if specified in config - attn_multiplier = getattr(self.config, "attn_output_multiplier", - None) if self.config else None - if attn_multiplier is not None: - output = output * attn_multiplier + output *= self.attn_multiplier return output @@ -260,10 +256,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -275,8 +269,6 @@ def forward( hidden_states = self.attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Post attention normalization @@ -340,8 +332,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -358,9 +348,7 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -371,8 +359,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -390,7 +378,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=num_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if (self.quant_config is not None and @@ -528,13 +516,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states @@ -547,12 +532,14 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - skip_prefixes = ["rotary_emb.inv_freq"] + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: # Skip lm_head when tie_word_embeddings is True - if self.config.tie_word_embeddings: - skip_prefixes.append("lm_head") + skip_prefixes = (["lm_head"] + if self.config.tie_word_embeddings else None) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 99c226439ecb..904f5330c653 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -25,9 +25,10 @@ from .intern_vit import InternVisionModel from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, + BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, BaseInternVLProcessingInfo, BaseInternVLProcessor, - InternVLChatModel, InternVLDummyInputsBuilder, - InternVLMultiModalProcessor, build_transform, + InternVLChatModel, build_transform, find_closest_aspect_ratio, get_internvl_target_ratios) @@ -430,8 +431,8 @@ def get_num_image_tokens( ) -class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] - ): +class H2OVLMultiModalProcessor( + BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]): def _get_prompt_updates( self, @@ -514,7 +515,7 @@ def _cached_apply_hf_processor( @MULTIMODAL_REGISTRY.register_processor( H2OVLMultiModalProcessor, info=H2OVLProcessingInfo, - dummy_inputs=InternVLDummyInputsBuilder) + dummy_inputs=BaseInternVLDummyInputsBuilder) class H2OVLChatModel(InternVLChatModel): def _init_vision_model( diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index cb0379c10f3a..b8bdc7aa32b2 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -17,7 +17,8 @@ # limitations under the License. """PyTorch Idefics2 model.""" -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -342,8 +343,8 @@ def forward( last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -351,7 +352,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() layer_count = len(self.encoder.layers) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 961954c2b584..fdb128ef5b54 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -17,7 +17,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import torch from torch import nn @@ -85,7 +85,7 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): def get_hf_processor( self, *, - size: Optional[Dict[str, int]] = None, + size: Optional[dict[str, int]] = None, **kwargs: object, ) -> Idefics3Processor: if size is not None: @@ -752,8 +752,8 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 7fea9647ead9..8be8841c1f6c 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, - Protocol, Type, Union, overload, runtime_checkable) +from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, + Union, overload, runtime_checkable) import torch from torch import Tensor @@ -102,7 +102,7 @@ class _SupportsMultiModalType(Protocol): @overload def supports_multimodal( - model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]: + model: type[object]) -> TypeIs[type[SupportsMultiModal]]: ... @@ -112,8 +112,8 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: def supports_multimodal( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: if isinstance(model, type): return isinstance(model, _SupportsMultiModalType) @@ -134,9 +134,9 @@ class SupportsLoRA(Protocol): """ # The `embedding_module` and `embedding_padding_modules` # are empty by default. - embedding_modules: ClassVar[Dict[str, str]] = {} - embedding_padding_modules: ClassVar[List[str]] = [] - packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} + embedding_modules: ClassVar[dict[str, str]] = {} + embedding_padding_modules: ClassVar[list[str]] = [] + packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} # We can't use runtime_checkable with ClassVar for issubclass checks @@ -145,13 +145,13 @@ class SupportsLoRA(Protocol): class _SupportsLoRAType(Protocol): supports_lora: Literal[True] - packed_modules_mapping: Dict[str, List[str]] - embedding_modules: Dict[str, str] - embedding_padding_modules: List[str] + packed_modules_mapping: dict[str, list[str]] + embedding_modules: dict[str, str] + embedding_padding_modules: list[str] @overload -def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]: +def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]: ... @@ -161,8 +161,8 @@ def supports_lora(model: object) -> TypeIs[SupportsLoRA]: def supports_lora( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]: result = _supports_lora(model) if not result: @@ -191,7 +191,7 @@ def supports_lora( return result -def _supports_lora(model: Union[Type[object], object]) -> bool: +def _supports_lora(model: Union[type[object], object]) -> bool: if isinstance(model, type): return isinstance(model, _SupportsLoRAType) @@ -226,9 +226,11 @@ def forward( intermediate_tensors: Optional["IntermediateTensors"], ) -> Union[Tensor, "IntermediateTensors"]: """ - Accept {class}`IntermediateTensors` when PP rank > 0. + Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when + PP rank > 0. - Return {class}`IntermediateTensors` only for the last PP rank. + Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only + for the last PP rank. """ ... @@ -256,7 +258,7 @@ def forward( @overload -def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: +def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]: ... @@ -266,8 +268,8 @@ def supports_pp(model: object) -> TypeIs[SupportsPP]: def supports_pp( - model: Union[Type[object], object], -) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + model: Union[type[object], object], +) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]: supports_attributes = _supports_pp_attributes(model) supports_inspect = _supports_pp_inspect(model) @@ -298,14 +300,14 @@ def supports_pp( return supports_attributes and supports_inspect -def _supports_pp_attributes(model: Union[Type[object], object]) -> bool: +def _supports_pp_attributes(model: Union[type[object], object]) -> bool: if isinstance(model, type): return isinstance(model, _SupportsPPType) return isinstance(model, SupportsPP) -def _supports_pp_inspect(model: Union[Type[object], object]) -> bool: +def _supports_pp_inspect(model: Union[type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): return False @@ -336,13 +338,13 @@ def has_inner_state(model: object) -> TypeIs[HasInnerState]: @overload -def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]: +def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: ... def has_inner_state( - model: Union[Type[object], object] -) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]: + model: Union[type[object], object] +) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]: if isinstance(model, type): return isinstance(model, _HasInnerStateType) @@ -373,13 +375,13 @@ def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: @overload -def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]: +def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: ... def is_attention_free( - model: Union[Type[object], object] -) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]: + model: Union[type[object], object] +) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]: if isinstance(model, type): return isinstance(model, _IsAttentionFreeType) @@ -410,13 +412,13 @@ def is_hybrid(model: object) -> TypeIs[IsHybrid]: @overload -def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]: +def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: ... def is_hybrid( - model: Union[Type[object], object] -) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]: + model: Union[type[object], object] +) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]: if isinstance(model, type): return isinstance(model, _IsHybridType) @@ -439,13 +441,13 @@ def has_noops(model: object) -> TypeIs[HasNoOps]: @overload -def has_noops(model: Type[object]) -> TypeIs[Type[HasNoOps]]: +def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: ... def has_noops( - model: Union[Type[object], object] -) -> Union[TypeIs[Type[HasNoOps]], TypeIs[HasNoOps]]: + model: Union[type[object], object] +) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]: if isinstance(model, type): return isinstance(model, _HasNoOpsType) @@ -461,7 +463,7 @@ class SupportsCrossEncoding(Protocol): @overload def supports_cross_encoding( - model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]: + model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]: ... @@ -471,8 +473,8 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: def _supports_cross_encoding( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: if isinstance(model, type): return isinstance(model, SupportsCrossEncoding) @@ -481,15 +483,15 @@ def _supports_cross_encoding( def supports_cross_encoding( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: return is_pooling_model(model) and _supports_cross_encoding(model) class SupportsQuant: """The interface required for all models that support quantization.""" - packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} + packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} quant_config: Optional[QuantizationConfig] = None def __new__(cls, *args, **kwargs) -> Self: @@ -525,7 +527,7 @@ class SupportsTranscription(Protocol): @overload def supports_transcription( - model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: + model: type[object]) -> TypeIs[type[SupportsTranscription]]: ... @@ -535,8 +537,8 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: def supports_transcription( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]: if isinstance(model, type): return isinstance(model, SupportsTranscription) @@ -551,7 +553,7 @@ class SupportsV0Only(Protocol): @overload -def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]: +def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]: ... @@ -561,8 +563,8 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: def supports_v0_only( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]: + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]: if isinstance(model, type): return isinstance(model, SupportsV0Only) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index f141dcf3cd4f..d325a6b67132 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload, +from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload, runtime_checkable) import torch @@ -20,7 +20,7 @@ # The type of hidden states # Currently, T = torch.Tensor for all models except for Medusa -# which has T = List[torch.Tensor] +# which has T = list[torch.Tensor] T = TypeVar("T", default=torch.Tensor) T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) @@ -48,12 +48,12 @@ def forward( ... -def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: +def _check_vllm_model_init(model: Union[type[object], object]) -> bool: model_init = model.__init__ return supports_kw(model_init, "vllm_config") -def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: +def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): return False @@ -75,7 +75,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: @overload -def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]: +def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ... @@ -85,8 +85,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: def is_vllm_model( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]: + model: Union[type[object], object], +) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: return _check_vllm_model_init(model) and _check_vllm_model_forward(model) @@ -105,7 +105,7 @@ def compute_logits( @overload def is_text_generation_model( - model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]: + model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]: ... @@ -116,8 +116,8 @@ def is_text_generation_model( def is_text_generation_model( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[VllmModelForTextGeneration]], + model: Union[type[object], object], +) -> Union[TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration]]: if not is_vllm_model(model): return False @@ -142,7 +142,7 @@ def pooler( @overload -def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]: +def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ... @@ -152,8 +152,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: def is_pooling_model( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: + model: Union[type[object], object], +) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: if not is_vllm_model(model): return False diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index fdcef8b9be8d..d9d9002bd5ba 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -6,8 +6,9 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +from collections.abc import Iterable from functools import partial -from typing import Iterable, Optional, Set, Tuple +from typing import Optional import torch import torch.nn as nn @@ -461,10 +462,10 @@ def forward( return encoder_outputs - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index c3d7cbfcddbb..3f3e3966e838 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable from functools import partial -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -81,7 +82,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -225,7 +226,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -252,7 +253,7 @@ def __init__( *, vllm_config: VllmConfig, prefix: str = "", - layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer): + layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer): super().__init__() config = vllm_config.model_config.hf_config @@ -316,7 +317,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", - model_type: Type[InternLM2Model] = InternLM2Model): + model_type: type[InternLM2Model] = InternLM2Model): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -361,15 +362,15 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), ("gate_up_proj", "w3", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -407,7 +408,7 @@ def __init__( *, vllm_config: VllmConfig, prefix: str = "", - model_type: Type[InternLM2Model] = InternLM2Model, + model_type: type[InternLM2Model] = InternLM2Model, ): super().__init__(vllm_config=vllm_config, prefix=prefix, diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 69b0caab8f8e..6893d0239121 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import nn @@ -66,7 +66,7 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], visual_token_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 23b92ad2bbf6..4612fc438741 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -8,8 +8,9 @@ # -------------------------------------------------------- from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union +from typing import Any, Literal, Optional, TypedDict, TypeVar, Union +import numpy.typing as npt import torch import torch.nn as nn import torchvision.transforms as T @@ -23,6 +24,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, @@ -73,11 +75,38 @@ class InternVLImageEmbeddingInputs(TypedDict): InternVLImageEmbeddingInputs] +class InternVLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_video * num_frames, num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class InternVLVideoEmbeddingInputs(TypedDict): + type: Literal["video_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` + or a list of tensors of shape `(total_video_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +InternVLVideoInputs = Union[InternVLVideoPixelInputs, + InternVLVideoEmbeddingInputs] + + # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD return T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Lambda(lambda img: convert_image_mode(img, 'RGB')), T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), @@ -230,6 +259,33 @@ def image_to_pixel_values_internvl( return pixel_values +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def video_to_pixel_values_internvl( + video: npt.NDArray, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, +) -> torch.Tensor: + target_ratios = get_internvl_target_ratios(min_num, max_num) + + transform = build_transform(input_size=input_size) + frames_list = list[Image.Image]() + for frame in video: + pil_frame = dynamic_preprocess_internvl( + Image.fromarray(frame, mode="RGB"), + target_ratios=target_ratios, + image_size=input_size, + use_thumbnail=use_thumbnail, + ) + assert len(pil_frame) == 1 + frames_list.extend(pil_frame) + + pixel_values = torch.stack([transform(image) for image in frames_list]) + return pixel_values + + class BaseInternVLProcessor(ABC): """ This model doesn't define its own HF processor, @@ -374,24 +430,14 @@ def _images_to_pixel_values_lst( ) for image in images ] - def __call__( + def _preprocess_image( self, - text: Optional[Union[str, list[str]]] = None, - images: Optional[Union[Image.Image, list[Image.Image]]] = None, + text: list[str], + images: list[Image.Image], min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: - if text is None: - text = [] - if not isinstance(text, list): - text = [text] - if images is None: - images = [] - if not isinstance(images, list): - images = [images] - + ) -> tuple[list[str], dict[str, torch.Tensor]]: if len(images) == 0: image_inputs = {} else: @@ -414,6 +460,34 @@ def __call__( image_repl = self.get_image_repl(feature_size, num_patches) text = [t.replace('<image>', image_repl.full, 1) for t in text] + return text, image_inputs + + def _make_batch_input(self, + input_item: Optional[Union[Any, list[Any]]] = None): + if input_item is None: + input_item = [] + if not isinstance(input_item, list): + input_item = [input_item] + return input_item + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> Mapping[str, NestedTensors]: + text, images = [self._make_batch_input(x) for x in (text, images)] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) text_inputs = self.tokenizer(text) @@ -424,11 +498,133 @@ def __call__( class InternVLProcessor(BaseInternVLProcessor): + """ + HF Processor for InternVLChatModel with extended video processing logic. + + Code for video processing is adapted from video example: + https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + video_token: Optional[str] = None, + ) -> None: + super().__init__( + config=config, + tokenizer=tokenizer, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + # add extra video token for video processing + self.video_token = video_token @property def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_CONTEXT] + @property + def video_token_id(self) -> Optional[int]: + if self.video_token is None: + return None + return self.tokenizer.get_vocab().get(self.video_token, None) + + @property + def supports_video(self) -> bool: + return self.video_token_id is not None + + def _videos_to_pixel_values_lst( + self, + videos: list[npt.NDArray], + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=1, + max_dynamic_patch=1, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + video_to_pixel_values_internvl( + video, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=False, + ) for video in videos + ] + + def _preprocess_video( + self, + text: list[str], + videos: list[npt.NDArray], + dynamic_image_size: Optional[bool] = None, + ): + if len(videos) == 0 or not self.supports_video: + video_inputs = {} + else: + pixel_values_lst_video = self._videos_to_pixel_values_lst( + videos, + dynamic_image_size=dynamic_image_size, + ) + video_inputs: dict[str, NestedTensors] = { + "pixel_values_flat_video": + torch.cat(pixel_values_lst_video), + "video_num_patches": + torch.tensor([len(item) for item in pixel_values_lst_video]), + } + + for pixel_values in pixel_values_lst_video: + num_patches = pixel_values.shape[0] + + video_repl = self.get_video_repl(self.num_image_token, + num_patches, self.video_token) + text = [t.replace('<video>', video_repl.full, 1) for t in text] + return text, video_inputs + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> Mapping[str, NestedTensors]: + text, images, videos = [ + self._make_batch_input(x) for x in (text, images, videos) + ] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + + text, video_inputs = self._preprocess_video( + text=text, + videos=videos, + dynamic_image_size=dynamic_image_size, + ) + + text_inputs = self.tokenizer(text) + + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + **video_inputs, + } + def get_image_repl( self, feature_size: int, @@ -439,8 +635,24 @@ def get_image_repl( return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) + def get_video_repl( + self, + feature_size: int, + num_patches: Optional[int] = None, + video_context_token: str = IMG_CONTEXT, + ) -> PromptUpdateDetails[str]: + repl_features = video_context_token * self.num_image_token + repl_features_with_sep = IMG_START + repl_features + IMG_END + # num_patches is equal to num_frames + repl_full = ''.join([ + f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches) + ]) + + return PromptUpdateDetails.select_text(repl_full, video_context_token) + class BaseInternVLProcessingInfo(BaseProcessingInfo): + """Basic image-only ProcessingInfo for InternVL-style models.""" @abstractmethod def get_hf_processor( @@ -496,11 +708,22 @@ def get_image_size_with_most_features(self) -> ImageSize: return largest_feature_pinpoint + def get_max_image_tokens(self) -> int: + processor = self.get_hf_processor() + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=processor, + ) + _I = TypeVar("_I", bound=BaseInternVLProcessingInfo) -class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): +class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + """Basic image-only DummyInputsBuilder for InternVL-style models.""" def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -524,7 +747,8 @@ def get_dummy_mm_data( } -class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): +class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): + """ Basic image-only MultiModalProcessor for InternVL-style models.""" def _call_hf_processor( self, @@ -613,6 +837,38 @@ def get_replacement_internvl(item_idx: int): class InternVLProcessingInfo(BaseInternVLProcessingInfo): + """InternVL ProcessingInfo extended for video processing""" + + @property + def supports_video(self): + return self.get_hf_processor().supports_video + + def get_supported_mm_limits(self): + video_limit = {"video": None} if self.supports_video else {} + return {**super().get_supported_mm_limits(), **video_limit} + + def get_video_token(self) -> Optional[str]: + text_model_type = self.get_hf_config().get_text_config().model_type + if text_model_type == "qwen2": + return "<|video_pad|>" + return None + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + processor = self.get_hf_processor() + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = (seq_len - + max_image_tokens) // processor.num_image_token + max_frames_per_video = max_total_frames // max(max_videos, 1) + + return max(max_frames_per_video, 1) def get_hf_processor( self, @@ -629,6 +885,8 @@ def get_hf_processor( if dynamic_image_size is not None: kwargs["dynamic_image_size"] = dynamic_image_size + kwargs["video_token"] = self.get_video_token() + return self.ctx.init_processor( InternVLProcessor, config=self.get_hf_config(), @@ -637,6 +895,121 @@ def get_hf_processor( ) +class InternVLDummyInputsBuilder( + BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]): + """InternVL DummyInputsBuilder extended for video support""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_videos = mm_counts.get("video", 0) + + return super().get_dummy_text(mm_counts) + "<video>" * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + dummy_image = super().get_dummy_mm_data(seq_len=seq_len, + mm_counts=mm_counts) + if self.info.supports_video: + config = self.info.get_hf_config() + image_size: int = config.vision_config.image_size + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + num_videos = mm_counts.get("video", 0) + dummy_video = { + "video": + self._get_dummy_videos(width=image_size, + height=image_size, + num_frames=target_num_frames, + num_videos=num_videos) + } + else: + dummy_video = {} + return {**dummy_image, **dummy_video} + + +class InternVLMultiModalProcessor( + BaseInternVLMultiModalProcessor[InternVLProcessingInfo]): + """InternVL MultiModalProcessor extended for video support""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + processed_outputs = super()._call_hf_processor(prompt, mm_data, + mm_kwargs) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + if self.info.supports_video and ( + video_token_id := hf_processor.video_token_id) is not None: + processed_outputs["video_token_id"] = torch.tensor(video_token_id) + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_fields = super()._get_mm_fields_config(hf_inputs, + hf_processor_mm_kwargs) + if self.info.supports_video: + video_num_patches = hf_inputs.get("video_num_patches", + torch.empty(0)) + num_videos = len(video_num_patches) + video_fields = dict( + pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches), + video_num_patches=MultiModalFieldConfig.batched("video"), + video_token_id=MultiModalFieldConfig.shared( + "video", num_videos), + ) + else: + video_fields = {} + + return image_fields | video_fields + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + prompt_repl: list[PromptUpdate] = super()._get_prompt_updates( + mm_items, hf_processor_mm_kwargs, out_mm_kwargs) + + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "video_num_patches" in out_mm_kwargs: + video_num_patches = out_mm_kwargs["video_num_patches"] + assert isinstance(video_num_patches, torch.Tensor) + video_num_patches = video_num_patches.tolist() + else: + video_num_patches = [] + + def get_video_replacement_internvl(item_idx: int): + feature_size = hf_processor.num_image_token + num_patches = video_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + return hf_processor.get_video_repl( + feature_size, + num_patches, + video_context_token=hf_processor.video_token) + + if self.info.supports_video: + prompt_repl.append( + PromptReplacement( + modality="video", + target="<video>", + replacement=get_video_replacement_internvl, + )) + return prompt_repl + + @MULTIMODAL_REGISTRY.register_processor( InternVLMultiModalProcessor, info=InternVLProcessingInfo, @@ -680,6 +1053,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.mlp1 = self._init_mlp1(config) self.img_context_token_id = None + self.video_context_token_id = None + self.visual_token_mask = None self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -824,10 +1199,55 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[InternVLVideoPixelInputs]: + pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) + video_num_patches = kwargs.pop("video_num_patches", None) + video_embeds = kwargs.pop("image_embeds", None) + + if pixel_values_flat_video is None and video_embeds is None: + return None + + if video_embeds is not None: + if not isinstance(video_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + + return InternVLImageEmbeddingInputs( + type="video_embeds", + data=flatten_bn(video_embeds), + ) + + video_token_id = kwargs["video_token_id"] + assert isinstance(video_token_id, torch.Tensor) + self.video_context_token_id = video_token_id.flatten().unique().item() + + if pixel_values_flat_video is not None: + if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat_video)}") + + if not isinstance(video_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(video_num_patches)}") + + pixel_values_flat_video = flatten_bn(pixel_values_flat_video, + concat=True) + video_num_patches = flatten_bn(video_num_patches, concat=True) + + return InternVLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_flat=self._validate_pixel_values( + pixel_values_flat_video), + num_patches=video_num_patches, + ) + + raise AssertionError("This line should be unreachable.") + def _process_image_input( self, - image_input: InternVLImageInputs, - ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: + image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs], + ) -> tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["data"] @@ -839,8 +1259,8 @@ def _process_image_input( # Only one image in the current batch if len(num_patches) == 1: - return image_embeds.view( - -1, self.config.text_config.hidden_size).unsqueeze(0) + return (image_embeds.view(-1, + self.config.text_config.hidden_size), ) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. @@ -852,8 +1272,26 @@ def _process_image_input( ] return image_embeds.split(image_feature_sizes) + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values_flat", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_flat_video", + ) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if self.is_mono: + assert self.img_context_token_id is not None self.visual_token_mask = ( input_ids == self.img_context_token_id).reshape(-1, 1) else: @@ -864,11 +1302,28 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: return None - return self._process_image_input(image_input) + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_image_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings def get_input_embeddings( self, @@ -877,13 +1332,18 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - assert self.img_context_token_id is not None + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] + assert len(context_token_ids) >= 1 self._set_visual_token_mask(input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, - self.img_context_token_id, + context_token_ids, ) return inputs_embeds @@ -932,8 +1392,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B skip_prefixes = [ "action_embed", "temporal_embed", "track_embed", diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index e1e3f0f199c5..d6a1e0bb4845 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -21,7 +21,8 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -333,10 +334,10 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 46335c2b3930..6f9fa60c9b05 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only Jamba model.""" -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -442,7 +443,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size conv_state_shape = ( @@ -464,8 +465,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -482,7 +483,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -583,7 +584,7 @@ def pooler( logits = self.score(hidden_states) return self._pooler(logits, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # TODO: The reward weights themselves have float32 accuracy data, we # would like to load them in fp32 to get that extra precision. super().load_weights(weights) diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 0629266860fd..b575f44765a8 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -43,10 +43,9 @@ import copy import math -from collections.abc import Mapping +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass -from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple, - TypedDict, Union) +from typing import Any, Literal, Optional, TypedDict, Union import torch from torch import nn @@ -120,7 +119,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: class KimiVLImagePixelInputs(TypedDict): type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, List[torch.Tensor]] + pixel_values: Union[torch.Tensor, list[torch.Tensor]] """ Shape:`(num_patches, num_channels, patch_size, patch_size)` """ @@ -447,7 +446,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata, **kwargs) return logits - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): config = self.config.text_config _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 7a3ea7a68768..d36b6466c0bb 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -22,13 +22,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn from transformers import LlamaConfig -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -96,19 +97,22 @@ def forward(self, x): class LlamaAttention(nn.Module): - def __init__(self, - config: LlamaConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: super().__init__() layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size @@ -158,20 +162,9 @@ def __init__(self, prefix=f"{prefix}.o_proj", ) - is_neox_style = True - is_gguf = quant_config and quant_config.get_name() == "gguf" - if is_gguf and config.model_type == "llama": - is_neox_style = False - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, - ) + self._init_rotary_emb(config, + rope_scaling=rope_scaling, + quant_config=quant_config) if hasattr(config, "interleaved_sliding_window"): interleaved_sliding_window = config.interleaved_sliding_window @@ -194,6 +187,7 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, + attn_type=attn_type, prefix=f"{prefix}.attn", ) @@ -209,6 +203,24 @@ def forward( output, _ = self.o_proj(attn_output) return output + def _init_rotary_emb(self, config: LlamaConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + class LlamaDecoderLayer(nn.Module): @@ -238,6 +250,15 @@ def __init__( if hasattr(config, 'qkv_bias'): attention_bias = config.qkv_bias + # By default, Llama uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, @@ -252,6 +273,7 @@ def __init__( bias_o_proj=bias_o_proj, cache_config=cache_config, prefix=f"{prefix}.self_attn", + attn_type=attn_type, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -271,7 +293,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -380,8 +402,8 @@ def forward( return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -391,7 +413,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -568,8 +590,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] @@ -585,7 +607,7 @@ def maybe_remap_mistral( self, name: str, loaded_weight: torch.Tensor, - ) -> Tuple[str, torch.Tensor]: + ) -> tuple[str, torch.Tensor]: def permute(w: torch.Tensor, n_heads: int): attn_in = self.config.head_dim * n_heads diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0fdc30f36f9b..40fdd84d8fb0 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -16,7 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Any, Optional import torch from torch import nn @@ -25,8 +26,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -49,7 +49,7 @@ def custom_routing_function( gating_output: torch.Tensor, topk: int, renormalize: bool, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: router_scores, router_indices = fast_topk(gating_output, topk, dim=-1) # psuedo-standard is that the router scores are floats router_scores = torch.sigmoid(router_scores.float()) @@ -89,7 +89,7 @@ def __init__(self, quant_config=quant_config, bias=False, prefix=f"{prefix}.shared_expert", - reduce_results=False, # We need to do scatter before reduce + reduce_results=self.experts.must_reduce_shared_expert_outputs(), ) def forward(self, hidden_states): @@ -102,7 +102,8 @@ def forward(self, hidden_states): experts_out = routed_out + shared_out if self.tp_size > 1: - experts_out = tensor_model_parallel_all_reduce(experts_out) + experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( + experts_out) return experts_out @@ -115,7 +116,7 @@ def __init__(self, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, @@ -300,7 +301,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -335,9 +336,9 @@ def load_moe_expert_weights( self, name: str, loaded_weight: torch.Tensor, - params_dict: Dict[str, nn.Parameter], - loaded_params: Set[str], - expert_params_mapping: List[Tuple[str, str, int, str]], + params_dict: dict[str, nn.Parameter], + loaded_params: set[str], + expert_params_mapping: list[tuple[str, str, int, str]], fused: bool = True, ) -> bool: expert_param_loaded = False @@ -390,8 +391,8 @@ def load_moe_expert_weights( expert_param_loaded = True return expert_param_loaded - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -412,7 +413,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ckpt_up_proj_name="gate_up_proj", num_experts=1) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "experts.gate_up_proj" in name or "experts.down_proj" in name: fused_experts_params = True @@ -489,8 +490,8 @@ def _init_model(self, prefix=prefix, layer_type=layer_type) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] @@ -506,7 +507,7 @@ def permute_qk_weight_for_rotary( self, name: str, loaded_weight: torch.Tensor, - ) -> Tuple[str, torch.Tensor]: + ) -> tuple[str, torch.Tensor]: def permute(w: torch.Tensor, n_heads: int): attn_in = self.config.head_dim * n_heads diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 76655bd71b15..172dc8b5ec06 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Set, Tuple +from collections.abc import Iterable import torch import torch.nn as nn @@ -8,6 +8,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -52,11 +53,15 @@ def __init__( self.config = vllm_config. \ speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "embed_tokens"), - ) + + # if PP disabled then draft will share embed with target + if get_pp_group().world_size > 1: + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ LlamaDecoderLayer( self.config, @@ -87,8 +92,8 @@ def forward( hidden_states = hidden_states + residual return hidden_states, hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -98,7 +103,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -109,6 +114,12 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight, shard_id) break else: + + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -119,13 +130,15 @@ def load_weights(self, weights: Iterable[Tuple[str, class EagleLlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) self.model = LlamaModel(vllm_config=vllm_config, prefix="model", - start_layer_id=start_layer_id) + start_layer_id=target_layer_num) logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.config.vocab_size, @@ -139,11 +152,10 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=None, ) model_weights = {} @@ -151,5 +163,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "lm_head" not in name: name = "model." + name model_weights[name] = loaded_weight - - loader.load_weights(model_weights.items()) + return loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 904ff3210943..f211bfe54a7d 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -8,6 +9,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -55,7 +57,7 @@ def forward( embeds: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states embeds = self.input_layernorm(embeds) @@ -91,11 +93,15 @@ def __init__( self.config = vllm_config. \ speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "embed_tokens"), - ) + + # if PP disabled then draft will share embed with target + if get_pp_group().world_size > 1: + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ LlamaDecoderLayer( self.config, @@ -135,8 +141,8 @@ def forward( hidden_states, hidden_prenorm = self.norm(hidden_states, residual) return hidden_states, hidden_prenorm - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -146,7 +152,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if 'midlayer.' in name: name = name.replace('midlayer.', 'layers.0.') @@ -169,13 +175,15 @@ def load_weights(self, weights: Iterable[Tuple[str, class Eagle3LlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) self.model = LlamaModel(vllm_config=vllm_config, - start_layer_id=start_layer_id, - prefix="model") + prefix="model", + start_layer_id=target_layer_num) logit_scale = getattr(self.config, "logit_scale", 1.0) self.lm_head = ParallelLMHead( @@ -187,8 +195,7 @@ def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale) self.draft_id_to_target_id = nn.Parameter( - torch.zeros((self.config.draft_vocab_size), - dtype=torch.long).type(torch.LongTensor), + torch.zeros(self.config.draft_vocab_size, dtype=torch.long), requires_grad=False, ) @@ -207,6 +214,9 @@ def compute_logits( ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + if self.draft_id_to_target_id is None: + return logits + base = torch.arange(self.config.draft_vocab_size, device=logits.device) targets = base + self.draft_id_to_target_id logits_new = logits.new_full(( @@ -223,7 +233,7 @@ def combine_hidden_states( # combine multiple auxiliary hidden states returned by eagle3 return self.model.fc(hidden_states) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, skip_prefixes=None, @@ -239,4 +249,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = "model." + name model_weights[name] = loaded_weight - return loader.load_weights(model_weights.items()) + loaded_weights = loader.load_weights(model_weights.items()) + + if 'd2t' not in loaded_weights: + self.draft_id_to_target_id = None + + return loaded_weights diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 6287fdb3300c..ced71b6dcdeb 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -2,8 +2,8 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, - TypeVar, Union, cast) +from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, + Union, cast) import torch import torch.nn as nn @@ -721,9 +721,8 @@ def forward( batch. pixel_values: The pixels in each input image. - :::{seealso} - {class}`LlavaImageInputs` - ::: + Info: + [LlavaImageInputs][] """ if intermediate_tensors is not None: inputs_embeds = None @@ -751,8 +750,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c7e8d6991b25..2fb79f57a67f 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod -from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, TypeVar, Union) +from collections.abc import Iterable, Mapping +from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, + Union) import torch import torch.nn as nn @@ -134,11 +135,13 @@ def _get_num_unpadded_features( current_aspect_ratio = current_width / current_height if aspect_ratio > current_aspect_ratio: - new_height = (original_height * current_width) // original_width + new_height = int( + round(original_height * (current_width / original_width), 7)) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: - new_width = (original_width * current_height) // original_height + new_width = int( + round(original_width * (current_height / original_height), 7)) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) @@ -266,8 +269,8 @@ def _validate_shape(d: torch.Tensor): return data def _validate_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -450,7 +453,7 @@ def _process_image_pixels( def _process_image_input( self, image_input: LlavaNextImageInputs, - ) -> Union[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, list[torch.Tensor]]: if image_input["type"] == "image_embeds": return [image_input["data"]] @@ -537,7 +540,7 @@ def forward( Unlike in LLaVA-1.5, the number of image tokens inputted to the language model depends on the original size of the input image. Including the original image token in the input, the required number of image tokens - is given by {func}`get_llava_next_image_feature_size`. + is given by [get_llava_next_image_feature_size][]. This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. @@ -548,9 +551,8 @@ def forward( pixel_values: The pixels in each grid patch for each input image. image_sizes: The original `(height, width)` for each input image. - :::{seealso} - {class}`LlavaNextImageInputs` - ::: + Info: + [LlavaNextImageInputs][] """ if intermediate_tensors is not None: inputs_embeds = None @@ -577,7 +579,7 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index a5ff189cfdb5..9303ea121727 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -2,7 +2,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -35,7 +35,7 @@ class LlavaNextVideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: Union[torch.Tensor, list[torch.Tensor]] """ Shape: `(batch_size, num_frames, num_channels, height, width)` @@ -300,8 +300,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.language_model.model.make_empty_intermediate_tensors) def _validate_video_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -326,7 +326,7 @@ def _parse_and_validate_video_input( A legal video input should have the following dimensions: { "pixel_values_videos" : - List[b, Tensor(nb_frames, nb_channels, height, width)] + list[b, Tensor(nb_frames, nb_channels, height, width)] } """ pixel_values_videos = kwargs.pop("pixel_values_videos", None) @@ -460,8 +460,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # This model doesn't support images for now diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 5c2b388e403d..7ea759fd59b8 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -2,8 +2,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, - TypedDict, Union) +from typing import Final, Literal, Optional, Protocol, TypedDict, Union import torch import torch.nn as nn @@ -117,11 +116,13 @@ def _get_num_unpadded_features( current_aspect_ratio = current_width / current_height if aspect_ratio > current_aspect_ratio: - new_height = (original_height * current_width) // original_width + new_height = int( + round(original_height * (current_width / original_width), 7)) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: - new_width = (original_width * current_height) // original_height + new_width = int( + round(original_width * (current_height / original_height), 7)) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) @@ -471,8 +472,8 @@ def _validate_shape(d: torch.Tensor): return data def _validate_image_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -530,8 +531,8 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") def _validate_video_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -557,7 +558,7 @@ def _parse_and_validate_video_input( A legal video input should have the following dimensions: { "pixel_values_videos" : - List[b, Tensor(nb_frames, nb_channels, height, width)] + list[b, Tensor(nb_frames, nb_channels, height, width)] } """ pixel_values_videos = kwargs.pop("pixel_values_videos", None) @@ -706,7 +707,7 @@ def _merge_image_patch_embeddings(self, def _process_image_pixels( self, inputs: LlavaOnevisionImagePixelInputs, - ) -> Union[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, list[torch.Tensor]]: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] @@ -735,7 +736,7 @@ def _process_image_pixels( def _process_image_input( self, image_input: LlavaOnevisionImageInputs, - ) -> Union[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, list[torch.Tensor]]: if image_input["type"] == "image_embeds": return [image_input["data"]] @@ -948,7 +949,7 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index af78ece66bbe..ce76a76b6574 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """PyTorch MAMBA model.""" -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -30,7 +31,7 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -KVCache = Tuple[torch.Tensor, torch.Tensor] +KVCache = tuple[torch.Tensor, torch.Tensor] class MambaDecoderLayer(nn.Module): @@ -153,10 +154,10 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -247,7 +248,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() conv_state_shape = ( self.config.intermediate_size // world_size, @@ -265,7 +266,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 72daf34c4412..858a1633befa 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """PyTorch MAMBA2 model.""" -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -35,7 +36,7 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -KVCache = Tuple[torch.Tensor, torch.Tensor] +KVCache = tuple[torch.Tensor, torch.Tensor] class Mamba2DecoderLayer(nn.Module): @@ -241,7 +242,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() conv_state_shape, temporal_state_shape = None, None @@ -279,10 +280,10 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 25839727898f..47d0ef9cc6bb 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Tuple import torch @@ -25,8 +24,8 @@ def at_layer_idx(self, layer_idx): class MambaCacheManager(ConstantSizeCache): def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, - num_mamba_layers: int, conv_state_shape: Tuple[int, int], - temporal_state_shape: Tuple[int, int]): + num_mamba_layers: int, conv_state_shape: tuple[int, int], + temporal_state_shape: tuple[int, int]): # Determine max batch size to set size of MambaCache max_batch_size = vllm_config.scheduler_config.max_num_seqs diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index a19d7da5654b..95ef1134b1bf 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -50,7 +51,7 @@ class Medusa(nn.Module): needs to have truncated_vocab_size (=k) as an attribute.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - config = vllm_config.model_config.hf_config + config = vllm_config.speculative_config.draft_model_config.hf_config super().__init__() self.config = config self.blocks = nn.ModuleList([ @@ -96,13 +97,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # checkpoint file has token_map tensor. self.token_map = None - def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]: return [block(hidden_states) for block in self.blocks] def compute_logits( - self, hidden_states: List[torch.Tensor], - sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: - logits_lst: List[torch.Tensor] = [] + self, hidden_states: list[torch.Tensor], + sampling_metadata: SamplingMetadata) -> list[torch.Tensor]: + logits_lst: list[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): _logits = self.logits_processor(lm_head, hs, sampling_metadata) @@ -127,9 +128,9 @@ def compute_logits( def sample( self, - logits: List[torch.Tensor], + logits: list[torch.Tensor], sampling_metadata: SamplingMetadata, - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: logits = torch.stack(logits, dim=0).float() logprobs = torch.log_softmax(logits, dim=-1) token_ids = logits.argmax(-1) # support only top-1 for now @@ -144,7 +145,7 @@ def sample( token_prob_list.append(probs[:, seq_group.sample_indices]) token_logprob_list.append(logprobs[:, seq_group.sample_indices]) - outputs: List[Optional[SamplerOutput]] = [] + outputs: list[Optional[SamplerOutput]] = [] for idx in range(len(sampling_metadata.seq_groups)): outputs.append( SamplerOutput( @@ -160,7 +161,14 @@ def generate_proposals( self, previous_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> List[SamplerOutput]: + ) -> Optional[list[SamplerOutput]]: + # During preemption, we may receive an empty tensor (batch_size=0) + if previous_hidden_states.size(0) == 0: + # Return None to signal the Top1Proposer that no proposals + # were generated for this batch, allowing it to handle this + # special case appropriately + return None + return self.sample( logits=self.compute_logits( hidden_states=self.forward(previous_hidden_states), @@ -169,10 +177,10 @@ def generate_proposals( sampling_metadata=sampling_metadata, ) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() weights_map = {} diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py new file mode 100644 index 000000000000..49ea64e029d6 --- /dev/null +++ b/vllm/model_executor/models/mimo.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2025 Xiaomi Corporation. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiMo model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix + +logger = init_logger(__name__) + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class MiMoModel(Qwen2Model): + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = hidden_states + residual + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "mtp_layers" in name: + continue + if "rotary_emb.inv_freq" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + + self.model = MiMoModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + hidden_states = self.model.norm(hidden_states) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py new file mode 100644 index 000000000000..cbca6a4c8f9d --- /dev/null +++ b/vllm/model_executor/models/mimo_mtp.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py +# Copyright 2025 Xiaomi Corporation. +# Copyright 2023 The vLLM team. +# Copyright 2024 DeepSeek-AI team. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiMo-MTP model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import maybe_prefix + + +class MiMoMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.token_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.hidden_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.input_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.mtp_block = Qwen2DecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.token_layernorm(inputs_embeds) + previous_hidden_states = self.hidden_layernorm(previous_hidden_states) + + hidden_states = self.input_proj( + torch.cat([previous_hidden_states, inputs_embeds], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return self.final_layernorm(hidden_states) + + +class MiMoMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.mtp_layers = torch.nn.ModuleDict({ + str(idx): + MiMoMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + return self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + inputs_embeds, + positions, + previous_hidden_states, + spec_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + lm_head: ParallelLMHead, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)] + logits = self.logits_processor(lm_head, hidden_states, + sampling_metadata) + return logits + + +class MiMoMTP(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = MiMoMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + assert spec_step_idx == 0, "mimo_mtp only support predict one token now" + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, self.lm_head, + sampling_metadata, spec_step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + if "rotary_emb.inv_freq" in name: + continue + name = self.map_model_name_to_mtp_param_name(name) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if "mtp_layers" not in name: + break + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if "mtp_layers" not in name and ("embed_tokens" not in name + and "lm_head" not in name): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def map_model_name_to_mtp_param_name(self, name: str) -> str: + import regex as re + name_without_prefix = [ + "token_layernorm", "hidden_layernorm", "input_proj", + "final_layernorm" + ] + for sub_name in name_without_prefix: + if sub_name in name: + return name + pattern = r"model.mtp_layers.(\d+)." + group = re.match(pattern, name) + if group is not None: + name = name.replace(group.group(), group.group() + "mtp_block.") + return name + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + return name diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 866dc3f466e7..0397b552ce9f 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -23,7 +23,8 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -190,7 +191,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -241,9 +242,6 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - # set rope as fp32 instead of bf16 - self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache( - ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -329,7 +327,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -428,8 +426,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -446,7 +444,7 @@ def load_weights(self, weights: Iterable[Tuple[str, for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -582,8 +580,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 1b24c38cef1b..2a6867d12d99 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -23,7 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM3 model compatible with HuggingFace weights.""" -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import nn @@ -58,7 +58,7 @@ def __init__( q_lora_rank: int, kv_lora_rank: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index f42d48e919cd..ae5df0f9273f 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -23,8 +23,7 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, - Union) +from typing import Any, Callable, Literal, Optional, TypedDict, Union import torch from torch import nn @@ -559,8 +558,8 @@ def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): self.audio_encoder_layer = -1 return model - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 300360f785ae..04cc7e35e345 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -26,8 +26,7 @@ from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, - Union) +from typing import Any, Callable, Literal, Optional, TypedDict, Union import numpy as np import torch @@ -118,7 +117,7 @@ def __init__(self, num_heads: int, kv_dim: Optional[int] = None, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - max_size: Tuple[int, int] = (70, 70), + max_size: tuple[int, int] = (70, 70), quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: super().__init__(num_queries, @@ -133,7 +132,7 @@ def __init__(self, self._set_2d_pos_cache(self.max_size) def _set_2d_pos_cache(self, - max_size: Tuple[int, int], + max_size: tuple[int, int], device: torch.types.Device = "cpu") -> None: pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, max_size, @@ -203,7 +202,7 @@ def forward(self, x: torch.Tensor, return x -def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: +def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: version_float = getattr(config, "version", None) # The old configs do not include version number @@ -938,8 +937,8 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.llm.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 951f4e2304a1..36bab9ee13b1 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -2,9 +2,10 @@ """Inference-only MiniMaxText01 model.""" import copy import math -import re -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union +import regex as re import torch import torch.distributed import torch.nn.functional as F @@ -127,7 +128,7 @@ def forward( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert residual is None, "RMSNorm does not support residual connection." return self._forward(x) @@ -178,7 +179,7 @@ def forward( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device) query_cast = query.to(self.cache_dtype) @@ -603,8 +604,9 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", - config.hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = config.hidden_size // config.num_attention_heads if hasattr(config, "max_model_len") and isinstance( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, @@ -708,11 +710,11 @@ def __init__( def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: Union[List[Dict], Optional[torch.Tensor]], + kv_caches: Union[list[dict], Optional[torch.Tensor]], attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], is_warmup: bool = False, - **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + **kwargs) -> tuple[torch.Tensor, torch.Tensor]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -860,8 +862,9 @@ def layer_fn(prefix): cache_shape=self.cache_shape) rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", - config.hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = config.hidden_size // config.num_attention_heads if hasattr(config, "max_model_len") and isinstance( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, @@ -1072,10 +1075,10 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() def which_layer(name: str) -> int: if "layers" in name: diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 4ac60f97bb5f..14c1250ca3b4 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable, Mapping -from typing import Literal, Optional, Set, Tuple, TypedDict, Union, cast +from typing import Literal, Optional, TypedDict, Union, cast import torch import torch.nn as nn @@ -357,7 +357,7 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 42ec786f3a59..051a73120838 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -2,8 +2,8 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, - TypeVar, Union) +from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, + Union) import torch import torch.nn as nn @@ -559,9 +559,8 @@ def forward( batch. pixel_values: The pixels in each input image. - :::{seealso} - {class}`Mistral3ImagePixelInputs` - ::: + Info: + [Mistral3ImagePixelInputs][] """ if intermediate_tensors is not None: inputs_embeds = None @@ -589,8 +588,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1513c8dad097..9bc7a16153e1 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -137,8 +138,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -314,8 +316,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -332,7 +334,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -479,7 +481,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 7c022a5b8f68..8220200d270c 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import numpy as np import torch @@ -50,7 +51,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -192,8 +193,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -353,6 +355,53 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class MixtralForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False @@ -397,51 +446,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 0c1d61c01f91..713c9e8d203f 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -16,7 +16,7 @@ """PyTorch Mllama model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import numpy as np import torch @@ -224,7 +224,7 @@ def apply( return mm_inputs - def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int: + def _get_num_image_in_last_group(self, prompt_token_ids: list[int]) -> int: num_images = 0 for token_id in prompt_token_ids[::-1]: if token_id == self.info.get_hf_config().image_token_index: @@ -370,8 +370,8 @@ def __init__( self, in_channels: int, out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]], + kernel_size: Union[int, tuple[int, int]], + stride: Union[int, tuple[int, int]], bias: bool = False, ) -> None: super().__init__() @@ -603,7 +603,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - ) -> Union[Tuple, BaseModelOutput]: + ) -> Union[BaseModelOutput]: encoder_states = () for i, encoder_layer in enumerate(self.layers): @@ -878,7 +878,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], - kv_range_for_decode: Optional[List[Tuple[int, int]]], + kv_range_for_decode: Optional[list[tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], ) -> torch.Tensor: q, k, v = self.qkv_proj(hidden_states, cross_attention_states) @@ -905,7 +905,7 @@ def _attention_with_mask( k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor, - kv_range_for_decode: List[Tuple[int, int]], + kv_range_for_decode: list[tuple[int, int]], ) -> torch.Tensor: kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] attn_metadata: AttentionMetadata = get_forward_context().attn_metadata @@ -1019,7 +1019,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, - kv_range_for_decode: Optional[List[Tuple[int, int]]], + kv_range_for_decode: Optional[list[tuple[int, int]]], full_text_row_masked_out_mask: torch.Tensor, ) -> torch.Tensor: residual = hidden_states @@ -1089,8 +1089,8 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], - kv_range_for_decode: Optional[List[Tuple[int, int]]], - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + kv_range_for_decode: Optional[list[tuple[int, int]]], + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]], skip_cross_attention: bool, ) -> torch.Tensor: @@ -1150,8 +1150,8 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], - kv_range_for_decode: Optional[List[Tuple[int, int]]], - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + kv_range_for_decode: Optional[list[tuple[int, int]]], + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]], skip_cross_attention: bool, ) -> torch.Tensor: @@ -1221,7 +1221,7 @@ def compute_logits( return logits def unpack_data(self, - image_data: Union[List[torch.Tensor], torch.Tensor], + image_data: Union[list[torch.Tensor], torch.Tensor], padding_value=0) -> torch.Tensor: if isinstance(image_data, torch.Tensor): # torch.Tensor @@ -1230,7 +1230,7 @@ def unpack_data(self, assert isinstance( image_data[0], torch.Tensor), "Image data is not properly batched." - # List[torch.Tensor] + # list[torch.Tensor] bsz = len(image_data) max_length = max(t.size(0) for t in image_data) trailing_dims = image_data[0].shape[1:] @@ -1248,24 +1248,24 @@ def unpack_data(self, def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: - # - List[torch.Tensor]: + # - list[torch.Tensor]: # with shape (num_image, num_tiles, 3, image_res, image_res) # - torch.Tensor: # with shape (bs, num_image, num_tiles, 3, image_res, image_res) - pixel_values: Optional[Union[List[List[torch.Tensor]], - List[torch.Tensor], + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], torch.Tensor]] = kwargs.pop( "pixel_values", None) - image_embeds: Optional[Union[List[List[torch.Tensor]], - List[torch.Tensor], + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], torch.Tensor]] = kwargs.pop( "image_embeds", None) - aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], - List[torch.Tensor], + aspect_ratio_ids: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], torch.Tensor]] = kwargs.pop( "aspect_ratio_ids", None) - aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], - List[torch.Tensor], + aspect_ratio_mask: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], torch.Tensor]] = kwargs.pop( "aspect_ratio_mask", None) @@ -1293,10 +1293,10 @@ def _parse_and_validate_image_input(self, **kwargs: object): def _get_and_validate_encoder_lens( self, - encoder_seq_lens: List[int], - num_tiles: List[List[int]], + encoder_seq_lens: list[int], + num_tiles: list[list[int]], num_tokens_per_tile: int, - ) -> List[int]: + ) -> list[int]: # Get the actual number of encoder tokens for each sample. # Because attn_metadata.encoder_seq_lens only counts the last # group of images for each sample, which is used to cheat the @@ -1318,7 +1318,7 @@ def _get_and_validate_encoder_lens( def flat_encoder_result(self, cross_attention_states: torch.Tensor, attn_metadata: AttentionMetadata, - actual_encoder_seq_lens: List[int]): + actual_encoder_seq_lens: list[int]): cross_attention_states_flat = torch.zeros( sum(actual_encoder_seq_lens), @@ -1342,8 +1342,8 @@ def get_cross_attention_states( self, image_inputs: MllamaImagePixelInputs, attn_metadata: AttentionMetadata, - actual_encoder_seq_lens: List[int], - ) -> Tuple[torch.Tensor]: + actual_encoder_seq_lens: list[int], + ) -> tuple[torch.Tensor]: # NOTE: llama's reference implementation runs vision model on CPU pixel_values = image_inputs['data'] aspect_ratio_ids = image_inputs['aspect_ratio_ids'] @@ -1367,10 +1367,10 @@ def get_cross_attention_mask( self, input_ids: torch.Tensor, attn_metadata: AttentionMetadata, - num_tiles: List[List[int]], + num_tiles: list[list[int]], num_tokens_per_tile: int, dtype: torch.dtype, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: token_ids = input_ids.tolist() start = 0 batch_token_ids = [] @@ -1422,7 +1422,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, **kwargs: object, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Union[CausalLMOutputWithPast]: attn_metadata = get_forward_context().attn_metadata if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: @@ -1476,8 +1476,8 @@ def forward( return outputs - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1487,7 +1487,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - updated_params: Set[str] = set() + updated_params: set[str] = set() for name, loaded_weight in weights: if 'patch_embedding.weight' in name: name = name.replace('patch_embedding.weight', @@ -1538,7 +1538,7 @@ def get_mm_mapping(self) -> MultiModelKeys: tower_model="vision_model") -def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: +def skip_attention_mask(sparse_mask: list[list[int]]) -> bool: for mask in sparse_mask: # Skip text-only samples. if len(mask) == 0: @@ -1556,10 +1556,10 @@ def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: def convert_sparse_cross_attention_mask_to_dense( - sparse_mask: List[List[List[int]]], - num_tiles: List[List[int]], - lengths: List[int], -) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + sparse_mask: list[list[list[int]]], + num_tiles: list[list[int]], + lengths: list[int], +) -> tuple[np.ndarray, list[tuple[int, int]]]: total_length = sum(lengths) total_tiles = sum([sum(tiles) for tiles in num_tiles]) dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 741b9837398c..8c98492c0bed 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -18,7 +18,7 @@ import math from collections.abc import Iterable, Mapping from itertools import tee -from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import torch from torch import nn @@ -582,7 +582,7 @@ def _get_prompt_updates( mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> List[PromptUpdate]: + ) -> list[PromptUpdate]: assert ( mm_items.get_count("image", strict=False) == 0 or "aspect_ratios" in out_mm_kwargs @@ -778,26 +778,26 @@ def compute_logits( def separate_weights( self, - weights: Iterable[Tuple[str, torch.Tensor]], + weights: Iterable[tuple[str, torch.Tensor]], prefix: str, - ) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[ + ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[ str, torch.Tensor]]]: weights1, weights2 = tee(weights, 2) - def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]: + def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]: for name, data in weights1: if name.startswith(prefix): yield (name, data) - def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]: + def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]: for name, data in weights2: if not name.startswith(prefix): yield (name, data) return get_prefix_weights(), get_other_weights() - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -806,7 +806,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), ] params_dict = dict(self.named_parameters()) - updated_params: Set[str] = set() + updated_params: set[str] = set() # language_model is an Llama4ForCausalLM instance. We load it's # using llama4's load_weights routine. diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 2920427f94f7..a7d7aa7d44ef 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Iterable, List, Set, Tuple +from collections.abc import Iterable import torch import torch.nn as nn @@ -148,7 +148,7 @@ def generate_proposals( previous_hidden_states: torch.Tensor, num_predict_tokens: int, sampling_metadata: SamplingMetadata, - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: if num_predict_tokens > self.max_speculative_tokens: raise ValueError(f"Max speculative tokens for model is " f"{self.max_speculative_tokens}, but " @@ -190,10 +190,10 @@ def generate_proposals( return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: name = name.replace("speculator.", "") param = params_dict.get(name) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 2190241f0ba3..86552aa05bf9 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Optional, Set, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -212,11 +213,11 @@ def __init__( eps=config.norm_eps, bias=config.norm_bias) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if name.endswith(".bias") and name not in params_dict: continue @@ -230,9 +231,12 @@ def load_weights(self, weights: Iterable[Tuple[str, def forward( self, input_ids: Optional[torch.LongTensor] = None, + positions: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: + position_ids = positions if positions is not None else position_ids if inputs_embeds is not None: hidden_states = inputs_embeds else: @@ -277,7 +281,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._pooler = CrossEncodingPooler(config, self.classifier, ModernBertPooler(config)) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 23814e6322d2..25e6f594069e 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -4,7 +4,7 @@ # https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py from dataclasses import dataclass, field -from typing import List, Union +from typing import Union @dataclass @@ -46,17 +46,17 @@ class ModelKeys: @dataclass class MultiModelKeys(ModelKeys): - language_model: List[str] = field(default_factory=list) - connector: List[str] = field(default_factory=list) + language_model: list[str] = field(default_factory=list) + connector: list[str] = field(default_factory=list) # vision tower and audio tower - tower_model: List[str] = field(default_factory=list) - generator: List[str] = field(default_factory=list) + tower_model: list[str] = field(default_factory=list) + generator: list[str] = field(default_factory=list) @staticmethod - def from_string_field(language_model: Union[str, List[str]] = None, - connector: Union[str, List[str]] = None, - tower_model: Union[str, List[str]] = None, - generator: Union[str, List[str]] = None, + def from_string_field(language_model: Union[str, list[str]] = None, + connector: Union[str, list[str]] = None, + tower_model: Union[str, list[str]] = None, + generator: Union[str, list[str]] = None, **kwargs) -> 'MultiModelKeys': def to_list(value): diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 42bbb77a22c0..640a2049a629 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import List, Optional, Set, Tuple, TypedDict, Union +from typing import Optional, TypedDict, Union import numpy as np import torch @@ -90,7 +90,7 @@ class MolmoImageInputs(TypedDict): @dataclass class VisionBackboneConfig: - image_default_input_size: Tuple[int, int] = (336, 336) + image_default_input_size: tuple[int, int] = (336, 336) image_patch_size: int = 14 image_pos_patch_size: int = 14 image_emb_dim: int = 1024 @@ -267,7 +267,7 @@ def __init__( for _ in range(config.image_num_layers) ]) - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: hidden_states = [] for r in self.resblocks: x = r(x) @@ -334,7 +334,7 @@ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: def forward(self, x: torch.Tensor, - patch_num: Optional[int] = None) -> List[torch.Tensor]: + patch_num: Optional[int] = None) -> list[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ @@ -434,7 +434,7 @@ def __init__( ) def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) @@ -570,7 +570,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]: # Self Attention if residual is None: residual = hidden_states @@ -596,7 +596,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]: # Self Attention residual = hidden_states hidden_states = self.self_attn( @@ -740,15 +740,15 @@ def forward( # image_features: (batch_size, num_image, num_patch, d_model) return image_features - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("merged_linear", "gate_proj", 0), ("merged_linear", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -855,10 +855,10 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if name.endswith(".bias") and name not in params_dict: @@ -965,7 +965,7 @@ def select_tiling( class MolmoProcessorWrapper: """ - Wraps {class}`MolmoProcessor` so that it can be called directly. + Wraps `MolmoProcessor` so that it can be called directly. The original definition can be found here: https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py @@ -1530,7 +1530,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) weights = _get_weights_with_merged_embedding(weights) @@ -1548,8 +1548,8 @@ def get_mm_mapping(self) -> MultiModelKeys: def _get_weights_with_merged_embedding( - weights: Iterable[Tuple[str, torch.Tensor]] -) -> Iterable[Tuple[str, torch.Tensor]]: + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: embedding_weights = {} for name, weight in weights: if "wte.embedding" in name: diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index c367d90f847b..9f11d4a42273 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -42,9 +42,10 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import math +from collections.abc import Sequence from copy import deepcopy from functools import cached_property -from typing import List, Optional, Sequence, Tuple, Union +from typing import Optional, Union import torch import torch.nn as nn @@ -222,7 +223,7 @@ def __init__( self, out_dim: int, in_dim: int = 3, - patch_size: Union[int, Tuple[int, int]] = (14, 14), + patch_size: Union[int, tuple[int, int]] = (14, 14), pos_emb_height: int = 14, pos_emb_width: int = 14, ): @@ -526,7 +527,7 @@ def patch_merger( x: torch.Tensor, grid_hw: torch.Tensor, merge_kernel_size: list[int, int] = (2, 2), -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: d_model = x.size(-1) outputs = [] diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 77bd794058cd..6c396d778ae7 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -2,7 +2,8 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.nn as nn @@ -265,10 +266,10 @@ def forward( hidden_states = self.norm_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: @@ -323,7 +324,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 5208c0796c8d..d0999e30e1ba 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Nemotron model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -47,7 +48,7 @@ from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -69,7 +70,7 @@ def _cast_if_autocast_enabled(*args): class NemotronLayerNorm1P(nn.LayerNorm): def __init__(self, - normalized_shape: Union[int, List[int], torch.Size], + normalized_shape: Union[int, list[int], torch.Size], eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True, @@ -133,7 +134,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, @@ -157,8 +158,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -267,7 +269,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -299,6 +301,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config + self.quant_config = quant_config lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab @@ -361,6 +364,63 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -441,66 +501,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 988b994b7689..9808fe05558e 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -22,18 +22,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only deci model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Type, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn from transformers import LlamaConfig +from vllm.attention import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -61,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int: return n + k - (n % k) +class DeciLMAttention(LlamaAttention): + + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__(config, hidden_size, num_heads, num_kv_heads, + rope_theta, rope_scaling, max_position_embeddings, + quant_config, bias, bias_o_proj, cache_config, prefix, + attn_type) + + def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + # Enables YARN for Mistral and LLaMA4 derivatives. + is_neox_style = True + if hasattr(config, "position_embedding_type"): + is_neox_style = config.position_embedding_type not in [ + "mistral_yarn", "rope_llama4" + ] + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor) + + class DeciLMDecoderLayer(nn.Module): def __init__( @@ -97,7 +142,7 @@ def __init__( if not self._is_no_op_attention: num_kv_heads = (config.num_attention_heads // block_config.attention.n_heads_in_group) - self.self_attn = LlamaAttention( + self.self_attn = DeciLMAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -135,7 +180,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if self._is_no_op_attention: @@ -168,7 +213,7 @@ def __init__( *, vllm_config: VllmConfig, prefix: str = "", - layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer, + layer_type: type[DeciLMDecoderLayer] = DeciLMDecoderLayer, ): super().__init__() @@ -260,8 +305,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -271,7 +316,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -428,8 +473,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 62a7deab6a10..172434e66ae2 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -22,9 +22,10 @@ PromptUpdateDetails) from .intern_vit import InternVisionModel -from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor, - InternVLChatModel, InternVLDummyInputsBuilder, - InternVLMultiModalProcessor) +from .internvl import (BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, BaseInternVLProcessor, + InternVLChatModel) IMG_PAD = "<|vision_pad|>" @@ -84,7 +85,8 @@ def get_hf_processor( ) -class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): +class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo] + ): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -110,7 +112,8 @@ def get_dummy_mm_data( } -class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): +class NVLMMultiModalProcessor( + BaseInternVLMultiModalProcessor[NVLMProcessingInfo]): def _get_prompt_updates( self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 0781ca168f84..fcb7c619a102 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -46,7 +47,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -209,7 +210,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]: # Attention block. residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -284,6 +285,45 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class OlmoForCausalLM(nn.Module, SupportsPP): """ @@ -338,53 +378,11 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 422b53d86f11..33adacdae5f5 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -23,8 +23,9 @@ # limitations under the License. """Inference-only OLMo2 model compatible with HuggingFace weights.""" +from collections.abc import Iterable from functools import partial -from typing import Iterable, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import nn @@ -48,8 +49,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.models.utils import ( - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, - make_layers, maybe_prefix) + AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -135,7 +136,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) @@ -313,6 +314,44 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader # type: ignore + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class Olmo2ForCausalLM(nn.Module, SupportsPP): """ @@ -365,48 +404,10 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if is_pp_missing_parameter(name, self): - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader # type: ignore - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index e6925e125690..6364b89fb837 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -102,7 +103,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 4096, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -307,8 +308,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -327,7 +328,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -439,10 +440,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=["rotary_emb.inv_freq"], - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d258eddae25d..8376d62410d4 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -18,7 +18,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -312,8 +313,8 @@ def forward( intermediate_tensors, inputs_embeds=inputs_embeds) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -321,7 +322,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -400,8 +401,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head.weight"] diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 8d9c000750d7..da2a194e6bdf 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -5,7 +5,8 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -72,7 +73,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -186,7 +187,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -259,8 +260,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -270,7 +271,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -341,16 +342,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/ovis2.py b/vllm/model_executor/models/ovis.py similarity index 58% rename from vllm/model_executor/models/ovis2.py rename to vllm/model_executor/models/ovis.py index 67cc86e7fc82..e03705d48f3e 100644 --- a/vllm/model_executor/models/ovis2.py +++ b/vllm/model_executor/models/ovis.py @@ -15,17 +15,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Ovis2 model.""" -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +""" PyTorch Ovis model.""" +import math +from collections.abc import Iterable, Mapping +from typing import Literal, Optional, TypedDict, Union import torch import torch.nn as nn from torch import Tensor -from transformers import BatchFeature +from torch.nn.functional import gumbel_softmax, pad, softmax +from transformers import BaseImageProcessor, BatchFeature from vllm.config import VllmConfig -from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.aimv2 import AIMv2Model +from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix) @@ -38,19 +44,160 @@ BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.ovis2 import OvisConfig -from vllm.transformers_utils.processors.ovis2 import OvisProcessor +from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig, + OvisConfig) +from vllm.transformers_utils.processors.ovis import OvisProcessor from .interfaces import MultiModalEmbeddings, SupportsMultiModal from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. IMAGE_TOKEN = "<image>" -IMAGE_PAD_TOKEN_ID = 151655 -NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256 +IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] +IMAGE_PAD_TOKEN_MAP = { + "gemma2": "<unused0>", + "llama": "<|reserved_special_token_0|>", + "qwen2": "<|image_pad|>", +} +IMAGE_PAD_TOKEN_ID_MAP = { + "gemma2": 7, + "llama": 128002, + "qwen2": 151655, +} -class Ovis2ImagePatchInputs(TypedDict): + +def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax + index = y_soft.argmax(dim, keepdim=True) + return torch.zeros_like( + y_soft, + memory_format=torch.legacy_contiguous_format, + ).scatter_(dim, index, 1.0) + + +class VisualTokenizer(torch.nn.Module): + + def __init__( + self, + config: BaseVisualTokenizerConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.backbone = self._init_backbone( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.backbone", + ) + # reserved tokens for IMAGE_INDICATORS + head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + config.backbone_config.hidden_size * config.hidden_stride * + config.hidden_stride, + head_dim, + bias=False, + return_bias=False, + ), torch.nn.LayerNorm(head_dim)) + + def _init_backbone( + self, + config: BaseVisualTokenizerConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + model_type = config.backbone_config.model_type + if model_type == "aimv2": + return AIMv2Model( + config=config.backbone_config, + quant_config=quant_config, + prefix=prefix, + ) + elif model_type == "siglip_vision_model": + return SiglipVisionModel( + config=config.backbone_config, + quant_config=quant_config, + prefix=prefix, + ) + raise ValueError( + f"Unsupported visual tokenizer model_type: {model_type}") + + @property + def dtype(self): + return next(self.head.parameters()).dtype + + @property + def device(self): + return next(self.head.parameters()).device + + def tokenize(self, logits): + if self.config.tokenize_function == 'softmax': + tokens = softmax(logits, dim=-1) + elif self.config.tokenize_function == 'gumbel_argmax': + tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) + elif self.config.tokenize_function == 'st_argmax': + tokens = st_argmax(logits, dim=-1) + else: + raise ValueError( + 'Invalid `max_type`, expected softmax or gumbel_argmax ' + f'or st_argmax, but got {self.config.tokenize_function}') + return tokens + + def encode(self, pixel_values): + features = self.backbone(pixel_values) + if self.config.drop_cls_token: + features = features[:, 1:, :] + + # merge number of `hidden_stride * hidden_stride` hidden states together + # to reduce token sequence length + # e.g., for hidden_stride=2, this leads to a token length reduction: + # 1024 -> 256 for aimv2 + if self.config.hidden_stride > 1: + # this `d` maybe different from the above `d`` + n, L, d = features.shape + sqrt_l = int(L**0.5) + assert sqrt_l**2 == L, ( + "The token sequence length should be a perfect square.") + features = features.reshape(n, sqrt_l, sqrt_l, d) + pl = (self.config.hidden_stride - + (sqrt_l % + self.config.hidden_stride)) % self.config.hidden_stride + features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) + sqrt_l += pl + features = features.reshape(n, sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, + sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, d) + # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] + features = features.permute(0, 1, 3, 2, 4, 5) + # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] + features = features.flatten(3) + # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] + features = features.reshape( + n, -1, + self.config.hidden_stride * self.config.hidden_stride * d) + + return features + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]""" + features = self.encode(pixel_values) + logits = self.head(features) + tokens = self.tokenize(logits) + # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with + # [BatchSize, #Token, 5], after which, tokens' shape should become + # [BatchSize, #Token, VocabSize] + tokens = torch.nn.functional.pad( + tokens, + (0, len(IMAGE_INDICATOR_IDS)), + mode="constant", + value=0, + ) + return tokens + + +class OvisImagePatchInputs(TypedDict): type: Literal["image_patches"] flat_data: torch.Tensor """ @@ -64,7 +211,7 @@ class Ovis2ImagePatchInputs(TypedDict): `(batch_size * (num_patches + 1))` """ - patches_per_image: List[int] + patches_per_image: list[int] """ List of number of total patches for each image in the batch. This is used to restore the first two dimensions of `flat_data`. @@ -92,31 +239,50 @@ def dtype(self): return self.weight.dtype -class Ovis2ProcessingInfo(BaseProcessingInfo): +class OvisProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(OvisConfig) def get_hf_processor(self, **kwargs): - return self.ctx.get_hf_processor(OvisProcessor) + return self.ctx.get_hf_processor( + OvisProcessor, + image_pad_token=self.get_image_pad_token(), + image_segment_len=self.get_image_segment_len(), + ) - def get_image_processor(self) -> OvisProcessor: + def get_image_segment_len(self) -> int: + visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config + image_size = visual_tokenizer_config.backbone_config.image_size + patch_size = visual_tokenizer_config.backbone_config.patch_size + hidden_stride = visual_tokenizer_config.hidden_stride + patch_grid_length = math.ceil(image_size / patch_size) + assert patch_grid_length % hidden_stride == 0, ( + f"patch_grid_length {patch_grid_length} is not divisible by " + f"hidden_stride {hidden_stride}") + # minus 1 for presented image token + return (patch_grid_length // hidden_stride)**2 - 1 + + def get_image_pad_token(self) -> str: + hf_text_config = self.get_hf_config().get_text_config() + text_model_type = hf_text_config.model_type + return IMAGE_PAD_TOKEN_MAP.get(text_model_type) + + def get_image_processor(self) -> BaseImageProcessor: return self.get_hf_processor().image_processor # type: ignore def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return { # 32k is model token limit at the moment - "image": - self.get_hf_config().multimodal_max_length // - ((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT) - } + return {"image": None} def get_image_size_with_most_features(self) -> ImageSize: - image_processor = self.get_image_processor() - return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2, - height=image_processor.size['shortest_edge'] * 9 * 2) + height, width = self.get_hf_processor().get_image_size() + hs = self.get_hf_config().visual_tokenizer_config.hidden_stride + # NOTE(Isotr0py): 9 is `max_partion` hardcoded in original code + # https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96 + return ImageSize(width=width * hs * 9, height=height * hs * 9) -class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]): +class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -141,7 +307,7 @@ def get_dummy_mm_data( return mm_data -class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]): +class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): def image_indicators_to_visual_tokens( self, @@ -165,9 +331,9 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: - # # Avoid warning from HF logger for text-only input - prompt_ids = self.info.get_tokenizer().encode(prompt) - # prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope + # Avoid warning from HF logger for text-only input + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") processed_outputs = super()._call_hf_processor( @@ -226,10 +392,10 @@ def get_replacement_ovis(item_idx): ] -@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor, - info=Ovis2ProcessingInfo, - dummy_inputs=Ovis2DummyInputsBuilder) -class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): +@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, + info=OvisProcessingInfo, + dummy_inputs=OvisDummyInputsBuilder) +class Ovis(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -242,24 +408,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "llm"), ) - self.visual_tokenizer = Aimv2VisualTokenizer( + self.visual_tokenizer = VisualTokenizer( config=config.visual_tokenizer_config, quant_config=quant_config, prefix=f"{prefix}.visual_tokenizer", - image_processor_name_or_path=config.visual_tokenizer_config. - backbone_config.name_or_path, ) self.vte = VisualEmbedding( self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size) + text_model_type = self.config.get_text_config().model_type + self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] + # TODO(Isotr0py): PP support # self.make_empty_intermediate_tensors = ( # self.language_model.make_empty_intermediate_tensors) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]: + self, **kwargs: object) -> Optional[OvisImagePatchInputs]: pixel_values = kwargs.pop("pixel_values", None) indicator_tokens = kwargs.pop("indicator_tokens", None) @@ -275,7 +442,7 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of indicator_tokens. " f"Got type: {type(pixel_values)}") - return Ovis2ImagePatchInputs( + return OvisImagePatchInputs( type="image_patches", flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), patches_per_image=[ @@ -288,7 +455,7 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") def _process_image_input( - self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings: + self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: image_patches_flat = image_input["flat_data"] patches_per_image = image_input["patches_per_image"] indicator_tokens = image_input["indicator_tokens"] @@ -338,7 +505,7 @@ def get_input_embeddings( if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, - [IMAGE_PAD_TOKEN_ID]) + self.image_pad_token_id) return inputs_embeds def forward( @@ -375,12 +542,11 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.logits_processor(self.llm.lm_head, hidden_states, - sampling_metadata) + logits = self.llm.compute_logits(hidden_states, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8699ae52622d..427005e9b704 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import torch from torch import nn @@ -391,7 +391,7 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index eacf02433b57..d46b95fea5a8 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -21,7 +21,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -260,10 +261,10 @@ def forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if is_pp_missing_parameter(name, self): continue @@ -336,7 +337,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index fc2b108bad97..330ad5c59448 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -36,7 +36,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -248,8 +249,8 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -257,7 +258,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v") ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -348,7 +349,7 @@ def compute_logits( sampling_metadata, self.lm_head.bias) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 338e87b4285f..d00d7d886d67 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -230,8 +231,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[tuple[torch.Tensor]]]: qkv, _ = self.query_key_value(hidden_states) qkv = qkv.view(qkv.shape[:-1] + @@ -352,10 +353,10 @@ def forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if name.endswith(".bias") and name not in params_dict: continue @@ -454,8 +455,8 @@ def forward( output_hidden_states = output_hidden_states return output_hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head.weight"] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a1442251b992..b757e661d771 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -14,10 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re from collections.abc import Iterable, Mapping, Sequence -from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union +import regex as re import torch import torch.nn as nn from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, @@ -94,7 +94,7 @@ def _init_img_processor(hf_config: PretrainedConfig, class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: Union[torch.Tensor, list[torch.Tensor]] """ Shape: `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` @@ -113,7 +113,7 @@ class Phi3VImagePixelInputs(TypedDict): class Phi3VImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: Union[torch.Tensor, list[torch.Tensor]] """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. @@ -571,8 +571,8 @@ def _validate_shape(d: torch.Tensor): return data def _validate_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size expected_dims = (3, h, w) @@ -707,8 +707,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) autoloaded_weights = loader.load_weights(weights, diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index e5ff9ceddef7..418ff900ffd5 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union import numpy as np import torch @@ -392,7 +392,7 @@ def forward(self, pixel_values: torch.FloatTensor, class Phi4MMImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: Union[torch.Tensor, list[torch.Tensor]] """ Shape: `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` @@ -415,18 +415,9 @@ class Phi4MMImagePixelInputs(TypedDict): """Shape: `(batch_size * num_images, H_mask, W_mask)`""" -class Phi4MMImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, List[torch.Tensor]] - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - """ - - class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: Union[torch.Tensor, list[torch.Tensor]] """Shape: `(batch_size * num_audios, 80, M)""" @@ -436,7 +427,6 @@ class Phi4MMAudioEmbeddingInputs(TypedDict): """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" -Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] @@ -1031,7 +1021,7 @@ def _process_audio_input(self, audio_input: Phi4MMAudioInputs, return audio_embeds def _parse_and_validate_image_input(self, - **kwargs: object) -> Optional[Dict]: + **kwargs: object) -> Optional[dict]: input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") if input_image_embeds is None: return None @@ -1112,15 +1102,13 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _process_image_input( self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: - if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) - else: - dtype = next(self.vision_encoder.parameters()).dtype - pixel_values = image_input['data'].to(dtype) - image_sizes = image_input['image_sizes'] - image_attention_mask = image_input['image_attention_mask'] - image_embeds = self.vision_encoder(pixel_values, image_sizes, - image_attention_mask) + + dtype = next(self.vision_encoder.parameters()).dtype + pixel_values = image_input['data'].to(dtype) + image_sizes = image_input['image_sizes'] + image_attention_mask = image_input['image_attention_mask'] + image_embeds = self.vision_encoder(pixel_values, image_sizes, + image_attention_mask) return image_embeds def get_multimodal_embeddings( @@ -1238,11 +1226,9 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: - weights = ((name, data) for name, data in weights - if "lora" not in name) - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader(self, skip_substrs=["lora"]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index 34a7a73d057a..98cef75069ae 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -6,7 +6,7 @@ #!/usr/bin/env python3 import abc import math -from typing import List, Literal, Optional +from typing import Literal, Optional import numpy as np import torch @@ -91,9 +91,9 @@ class ConformerEncoderLayer(nn.Module): if set to True, use GLULinear module, otherwise, used GLUPointWiseConv module. default to False. - attention_innner_dim: int, optional + attention_inner_dim: int, optional if equal to -1, attention dim for linears k/q/v is - equal to d_model. otherwise attention_innner_dim is used. + equal to d_model. otherwise attention_inner_dim is used. default -1. attention_glu_type: str, optional activation function for glu used in the multihead attention, @@ -148,7 +148,7 @@ def __init__( conv_glu_type="sigmoid", bias_in_glu=True, linear_glu_in_convm=False, - attention_innner_dim=-1, + attention_inner_dim=-1, attention_glu_type="swish", activation_checkpointing="", export=False, @@ -169,7 +169,7 @@ def __init__( n_head, d_model, dropout_rate, - attention_innner_dim, + attention_inner_dim, attention_glu_type, bias_in_glu, use_pt_scaled_dot_product_attention= @@ -746,7 +746,7 @@ class ConformerEncoder(TransformerEncoderBase): attention_group_size = attenion_heads = Multi-Query Attention """ - extra_multi_layer_output_idxs: List[int] + extra_multi_layer_output_idxs: list[int] def __init__( # pylint: disable-all self, diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index 4051763cec8c..f468fdbd5417 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -5,7 +5,7 @@ # but implemented by the Phi-Speech team #!/usr/bin/env python3 import math -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch import torch.nn.functional as F @@ -1586,7 +1586,7 @@ def forward( memory: Optional[Tensor] = None, pos_emb: Optional[Tensor] = None, att_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + ) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: """AttModule forward Args: diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 2dc55e4c352e..d9917c26d1b1 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -505,8 +506,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -521,7 +522,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -657,10 +658,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c0b492dbfcb9..9f28d4cef425 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -4,12 +4,14 @@ from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union import torch import torch.nn as nn import torch.nn.functional as F -from mistral_common.protocol.instruct.messages import ImageChunk +from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, + UserMessage) +from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image from transformers import PixtralVisionConfig, TensorType @@ -39,7 +41,7 @@ BaseProcessingInfo, MultiModalHashes, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) @@ -65,14 +67,14 @@ class PixtralImagePixelInputs(TypedDict): """ Shape: `(batch_size * num_images, num_channels, image_width, image_height)` - The result of stacking {attr}`ImageEncoding.tokens` from each prompt. + The result of stacking `ImageEncoding.tokens` from each prompt. """ class PixtralProcessorAdapter: """ Provide a HF-compatible interface for - {class}`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`. + `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`. """ def __init__(self, tokenizer: MistralTokenizer) -> None: @@ -224,6 +226,28 @@ def get_dummy_mm_data( num_images=num_images) } + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + tokenizer = self.info.get_tokenizer() + + dummy_text = self.get_dummy_text(mm_counts) + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) + dummy_images = dummy_mm_data.get("image", []) + + request = ChatCompletionRequest(messages=[ + UserMessage(content=[ + TextChunk(text=dummy_text), + *(ImageChunk(image=image) for image in dummy_images), + ]), + ]) + res = tokenizer.mistral.encode_chat_completion(request) + dummy_tokens = res.tokens + + return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data) + class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ): @@ -275,8 +299,12 @@ def _cached_apply_hf_processor( *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: - prompt_ids, mm_kwargs, mm_hashes, _ = super( - )._cached_apply_hf_processor( + ( + prompt_ids, + mm_kwargs, + mm_hashes, + _, + ) = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -438,18 +466,18 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): + def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith("vision_encoder") - def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): + def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith("vision_language_adapter") - def is_patch_merger(weight: Tuple[str, torch.Tensor]): + def is_patch_merger(weight: tuple[str, torch.Tensor]): return weight[0].startswith("patch_merger") - def is_pre_mm_projector_norm(weight: Tuple[str, torch.Tensor]): + def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): return weight[0].startswith("pre_mm_projector_norm") # Get references to parameters for direct loading @@ -566,7 +594,7 @@ def apply_rotary_emb_vit( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) assert freqs_cis.dtype == torch.complex64 @@ -671,7 +699,7 @@ def forward( return x -def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor: +def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: positions = torch.cat([ torch.stack( torch.meshgrid( @@ -733,7 +761,7 @@ def freqs_cis(self) -> torch.Tensor: def forward( self, - images: List[torch.Tensor], + images: list[torch.Tensor], ) -> torch.Tensor: """ Args: @@ -1023,7 +1051,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: batch, patches, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) @@ -1249,8 +1277,8 @@ def forward( # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1260,7 +1288,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() layer_count = len(self.transformer.layers) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 790c48ccd216..55a65f8078a4 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only PLaMo2 model.""" import math -from typing import Iterable, Optional, Tuple +from collections.abc import Iterable +from typing import Optional import torch from torch import nn @@ -659,7 +660,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() hidden_size = (self.config.mamba_num_heads * self.config.hidden_size_per_head) @@ -682,7 +683,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index c10ef45440b1..40ac5e30a368 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -16,7 +16,7 @@ # limitations under the License. """Inference-only IBM/NASA Prithvi Geospatial model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, Set, Tuple, Union +from typing import Optional, Union import torch import torch.nn as nn @@ -154,7 +154,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): "by PrithviGeospatialMAE.") def _parse_and_validate_multimodal_data( - self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: pixel_values = kwargs.pop("pixel_values", None) if not isinstance(pixel_values, torch.Tensor): @@ -195,8 +195,8 @@ def pooler( ) -> Optional[PoolerOutput]: return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_list = [] model_buffers = dict(self.named_buffers()) loaded_buffers = [] diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e75294bc6cba..2fda87a4ff0f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -6,7 +6,8 @@ # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" import json -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -76,7 +77,7 @@ def __init__( num_heads: int, max_position_embeddings: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -166,7 +167,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -284,15 +285,15 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index f76f31c9fc8d..0d0d98c59dbc 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -23,7 +23,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -53,7 +54,7 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter, + extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -99,17 +100,20 @@ def forward(self, x): class Qwen2Attention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -131,6 +135,7 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -155,15 +160,21 @@ def __init__(self, max_position=max_position, base=self.rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=attn_type) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}) def forward( self, @@ -192,6 +203,9 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) # By default, Qwen2 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -213,6 +227,7 @@ def __init__( rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -231,7 +246,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -276,14 +291,14 @@ def __init__(self, # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None and hasattr(config, "max_window_layers")): - raise ValueError("Sliding window for some but all layers is not " - "supported. This model uses sliding window " - "but `max_window_layers` = {} is less than " - "`num_hidden_layers` = {}. Please open an issue " - "to discuss this feature.".format( - config.max_window_layers, - config.num_hidden_layers, - )) + assert config.max_window_layers == config.num_hidden_layers, ( + "Sliding window for some but all layers is not supported. " + "This model uses sliding window but `max_window_layers` = {} " + "is less than `num_hidden_layers` = {}. Please open an issue " + "to discuss this feature.".format( + config.max_window_layers, + config.num_hidden_layers, + )) self.config = config self.quant_config = quant_config @@ -353,8 +368,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -364,7 +379,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -476,8 +491,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] @@ -545,7 +560,7 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights = self.hf_to_vllm_mapper.apply(weights) weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index d8e178f9cd47..d89b822dd873 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -21,10 +21,10 @@ # limitations under the License. """Inference-only Qwen2.5-Omni model (thinker part).""" +from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import (Any, Dict, Iterable, List, Mapping, Optional, Sequence, - Set, Tuple, Union) +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -138,7 +138,7 @@ def get_hf_processor( min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, size: Optional[dict[str, int]] = None, - fps: Optional[Union[float, List[float]]] = None, + fps: Optional[Union[float, list[float]]] = None, **kwargs: object, ) -> Qwen2_5OmniProcessor: if fps is not None: @@ -550,7 +550,7 @@ def _parse_and_validate_audio_input( def _parse_and_validate_image_input( self, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> Optional[Qwen2_5_VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -589,7 +589,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> Optional[Qwen2_5_VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) @@ -627,7 +627,7 @@ def _parse_and_validate_video_input( def _process_audio_input( self, audio_input: Qwen2AudioInputs, - audio_hashes: List[str] = None, + audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: @@ -676,7 +676,7 @@ def _process_image_input( def _process_video_input( self, video_input: Qwen2_5_VLVideoInputs, - video_hashes: List[str] = None, + video_hashes: list[str] = None, cached_video_embeds: torch.Tensor = None) -> torch.Tensor: if video_input["type"] == "video_embeds": return video_input["video_embeds"].type(self.visual.dtype) @@ -825,7 +825,7 @@ def get_multimodal_embeddings_v0( if audio_input is None and image_input is None and video_input is None: return None - multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] + multimodal_embeddings: list[tuple[NestedTensors, str]] = [] if audio_input is not None: audio_embeds = self._process_audio_input(audio_input) @@ -891,8 +891,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=["talker.", "token2wav."], diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 8728de95134d..68dd07820189 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -24,9 +24,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -from functools import partial -from typing import (Callable, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from collections.abc import Iterable, Mapping +from functools import lru_cache, partial +from typing import Callable, Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -91,7 +91,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] image_embeds: torch.Tensor """Supported types: - - List[`torch.Tensor`]: A list of tensors holding all images' features. + - list[`torch.Tensor`]: A list of tensors holding all images' features. Each tensor holds an image's features. - `torch.Tensor`: A tensor holding all images' features (concatenation of all images' feature tensors). @@ -137,7 +137,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] video_embeds: torch.Tensor """Supported types: - - List[`torch.Tensor`]: A list of tensors holding all videos' features. + - list[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. - `torch.Tensor`: A tensor holding all videos' features (concatenation of all videos' feature tensors). @@ -478,8 +478,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta**( + torch.arange(0, dim, 2, dtype=torch.float, device='cpu') / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -520,7 +520,7 @@ def __init__( self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads - # args for get_window_index + # args for get_window_index_thw self.window_size = vision_config.window_size self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size @@ -567,65 +567,71 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + def rotary_pos_emb_thw(self, t, h, w): + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) + max_size = max(h, w) + rotary_pos_emb_full = self.rotary_pos_emb(max_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + rotary_pos_emb = rotary_pos_emb.reshape( + rotary_pos_emb.shape[0] // self.spatial_merge_unit, + self.spatial_merge_unit, -1) + return rotary_pos_emb - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 + def get_window_index_thw(self, grid_t, grid_h, grid_w): vit_merger_window_size = (self.window_size // self.spatial_merge_size // self.patch_size) - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h = grid_h // self.spatial_merge_size - llm_grid_w = grid_w // self.spatial_merge_size - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum( - 0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - return window_index, cu_window_seqlens + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_seqlens_tmp = cu_seqlens_tmp.to(dtype=torch.int32) + cu_seqlens_tmp = torch.unique_consecutive(cu_seqlens_tmp) + + return index_new, cu_seqlens_tmp + + @lru_cache(maxsize=1024) # noqa: B019 + def get_rope_by_thw(self, t, h, w): + window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw( + t, h, w) + rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w) + rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :] + rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1) + cu_seqlens_thw = torch.repeat_interleave( + torch.tensor([h * w], dtype=torch.int32), t) + return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, + cu_seqlens_thw) def compute_attn_mask_seqlen( self, @@ -641,45 +647,74 @@ def compute_attn_mask_seqlen( def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: list[list[int]], ) -> torch.Tensor: # patchify + seq_len, _ = x.size() + rotary_pos_emb = [] + window_index: list = [] + cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)] + cu_seqlens: list = [] + hidden_states = x.to(device=self.device, dtype=self.dtype) hidden_states = self.patch_embed(hidden_states) - # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index_id = 0 + cu_window_seqlens_last = 0 + for t, h, w in grid_thw: + t, h, w = int(t), int(h), int(w) + llm_h = h // self.spatial_merge_size + llm_w = w // self.spatial_merge_size + + ( + rotary_pos_emb_thw, + window_index_thw, + cu_seqlens_window_thw, + cu_seqlens_thw, + ) = self.get_rope_by_thw(t, h, w) + + window_index.append(window_index_thw + window_index_id) + window_index_id += (t * llm_h * llm_w) + + cu_seqlens_window_thw = (cu_seqlens_window_thw + + cu_window_seqlens_last) + cu_window_seqlens_last = cu_seqlens_window_thw[-1] + cu_window_seqlens.append(cu_seqlens_window_thw) - # windows attention - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + rotary_pos_emb.append(rotary_pos_emb_thw) + + cu_seqlens.append(cu_seqlens_thw) + + rotary_pos_emb = torch.cat(rotary_pos_emb) + window_index = torch.cat(window_index) + cu_window_seqlens = torch.cat(cu_window_seqlens) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - seq_len, _ = hidden_states.size() - hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - hidden_states = hidden_states[window_index, :, :] - hidden_states = hidden_states.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + cu_seqlens = torch.cat(cu_seqlens) + cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers - hidden_states = hidden_states.unsqueeze(1) - # pre-compute seqlens for window/full attn to reduce cuMemcpy operations max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( cu_seqlens) max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( cu_window_seqlens) + + cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) + cu_window_seqlens = cu_window_seqlens.to(device=self.device, + non_blocking=True) + rotary_pos_emb = rotary_pos_emb.to(device=self.device, + non_blocking=True) + window_index = window_index.to(device=hidden_states.device, + non_blocking=True) + + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + hidden_states = hidden_states.unsqueeze(1) + for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens @@ -709,8 +744,8 @@ def forward( hidden_states = hidden_states[reverse_indices, :] return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -718,7 +753,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("attn.qkv.", "attn.v.", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -750,7 +785,7 @@ def get_hf_processor( min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, size: Optional[dict[str, int]] = None, - fps: Optional[Union[float, List[float]]] = None, + fps: Optional[Union[float, list[float]]] = None, **kwargs: object, ) -> Qwen2_5_VLProcessor: if fps is not None: @@ -932,12 +967,13 @@ def _process_image_input( grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -951,13 +987,15 @@ def _process_video_input( grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size @@ -1116,8 +1154,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index f30bf08ab18b..3182a7532578 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, Set, Tuple, TypedDict, Union +from typing import Any, Optional, TypedDict, Union import torch import torch.nn as nn @@ -403,7 +403,7 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 47d90919ed8f..143b9f98b029 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -23,7 +23,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch import torch.nn.functional as F @@ -33,9 +34,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -129,7 +128,8 @@ def __init__( intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), ) else: self.shared_expert = None @@ -156,7 +156,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) return final_hidden_states.view(orig_shape) @@ -170,11 +170,12 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + dual_chunk_attention_config: Optional[dict[str, Any]] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -198,6 +199,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -221,14 +223,20 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}) def forward( self, @@ -256,6 +264,9 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Qwen2MoeAttention( @@ -268,6 +279,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, ) # Note: Qwen/Qwen2-57B-A14B-Instruct does not have @@ -378,8 +390,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -398,7 +410,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -521,10 +533,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 90f799e6734e..81dc38988c9d 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -5,7 +5,8 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-RM model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -95,8 +96,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ac0a6de523df..0ff0836b0897 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -25,8 +25,7 @@ """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, - Union) +from typing import Any, Callable, Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -102,7 +101,7 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] image_embeds: torch.Tensor """Supported types: - - List[`torch.Tensor`]: A list of tensors holding all images' features. + - list[`torch.Tensor`]: A list of tensors holding all images' features. Each tensor holds an image's features. - `torch.Tensor`: A tensor holding all images' features (concatenation of all images' feature tensors). @@ -142,7 +141,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] video_embeds: torch.Tensor """Supported types: - - List[`torch.Tensor`]: A list of tensors holding all videos' features. + - list[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. - `torch.Tensor`: A tensor holding all videos' features (concatenation of all videos' feature tensors). @@ -662,8 +661,8 @@ def forward( return x - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -671,7 +670,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -1394,8 +1393,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 40e0ccc1bab6..dbe2be8a73d5 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -21,7 +21,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3 model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -63,7 +64,7 @@ def __init__(self, rope_theta: float = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None, + rope_scaling: Optional[tuple] = None, prefix: str = "", attn_type: str = AttentionType.DECODER) -> None: super().__init__() @@ -201,7 +202,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -309,8 +310,8 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index fe6b303ba0b5..8a4c2850dda3 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -21,7 +21,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -30,9 +31,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -137,7 +136,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits=router_logits) final_hidden_states = final_hidden_states if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) return final_hidden_states.view(orig_shape) @@ -151,7 +150,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, head_dim: Optional[int] = None, rms_norm_eps: float = 1e-06, @@ -375,8 +374,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -395,7 +394,7 @@ def load_weights(self, weights: Iterable[Tuple[str, num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -475,6 +474,17 @@ def load_weights(self, weights: Iterable[Tuple[str, class Qwen3MoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } fall_back_to_pt_during_load = False @@ -518,10 +528,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 199b885a5850..f5d242fdf1c2 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -7,13 +7,12 @@ import copy import math -import re import unicodedata -from collections.abc import Collection, Mapping, Sequence -from collections.abc import Set as AbstractSet +from collections.abc import Collection, Mapping, Sequence, Set from functools import lru_cache, partial -from typing import Callable, List, Literal, Optional, TypedDict, Union +from typing import Callable, Literal, Optional, TypedDict, Union +import regex as re import torch from torch import nn from torchvision import transforms @@ -383,7 +382,8 @@ def _get_tokenizer_without_image_pad( tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: """ The logic of adding image pad tokens should only be applied in - {class}`QwenVLProcessor`, so they are patched out here. + [`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor], + so they are patched out here. The definition of the wrapped tokenizer can be found here: https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py @@ -395,7 +395,7 @@ class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore def tokenize( self, text: str, - allowed_special: Union[AbstractSet[str], str] = "all", + allowed_special: Union[Set[str], str] = "all", disallowed_special: Union[Collection[str], str] = (), **kwargs, ) -> list[Union[bytes, str]]: @@ -411,7 +411,7 @@ def tokenize( def _decode( self, - token_ids: Union[int, List[int]], + token_ids: Union[int, list[int]], skip_special_tokens: bool = False, errors: Optional[str] = None, **kwargs, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 19153efd8e17..97ea12de6537 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -10,10 +10,10 @@ import sys import tempfile from abc import ABC, abstractmethod +from collections.abc import Set from dataclasses import dataclass, field from functools import lru_cache -from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type, - TypeVar, Union) +from typing import Callable, Optional, TypeVar, Union import cloudpickle import torch.nn as nn @@ -79,6 +79,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), + "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), @@ -88,6 +89,7 @@ # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), @@ -126,7 +128,8 @@ "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GritLM": ("gritlm", "GritLM"), - "GteModel": ("bert", "GteEmbeddingModel"), + "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), + "GteNewModel": ("bert_with_rope", "GteNewModel"), "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "LlamaModel": ("llama", "LlamaForCausalLM"), @@ -136,7 +139,8 @@ if arch == "LlamaForCausalLM" }, "MistralModel": ("llama", "LlamaForCausalLM"), - "NomicBertModel": ("bert", "NomicBertEmbeddingModel"), + "ModernBertModel": ("modernbert", "ModernBertModel"), + "NomicBertModel": ("bert_with_rope", "NomicBertModel"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), @@ -195,7 +199,7 @@ "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), - "Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"), + "Ovis": ("ovis", "Ovis"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 @@ -204,6 +208,7 @@ "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 + "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), # [Encoder-decoder] @@ -215,6 +220,7 @@ } _SPECULATIVE_DECODING_MODELS = { + "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), @@ -262,7 +268,7 @@ class _ModelInfo: supports_v0_only: bool @staticmethod - def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": + def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": return _ModelInfo( architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), @@ -286,7 +292,7 @@ def inspect_model_cls(self) -> _ModelInfo: raise NotImplementedError @abstractmethod - def load_model_cls(self) -> Type[nn.Module]: + def load_model_cls(self) -> type[nn.Module]: raise NotImplementedError @@ -297,10 +303,10 @@ class _RegisteredModel(_BaseRegisteredModel): """ interfaces: _ModelInfo - model_cls: Type[nn.Module] + model_cls: type[nn.Module] @staticmethod - def from_model_cls(model_cls: Type[nn.Module]): + def from_model_cls(model_cls: type[nn.Module]): return _RegisteredModel( interfaces=_ModelInfo.from_model_cls(model_cls), model_cls=model_cls, @@ -309,7 +315,7 @@ def from_model_cls(model_cls: Type[nn.Module]): def inspect_model_cls(self) -> _ModelInfo: return self.interfaces - def load_model_cls(self) -> Type[nn.Module]: + def load_model_cls(self) -> type[nn.Module]: return self.model_cls @@ -326,7 +332,7 @@ def inspect_model_cls(self) -> _ModelInfo: return _run_in_subprocess( lambda: _ModelInfo.from_model_cls(self.load_model_cls())) - def load_model_cls(self) -> Type[nn.Module]: + def load_model_cls(self) -> type[nn.Module]: mod = importlib.import_module(self.module_name) return getattr(mod, self.class_name) @@ -335,7 +341,7 @@ def load_model_cls(self) -> Type[nn.Module]: def _try_load_model_cls( model_arch: str, model: _BaseRegisteredModel, -) -> Optional[Type[nn.Module]]: +) -> Optional[type[nn.Module]]: from vllm.platforms import current_platform current_platform.verify_model_arch(model_arch) try: @@ -362,22 +368,22 @@ def _try_inspect_model_cls( @dataclass class _ModelRegistry: # Keyed by model_arch - models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict) + models: dict[str, _BaseRegisteredModel] = field(default_factory=dict) - def get_supported_archs(self) -> AbstractSet[str]: + def get_supported_archs(self) -> Set[str]: return self.models.keys() def register_model( self, model_arch: str, - model_cls: Union[Type[nn.Module], str], + model_cls: Union[type[nn.Module], str], ) -> None: """ Register an external model to be used in vLLM. `model_cls` can be either: - - A {class}`torch.nn.Module` class directly referencing the model. + - A [`torch.nn.Module`][] class directly referencing the model. - A string in the format `<module>:<class>` which can be used to lazily import the model. This is useful to avoid initializing CUDA when importing the model and thus the related error @@ -409,7 +415,7 @@ def register_model( self.models[model_arch] = model - def _raise_for_unsupported(self, architectures: List[str]): + def _raise_for_unsupported(self, architectures: list[str]): all_supported_archs = self.get_supported_archs() if any(arch in all_supported_archs for arch in architectures): @@ -422,7 +428,7 @@ def _raise_for_unsupported(self, architectures: List[str]): f"Supported architectures: {all_supported_archs}") def _try_load_model_cls(self, - model_arch: str) -> Optional[Type[nn.Module]]: + model_arch: str) -> Optional[type[nn.Module]]: if model_arch not in self.models: return None @@ -436,8 +442,8 @@ def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: def _normalize_archs( self, - architectures: Union[str, List[str]], - ) -> List[str]: + architectures: Union[str, list[str]], + ) -> list[str]: if isinstance(architectures, str): architectures = [architectures] if not architectures: @@ -454,8 +460,8 @@ def _normalize_archs( def inspect_model_cls( self, - architectures: Union[str, List[str]], - ) -> Tuple[_ModelInfo, str]: + architectures: Union[str, list[str]], + ) -> tuple[_ModelInfo, str]: architectures = self._normalize_archs(architectures) for arch in architectures: @@ -467,8 +473,8 @@ def inspect_model_cls( def resolve_model_cls( self, - architectures: Union[str, List[str]], - ) -> Tuple[Type[nn.Module], str]: + architectures: Union[str, list[str]], + ) -> tuple[type[nn.Module], str]: architectures = self._normalize_archs(architectures) for arch in architectures: @@ -480,77 +486,77 @@ def resolve_model_cls( def is_text_generation_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_text_generation_model def is_pooling_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_pooling_model def is_cross_encoder_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_cross_encoding def is_multimodal_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal def is_pp_supported_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_pp def model_has_inner_state( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.has_inner_state def is_attention_free_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_attention_free def is_hybrid_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_hybrid def is_noops_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.has_noops def is_transcription_model( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_transcription def is_v1_compatible( self, - architectures: Union[str, List[str]], + architectures: Union[str, list[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return not model_cls.supports_v0_only diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 4c23d72a4195..9a4d0ab2dd4d 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -from typing import Iterable, Optional, Tuple +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -19,6 +20,7 @@ from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) +from .bert_with_rope import BertWithRope, JinaRobertaModel from .interfaces import SupportsCrossEncoding, SupportsV0Only @@ -125,39 +127,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def _build_model(self, vllm_config: VllmConfig, - prefix: str = "") -> BertModel: + prefix: str = "") -> Union[BertModel, BertWithRope]: if (vllm_config.model_config.hf_config.position_embedding_type == "rotary"): - config = vllm_config.model_config.hf_config - head_dim = config.hidden_size // config.num_attention_heads - - rotary_kwargs = { - "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), - "max_position": config.max_position_embeddings, - "base": config.rotary_emb_base, - "rope_scaling": getattr(config, "rope_scaling", None) - } - - return BertModel(vllm_config=vllm_config, - rotary_kwargs=rotary_kwargs, - prefix=prefix) + return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) else: return BertModel(vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - if getattr(self.config, "lora_rank", 0) > 0: - scaling = self.config.lora_alpha / self.config.lora_rank - weights = jina_merge_lora_weights(weights, scaling) - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights = self.hf_to_vllm_mapper.apply(weights) # Separate weights in "roberta"-prefixed and all else (not in memory). # For use with models like FacebookAI/roberta-base. bert_weights, task_weights = roberta_task_weights_filter(weights) - bert_weights = jina_to_vllm_mapper.apply(bert_weights) - loaded = self.model.load_weights(bert_weights) if not len(loaded): # Fix for models like `sentence-transformers/stsb-roberta-base-v2` @@ -178,6 +161,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, _pooler: An instance of Pooler used for pooling operations. """ + jina_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + 'emb_ln': "embeddings.LayerNorm", + 'layers': "layer", + 'mixer.Wqkv': "attention.self.qkv_proj", + 'mixer.out_proj': "attention.output.dense", + 'norm1': "attention.output.LayerNorm", + 'mlp.fc1': "intermediate.dense", + 'mlp.fc2': "output.dense", + 'norm2': "output.LayerNorm", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -193,9 +188,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.classifier = RobertaClassificationHead(config) self._pooler = CrossEncodingPooler(config, self.classifier) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): bert_weights, task_weights = roberta_task_weights_filter(weights) - bert_weights = jina_to_vllm_mapper.apply(bert_weights) + bert_weights = self.jina_to_vllm_mapper.apply(bert_weights) self.roberta.load_weights(bert_weights) @@ -255,8 +250,8 @@ def create_position_ids_from_input_ids(input_ids, def roberta_task_weights_filter( - all_weights: Iterable[Tuple[str, torch.Tensor]] -) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, + all_weights: Iterable[tuple[str, torch.Tensor]] +) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]: """ Separate task-specific weights that are applied on top @@ -276,57 +271,3 @@ def encoder_decoder_weights(): return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 if not n.startswith("roberta.")) - - -jina_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={ - 'emb_ln': "embeddings.LayerNorm", - 'layers': "layer", - 'mixer.Wqkv': "attention.self.qkv_proj", - 'mixer.out_proj': "attention.output.dense", - 'norm1': "attention.output.LayerNorm", - 'mlp.fc1': "intermediate.dense", - 'mlp.fc2': "output.dense", - 'norm2': "output.LayerNorm", - }) - - -@torch.inference_mode() -def jina_merge_lora_weights(weights: Iterable[Tuple[str, torch.Tensor]], - scaling: float = 1.0): - # use for jina-embeddings-v3 - # Merge Lora weights into a single weight tensor. - # This is a temporary solution until we have a better way to handle - - weights = {name: weight for name, weight in weights} - - o = ".original" - a = ".0.lora_A" - b = ".0.lora_B" - - # text-matching - i = -1 - - for name in list(weights.keys()): - if o in name: - dtype = weights[name].dtype - shape = weights[name].shape - weight_name = name[:-len(o)] - - if "embeddings" in weight_name: - B = weights[weight_name + a][i].cuda().float() - A = weights[weight_name + b][i].cuda().float() - else: - B = weights[weight_name + b][i].cuda().float() - A = weights[weight_name + a][i].cuda().float() - - weight = (weights[weight_name + o].cuda() + - torch.matmul(B, A).view(shape) * scaling) - weight = weight.cpu().to(dtype) - - weights[weight_name.replace(".parametrizations", "")] = weight - - del weights[weight_name + o], weights[weight_name + - a], weights[weight_name + b] - - return [(name, weight) for name, weight in weights.items()] diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 75fcf540b0b1..3b5334afa7af 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -3,7 +3,8 @@ within a vision language model.""" import math -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -265,7 +266,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - ) -> Tuple[torch.Tensor, None]: + ) -> tuple[torch.Tensor, None]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -480,8 +481,8 @@ def forward( feature_sample_layers=feature_sample_layers, ) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -489,7 +490,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index e78c37b65f87..eefadda918f6 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,7 +8,7 @@ # -------------------------------------------------------- from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union +from typing import Literal, Optional, TypedDict, TypeVar, Union import torch import torch.nn as nn @@ -24,6 +24,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, @@ -78,7 +79,7 @@ class SkyworkR1VImageEmbeddingInputs(TypedDict): def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD return T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Lambda(lambda img: convert_image_mode(img, 'RGB')), T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), @@ -937,8 +938,8 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: skip_prefixes = [ "action_embed", "temporal_embed", "track_embed", "track_embed_decoder", "box_token", "cg_criterion", "cg_model", diff --git a/vllm/model_executor/models/smolvlm.py b/vllm/model_executor/models/smolvlm.py index 17217dc9a247..31dec55026ba 100644 --- a/vllm/model_executor/models/smolvlm.py +++ b/vllm/model_executor/models/smolvlm.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Optional +from typing import Optional from transformers import SmolVLMProcessor @@ -21,7 +21,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo): def get_hf_processor( self, *, - max_image_size: Optional[Dict[str, int]] = None, + max_image_size: Optional[dict[str, int]] = None, **kwargs: object, ) -> SmolVLMProcessor: if max_image_size is not None: diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index f86aff7ba7ef..fcd17cc1c2ba 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -23,7 +23,8 @@ # limitations under the License. """Inference-only Solar model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import torch from torch import nn @@ -49,7 +50,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -101,7 +102,7 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, @@ -125,8 +126,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -236,7 +238,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -268,6 +270,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config + self.quant_config = quant_config lora_vocab = ((lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab @@ -359,6 +362,65 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -437,68 +499,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 1cbda7267e4c..86ce813ddf3d 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -20,7 +20,8 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -180,7 +181,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -252,8 +253,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -263,7 +264,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -335,15 +336,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - skip_prefixes=[ - "rotary_emb.inv_freq", "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ], - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 6eebe4c4d614..f4ba5a8030e5 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -19,7 +19,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn @@ -255,8 +256,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -265,7 +266,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -342,14 +343,13 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - skip_prefixes=([ - "rotary_emb.inv_freq", "lm_head.weight" - ] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]), + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index 379e19e1beea..7d713d23c772 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Set, Tuple +from collections.abc import Iterable import torch import torch.nn as nn @@ -50,14 +50,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): layer.mlp.gate_up_proj.bias = None layer.mlp.gate_up_proj.skip_bias_add = True - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ('gate_up_proj', 'gate_proj', 0), ('gate_up_proj', 'up_proj', 1), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() total_num_heads = self.config.n_head head_dim = self.config.hidden_size // total_num_heads for name, loaded_weight in weights: @@ -128,8 +128,8 @@ def _init_model(self, layer_type: type[nn.Module] = LlamaDecoderLayer): return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 7b946ad6aac7..b87a2ebf211a 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models""" -import re -from typing import Iterable, Literal, Optional, Union +from collections.abc import Iterable +from contextlib import nullcontext +from typing import Literal, Optional, Union +import regex as re import torch from torch import nn from transformers import AutoModel, PretrainedConfig, PreTrainedModel @@ -109,6 +111,33 @@ def replace_linear_class( ) +class ConfigOverride: + """Context manager to temporarily override config attributes.""" + + def __init__(self, config: PretrainedConfig, **kwargs): + self.config = config + self.kwargs = kwargs + self.kwargs_original = {} + self.kwargs_delete = set() + + def __enter__(self): + """Override config attributes.""" + for key, value in self.kwargs.items(): + if not hasattr(self.config, key): + self.kwargs_delete.add(key) + self.kwargs_original[key] = getattr(self.config, key, None) + setattr(self.config, key, value) + return self.config + + def __exit__(self, exc_type, exc_value, traceback): + """Restore original config attributes.""" + for key, value in self.kwargs_original.items(): + if key in self.kwargs_delete: + delattr(self.config, key) + else: + setattr(self.config, key, value) + + class TransformersModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -134,8 +163,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pp_rank = self.pp_group.rank_in_group self.tp_size = get_tensor_model_parallel_world_size() + # vLLM handles interleaved sliding window attention by creating a new + # interleaved_sliding_window attribute and deleting the sliding_window + # attribute. This breaks the constructors in Transformers so we + # temporarily add the attribute back to construct the model. + config_override = nullcontext() + if hasattr(config, "interleaved_sliding_window"): + config_override = ConfigOverride( + config, sliding_window=config.interleaved_sliding_window) + # Use meta device to delay allocating GPU tensors - with torch.device("meta"): + with torch.device("meta"), config_override: # FIXME(Isotr0py): We need to refactor this part in the future to # avoid registering an extra model layer, otherwise we will need a # weights mapper to rename weights. @@ -261,9 +299,17 @@ def create_attention_instances(self) -> dict[int, Attention]: num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) start, end = get_pp_indices(self.config.num_hidden_layers, self.pp_rank, self.pp_size) - return { - i: - Attention( + + attention_instances = {} + for i in range(start, end): + # Handle interleaved sliding window attention + sliding_window = None + if (hasattr(self.config, "interleaved_sliding_window") + and hasattr(self.config, "sliding_window_pattern") + and ((i + 1) % self.config.sliding_window_pattern > 0)): + sliding_window = self.config.interleaved_sliding_window + + attention_instances[i] = Attention( num_heads=num_heads, head_size=head_size, # NOTE: We use Llama scale as default, if it's set by @@ -272,9 +318,9 @@ def create_attention_instances(self) -> dict[int, Attention]: num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, + per_layer_sliding_window=sliding_window, prefix=f"{i}.attn") - for i in range(start, end) - } + return attention_instances def init_buffers(self, module: nn.Module): """ diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 0bc5d218f8d0..c1a4dc1b33d7 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,7 +3,7 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union import torch from torch import nn @@ -619,8 +619,8 @@ def compute_logits(self, hidden_states: torch.Tensor, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."]) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 0458e3ce03b5..3d821d3dc6b5 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, Union, overload) +from typing import Callable, Literal, Optional, Protocol, Union, overload import torch import torch.nn as nn @@ -58,15 +58,15 @@ def _map_name(self, key: str) -> Optional[str]: return key def apply( - self, weights: Iterable[Tuple[str, torch.Tensor]] - ) -> Iterable[Tuple[str, torch.Tensor]]: + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: return ((out_name, data) for name, data in weights if (out_name := self._map_name(name)) is not None) class AutoWeightsLoader: """ - Helper class to load weights into a {class}`torch.nn.Module`. It is able + Helper class to load weights into a [`torch.nn.Module`][]. It is able to automatically detect child modules and parameters while iterating over the weights only once. @@ -80,23 +80,35 @@ class AutoWeightsLoader: environment variable ``VLLM_LOGGING_LEVEL=DEBUG``. """ + # Models trained using early version ColossalAI + # may include these tensors in checkpoint. Skip them. + ROTARY_EMBEDS_UNUSED_WEIGHTS = [ + "rotary_emb.inv_freq", + "rotary_emb.cos_cached", + "rotary_emb.sin_cached", + ] + def __init__( self, module: nn.Module, *, - skip_prefixes: Optional[List[str]] = None, - ignore_unexpected_prefixes: Optional[List[str]] = None, + skip_prefixes: Optional[list[str]] = None, + skip_substrs: Optional[list[str]] = None, + ignore_unexpected_prefixes: Optional[list[str]] = None, ) -> None: super().__init__() self.module = module self.skip_prefixes = skip_prefixes or [] + self.skip_substrs = skip_substrs or [] self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + # update default skip_substrs + self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS def _groupby_prefix( self, - weights: Iterable[Tuple[str, torch.Tensor]], - ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]: + weights: Iterable[tuple[str, torch.Tensor]], + ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]: weights_by_parts = ((weight_name.split(".", 1), weight_data) for weight_name, weight_data in weights) @@ -119,7 +131,8 @@ def _get_qualname(self, prefix: str, rest: str) -> str: return ".".join((prefix, rest)) def _can_skip(self, qualname: str) -> bool: - return any(qualname.startswith(p) for p in self.skip_prefixes) + return (any(qualname.startswith(p) for p in self.skip_prefixes) + or any(substr in qualname for substr in self.skip_substrs)) def _can_ignore_unexpected(self, qualname: str) -> bool: return any( @@ -129,7 +142,7 @@ def _load_param( self, base_prefix: str, param: nn.Parameter, - weights: Iterable[Tuple[str, torch.Tensor]], + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[str]: for weight_name, weight_data in weights: weight_qualname = self._get_qualname(base_prefix, weight_name) @@ -159,7 +172,7 @@ def _load_param( yield weight_qualname def _add_loadable_non_param_tensors(self, module: nn.Module, - child_params: Dict[str, torch.Tensor]): + child_params: dict[str, torch.Tensor]): """ Add tensor names that are not in the model params that may be in the safetensors, e.g., batch normalization stats. @@ -182,7 +195,7 @@ def _load_module( self, base_prefix: str, module: nn.Module, - weights: Iterable[Tuple[str, torch.Tensor]], + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[str]: if isinstance(module, PPMissingLayer): return @@ -251,12 +264,15 @@ def _load_module( def load_weights( self, - weights: Iterable[Tuple[str, torch.Tensor]], + weights: Iterable[tuple[str, torch.Tensor]], *, mapper: Optional[WeightsMapper] = None, - ) -> Set[str]: + ) -> set[str]: if mapper is not None: weights = mapper.apply(weights) + # filter out weights with first-prefix/substr to skip in name + weights = ((name, weight) for name, weight in weights + if not self._can_skip(name)) autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights @@ -292,13 +308,13 @@ def flatten_bn(x: torch.Tensor) -> torch.Tensor: @overload -def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]: +def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ... @overload def flatten_bn( - x: Union[List[torch.Tensor], torch.Tensor], + x: Union[list[torch.Tensor], torch.Tensor], *, concat: Literal[True], ) -> torch.Tensor: @@ -307,18 +323,18 @@ def flatten_bn( @overload def flatten_bn( - x: Union[List[torch.Tensor], torch.Tensor], + x: Union[list[torch.Tensor], torch.Tensor], *, concat: bool = False, -) -> Union[List[torch.Tensor], torch.Tensor]: +) -> Union[list[torch.Tensor], torch.Tensor]: ... def flatten_bn( - x: Union[List[torch.Tensor], torch.Tensor], + x: Union[list[torch.Tensor], torch.Tensor], *, concat: bool = False, -) -> Union[List[torch.Tensor], torch.Tensor]: +) -> Union[list[torch.Tensor], torch.Tensor]: """ Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. @@ -442,7 +458,7 @@ def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, - placeholder_token_id: Union[int, List[int]], + placeholder_token_id: Union[int, list[int]], ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the @@ -596,7 +612,7 @@ def make_layers( num_hidden_layers: int, layer_fn: LayerFn, prefix: str, -) -> Tuple[int, int, torch.nn.ModuleList]: +) -> tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function, taking pipeline parallelism into account. """ @@ -614,10 +630,10 @@ def make_layers( # NOTE: don't use lru_cache here because it can prevent garbage collection -_model_to_pp_missing_layer_names: Dict[int, List[str]] = {} +_model_to_pp_missing_layer_names: dict[int, list[str]] = {} -def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]: +def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]: """Get the names of the missing layers in a pipeline parallel model.""" model_id = id(model) if model_id in _model_to_pp_missing_layer_names: @@ -645,7 +661,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: for missing_layer_name in get_pp_missing_layer_names(model)) -def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): +def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int): def make_empty_intermediate_tensors( batch_size: int, @@ -684,7 +700,7 @@ def extract_layer_index(layer_name: str) -> int: - "model.encoder.layers.0.sub.1" -> ValueError """ subnames = layer_name.split(".") - int_vals: List[int] = [] + int_vals: list[int] = [] for subname in subnames: try: int_vals.append(int(subname)) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 908cd7885aa8..c6e303d6024a 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -2,7 +2,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import List, Optional, Set, Tuple, TypedDict, Union +from typing import Optional, TypedDict, Union import torch from torch import nn @@ -382,7 +382,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_positions.weight.copy_( sinusoids(*self.embed_positions.weight.shape)) - def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]): + def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]): hidden_states = [] for features in input_features: embeds = nn.functional.gelu(self.conv1(features)) @@ -460,7 +460,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, - input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], + input_features: Optional[Union[torch.Tensor, list[torch.Tensor]]], input_ids: Optional[torch.Tensor], positions: torch.Tensor, ) -> torch.Tensor: @@ -474,14 +474,14 @@ def forward( def get_encoder_outputs( self, - input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], + input_features: Optional[Union[torch.Tensor, list[torch.Tensor]]], ) -> Optional[torch.Tensor]: if input_features is None: return None return self.encoder(input_features) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), @@ -491,7 +491,7 @@ def load_weights(self, weights: Iterable[Tuple[str, (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"), ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -722,8 +722,8 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) # add fake zeros bias for k_proj to state_dict @@ -732,8 +732,8 @@ def load_weights(self, weights: Iterable[Tuple[str, def _create_fake_bias_for_k_proj( - weights: Iterable[Tuple[str, torch.Tensor]] -) -> Iterable[Tuple[str, torch.Tensor]]: + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: """ Create full zeros bias for k_proj weight in self-attn and x-attn layers. So that the bias for k_proj in qkv_proj can be initialized with zeros. diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index eddccbba5a2d..48e254bdd85b 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -6,8 +6,9 @@ architectures in a hybrid model optimized for efficient sequence modeling. The model alternates between state space model layers and attention-based layers. """ +from collections.abc import Iterable from itertools import cycle -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Optional, Union import torch from torch import nn @@ -54,7 +55,7 @@ def __init__( self, input_dim: int, rank: int, - output_dim: Union[int, List[int]], + output_dim: Union[int, list[int]], quant_config: Optional[QuantizationConfig] = None, ): """Initialize the attention layer. @@ -279,7 +280,7 @@ def __init__( self, config: Zamba2Config, bare_block_idx: int, - num_hybrid_layers: Dict[int, int], + num_hybrid_layers: dict[int, int], quant_config: Optional[QuantizationConfig] = None, ) -> None: """Initialize the MLP layer. @@ -769,8 +770,8 @@ def forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -779,7 +780,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ] params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for chkpt_weight_name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in chkpt_weight_name: @@ -914,9 +915,9 @@ def forward(self, return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers: Dict[str, + def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str, torch.Tensor], - **kwargs) -> Dict[str, torch.Tensor]: + **kwargs) -> dict[str, torch.Tensor]: """Copy inputs before CUDA graph capture. Args: @@ -930,7 +931,7 @@ def copy_inputs_before_cuda_graphs(self, input_buffers: Dict[str, input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs( - self, batch_size: int) -> Dict[str, torch.Tensor]: + self, batch_size: int) -> dict[str, torch.Tensor]: """Get inputs for sequence-length-agnostic graph capture. Args: @@ -941,7 +942,7 @@ def get_seqlen_agnostic_capture_inputs( return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Returns: @@ -1001,7 +1002,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py index dea8b0e9d471..4c5db7396c03 100644 --- a/vllm/model_executor/pooling_metadata.py +++ b/vllm/model_executor/pooling_metadata.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Any, Dict, List, Tuple +from typing import Any import torch @@ -23,9 +23,9 @@ class PoolingMetadata: def __init__( self, - seq_groups: List[Tuple[List[int], PoolingParams]], - seq_data: Dict[int, Any], # Specific data related to sequences - prompt_lens: List[int], + seq_groups: list[tuple[list[int], PoolingParams]], + seq_data: dict[int, Any], # Specific data related to sequences + prompt_lens: list[int], ) -> None: self.seq_groups = seq_groups self.seq_data = seq_data diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index d76c75d9e6ce..6b83a59b5988 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -2,7 +2,7 @@ from array import array from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch @@ -25,10 +25,10 @@ class SequenceGroupToSample: # |-- query_len ---| # Sequence ids for the sequence group in a previous step. - seq_ids: List[int] + seq_ids: list[int] sampling_params: SamplingParams # seq_id -> sequence data. - seq_data: Dict[int, SequenceData] + seq_data: dict[int, SequenceData] # The length of the sequence (all tokens seen in the past + new token to # compute attention) of the sequence group. None if it is in a decode # stage. @@ -44,9 +44,9 @@ class SequenceGroupToSample: is_prompt: bool # Query token indices from logits. to compute prompt logprob. Empty if # prompt logprob is not required. - prompt_logprob_indices: List[int] + prompt_logprob_indices: list[int] # Sample token indices from logits. Empty if sampling is not required. - sample_indices: List[int] + sample_indices: list[int] @property def do_sample(self): @@ -78,7 +78,7 @@ class SamplingMetadataCache: """Used to cache SamplingMetadata objects between scheduler iterations""" def __init__(self): - self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} + self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {} def get_cached_seq_group_to_sample(self, num_seqs): if num_seqs not in self._seq_group_to_sample_cache: @@ -130,9 +130,9 @@ def sample(logits): def __init__( self, - seq_groups: List[SequenceGroupToSample], + seq_groups: list[SequenceGroupToSample], selected_token_indices: torch.Tensor, - categorized_sample_indices: Dict[SamplingType, torch.Tensor], + categorized_sample_indices: dict[SamplingType, torch.Tensor], num_prompts: int, skip_sampler_cpu_output: bool = False, reuse_sampling_tensors: bool = False, @@ -146,12 +146,12 @@ def __init__( @staticmethod def prepare( - seq_group_metadata_list: List[SequenceGroupMetadata], - seq_lens: List[int], - query_lens: List[int], + seq_group_metadata_list: list[SequenceGroupMetadata], + seq_lens: list[int], + query_lens: list[int], device: str, pin_memory: bool, - generators: Optional[Dict[str, torch.Generator]] = None, + generators: Optional[dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, ) -> "SamplingMetadata": ( @@ -195,16 +195,16 @@ def __repr__(self) -> str: def _prepare_seq_groups( - seq_group_metadata_list: List[SequenceGroupMetadata], - seq_lens: List[int], - query_lens: List[int], + seq_group_metadata_list: list[SequenceGroupMetadata], + seq_lens: list[int], + query_lens: list[int], device: str, - generators: Optional[Dict[str, torch.Generator]] = None, + generators: Optional[dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, -) -> Tuple[ - List[SequenceGroupToSample], - List[int], - Dict[SamplingType, List[int]], +) -> tuple[ + list[SequenceGroupToSample], + list[int], + dict[SamplingType, list[int]], int, ]: """Prepare sequence groups and indices for sampling. @@ -227,17 +227,17 @@ def _prepare_seq_groups( num_prompts: Total number of prompts from `seq_group_metadata_list`. """ # Batched sequence groups for the current model forward stsep. - seq_groups: List[SequenceGroupToSample] = [] + seq_groups: list[SequenceGroupToSample] = [] # A list of token indices to sample/compute logprob. It is used to # prune the outcome logits from the model for the performance. - selected_token_indices: List[int] = [] + selected_token_indices: list[int] = [] # Used for selected_token_indices. model_output_idx = 0 # Sampling type -> ( # indices to sample/prompt logprob within pruned output logits, # indices to sample within pruned logits) - categorized_sample_indices: Dict[SamplingType, List[int]] = { + categorized_sample_indices: dict[SamplingType, list[int]] = { t: [] for t in SamplingType } @@ -265,9 +265,9 @@ def _prepare_seq_groups( # If the current seq group is in decode stage, it is None. seq_len: Optional[int] = None query_len: Optional[int] = None - prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices + prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices if cache is not None else []) - sample_indices: List[int] = (sample_obj.sample_indices + sample_indices: list[int] = (sample_obj.sample_indices if cache is not None else []) do_sample = seq_group_metadata.do_sample @@ -389,16 +389,16 @@ def from_sampling_metadata( vocab_size: int, device: torch.device, dtype: torch.dtype, - ) -> Tuple["SamplingTensors", bool, bool, bool]: - prompt_tokens: List[array] = [] - output_tokens: List[array] = [] - top_ks: List[int] = [] - temperatures: List[float] = [] - top_ps: List[float] = [] - min_ps: List[float] = [] - presence_penalties: List[float] = [] - frequency_penalties: List[float] = [] - repetition_penalties: List[float] = [] + ) -> tuple["SamplingTensors", bool, bool, bool]: + prompt_tokens: list[array] = [] + output_tokens: list[array] = [] + top_ks: list[int] = [] + temperatures: list[float] = [] + top_ps: list[float] = [] + min_ps: list[float] = [] + presence_penalties: list[float] = [] + frequency_penalties: list[float] = [] + repetition_penalties: list[float] = [] do_penalties = False do_top_p_top_k = False do_min_p = False @@ -416,7 +416,7 @@ def from_sampling_metadata( # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) - top_k = vocab_size if top_k == -1 else top_k + top_k = vocab_size if top_k < 1 else top_k if temperature < _SAMPLING_EPS: # NOTE: Zero temperature means deterministic sampling # (i.e., greedy sampling or beam search). @@ -496,15 +496,15 @@ def from_sampling_metadata( @classmethod def from_lists( cls, - temperatures: List[float], - top_ps: List[float], - top_ks: List[int], - min_ps: List[float], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], - prompt_tokens: List[array], - output_tokens: List[array], + temperatures: list[float], + top_ps: list[float], + top_ks: list[int], + min_ps: list[float], + presence_penalties: list[float], + frequency_penalties: list[float], + repetition_penalties: list[float], + prompt_tokens: list[array], + output_tokens: list[array], vocab_size: int, device: torch.device, dtype: torch.dtype, diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 04f922dfd77a..f9d89e64bd9d 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Utils for model executor.""" -from typing import Any, Dict, Optional +from typing import Any, Optional import torch @@ -12,7 +12,7 @@ def set_random_seed(seed: int) -> None: def set_weight_attrs( weight: torch.Tensor, - weight_attrs: Optional[Dict[str, Any]], + weight_attrs: Optional[dict[str, Any]], ): """Set attributes on a weight tensor. diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 756ea11311da..815e34d5ac5d 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -8,12 +8,12 @@ MULTIMODAL_REGISTRY = MultiModalRegistry() """ -The global {class}`~MultiModalRegistry` is used by model runners to -dispatch data processing according to the target model. +The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry] +is used by model runners to dispatch data processing according to the target +model. -:::{seealso} -{ref}`mm-processing` -::: +Info: + [mm_processing](../../../design/mm_processing.html) """ __all__ = [ diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index 53e289370a9f..b4cd6a90834c 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -10,6 +10,7 @@ from PIL import Image from vllm.logger import init_logger +from vllm.multimodal.image import convert_image_mode if TYPE_CHECKING: from vllm.inputs import TokensPrompt @@ -35,7 +36,8 @@ def serialize_item(cls, obj: object) -> bytes: return np.array(obj).tobytes() if isinstance(obj, Image.Image): - return cls.item_to_bytes("image", np.array(obj.convert("RGBA"))) + return cls.item_to_bytes( + "image", np.asarray(convert_image_mode(obj, "RGBA"))) if isinstance(obj, torch.Tensor): return cls.item_to_bytes("tensor", obj.numpy()) if isinstance(obj, np.ndarray): @@ -43,7 +45,7 @@ def serialize_item(cls, obj: object) -> bytes: "ndarray", { "dtype": obj.dtype.str, "shape": obj.shape, - "data": obj.data.tobytes(), + "data": obj.tobytes(), }) logger.warning( diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 939928bbf108..a63ec0bd8ada 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -22,6 +22,25 @@ def rescale_image_size(image: Image.Image, return image +# TODO: Support customizable background color to fill in. +def rgba_to_rgb( + image: Image.Image, background_color=(255, 255, 255)) -> Image.Image: + """Convert an RGBA image to RGB with filled background color.""" + assert image.mode == "RGBA" + converted = Image.new("RGB", image.size, background_color) + converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel + return converted + + +def convert_image_mode(image: Image.Image, to_mode: str): + if image.mode == to_mode: + return image + elif image.mode == "RGBA" and to_mode == "RGB": + return rgba_to_rgb(image) + else: + return image.convert(to_mode) + + class ImageMediaIO(MediaIO[Image.Image]): def __init__(self, *, image_mode: str = "RGB") -> None: @@ -32,7 +51,7 @@ def __init__(self, *, image_mode: str = "RGB") -> None: def load_bytes(self, data: bytes) -> Image.Image: image = Image.open(BytesIO(data)) image.load() - return image.convert(self.image_mode) + return convert_image_mode(image, self.image_mode) def load_base64(self, media_type: str, data: str) -> Image.Image: return self.load_bytes(base64.b64decode(data)) @@ -40,7 +59,7 @@ def load_base64(self, media_type: str, data: str) -> Image.Image: def load_file(self, filepath: Path) -> Image.Image: image = Image.open(filepath) image.load() - return image.convert(self.image_mode) + return convert_image_mode(image, self.image_mode) def encode_base64( self, @@ -51,7 +70,7 @@ def encode_base64( image = media with BytesIO() as buffer: - image = image.convert(self.image_mode) + image = convert_image_mode(image, self.image_mode) image.save(buffer, image_format) data = buffer.getvalue() diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 61d8eb62ffaf..600a34d39ef6 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -10,42 +10,45 @@ Union, cast, final) import numpy as np -import torch -import torch.types -from PIL.Image import Image -from transformers import BatchFeature from typing_extensions import NotRequired, TypeAlias from vllm.jsontree import JSONTree, json_map_leaves -from vllm.utils import full_groupby, is_list_of +from vllm.utils import LazyLoader, full_groupby, is_list_of if TYPE_CHECKING: + import torch + import torch.types + from PIL.Image import Image + from transformers.feature_extraction_utils import BatchFeature + from .hasher import MultiModalHashDict +else: + torch = LazyLoader("torch", globals(), "torch") _T = TypeVar("_T") -HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] +HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"] """ -A {class}`transformers.image_utils.ImageInput` representing a single image +A `transformers.image_utils.ImageInput` representing a single image item, which can be passed to a HuggingFace `ImageProcessor`. """ -HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor, - list[np.ndarray], list[torch.Tensor]] +HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor", + list[np.ndarray], list["torch.Tensor"]] """ -A {class}`transformers.image_utils.VideoInput` representing a single video +A `transformers.image_utils.VideoInput` representing a single video item, which can be passed to a HuggingFace `VideoProcessor`. """ -HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor] +HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"] """ Represents a single audio item, which can be passed to a HuggingFace `AudioProcessor`. """ -ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor] +ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"] """ -A {class}`transformers.image_utils.ImageInput` representing a single image +A `transformers.image_utils.ImageInput` representing a single image item, which can be passed to a HuggingFace `ImageProcessor`. Alternatively, a 3-D tensor or batch of 2-D tensors, @@ -53,9 +56,9 @@ these are directly passed to the model without HF processing. """ -VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor] +VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"] """ -A {class}`transformers.image_utils.VideoInput` representing a single video +A `transformers.image_utils.VideoInput` representing a single video item, which can be passed to a HuggingFace `VideoProcessor`. Alternatively, a 3-D tensor or batch of 2-D tensors, @@ -64,7 +67,7 @@ """ AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], - torch.Tensor] + "torch.Tensor"] """ Represents a single audio item, which can be passed to a HuggingFace `AudioProcessor`. @@ -105,7 +108,8 @@ class MultiModalDataBuiltins(TypedDict, total=False): """ A dictionary containing an entry for each modality type to input. -The built-in modalities are defined by {class}`MultiModalDataBuiltins`. +The built-in modalities are defined by +[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins]. """ @@ -132,7 +136,7 @@ class PlaceholderRange: length: int """The length of the placeholder.""" - is_embed: Optional[torch.Tensor] = None + is_embed: Optional["torch.Tensor"] = None """ A boolean mask of shape `(length,)` indicating which positions between `offset` and `offset + length` to assign embeddings to. @@ -158,15 +162,16 @@ def __eq__(self, other: object) -> bool: return nested_tensors_equal(self.is_embed, other.is_embed) -NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, - tuple[torch.Tensor, ...]] +NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"], + "torch.Tensor", tuple["torch.Tensor", ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: - """Equality check between {data}`NestedTensors` objects.""" + """Equality check between + [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.""" if isinstance(a, torch.Tensor): return isinstance(b, torch.Tensor) and torch.equal(a, b) elif isinstance(b, torch.Tensor): @@ -186,7 +191,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via -{meth}`MultiModalKwargs.batch`. +[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch]. """ @@ -194,7 +199,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: class MultiModalFieldElem: """ Represents a keyword argument corresponding to a multi-modal item - in {class}`MultiModalKwargs`. + in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]. """ modality: str @@ -205,13 +210,15 @@ class MultiModalFieldElem: key: str """ - The key of this field in {class}`MultiModalKwargs`, + The key of this field in + [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs], i.e. the name of the keyword argument to be passed to the model. """ data: NestedTensors """ - The tensor data of this field in {class}`MultiModalKwargs`, + The tensor data of this field in + [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs], i.e. the value of the keyword argument to be passed to the model. """ @@ -234,7 +241,8 @@ def __eq__(self, other: object) -> bool: class BaseMultiModalField(ABC): """ Defines how to interpret tensor data belonging to a keyword argument in - {class}`MultiModalKwargs` for multiple multi-modal items, and vice versa. + [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple + multi-modal items, and vice versa. """ def _field_factory(self, *, modality: str, key: str): @@ -259,10 +267,12 @@ def build_elems( data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: """ - Construct {class}`MultiModalFieldElem` instances to represent - the provided data. - - This is the inverse of {meth}`reduce_data`. + Construct + [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem] + instances to represent the provided data. + + This is the inverse of + [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data]. """ raise NotImplementedError @@ -272,9 +282,11 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: """ - Merge the data from multiple instances of {class}`MultiModalFieldElem`. + Merge the data from multiple instances of + [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]. - This is the inverse of {meth}`build_elems`. + This is the inverse of + [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems]. """ field_types = [type(item.field) for item in elems] if len(set(field_types)) > 1: @@ -286,9 +298,8 @@ def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: @dataclass(frozen=True) class MultiModalBatchedField(BaseMultiModalField): """ - :::{seealso} - {func}`MultiModalFieldConfig.batched` - ::: + Info: + [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched] """ def build_elems( @@ -317,10 +328,9 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: @dataclass(frozen=True) class MultiModalFlatField(BaseMultiModalField): """ - :::{seealso} - {func}`MultiModalFieldConfig.flat` - {func}`MultiModalFieldConfig.flat_from_sizes` - ::: + Info: + [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat] + [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes] """ slices: Union[Sequence[slice], Sequence[Sequence[slice]]] dim: int = 0 @@ -360,9 +370,8 @@ def _expect_same_shape(tensor: torch.Tensor): @dataclass(frozen=True) class MultiModalSharedField(BaseMultiModalField): """ - :::{seealso} - {func}`MultiModalFieldConfig.shared` - ::: + Info: + [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared] """ batch_size: int @@ -422,7 +431,7 @@ def flat(modality: str, modality: The modality of the multi-modal item that uses this keyword argument. slices: For each multi-modal item, a slice (dim=0) or a tuple of - slices (dim>0) that is used to extract the data corresponding + slices (dim>0) that is used to extract the data corresponding to it. dim: The dimension to extract data, default to 0. @@ -465,7 +474,7 @@ def flat(modality: str, @staticmethod def flat_from_sizes(modality: str, - size_per_item: torch.Tensor, + size_per_item: "torch.Tensor", dim: int = 0): """ Defines a field where an element in the batch is obtained by @@ -507,9 +516,8 @@ def flat_from_sizes(modality: str, Element 3: [[C],[C]] ``` - :::{seealso} - {func}`MultiModalFieldConfig.flat` - ::: + Info: + [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat] """ if size_per_item.ndim != 1: @@ -573,8 +581,10 @@ def build_elems( class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): """ - A collection of {class}`MultiModalFieldElem` - corresponding to a data item in {class}`MultiModalDataItems`. + A collection of + [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem] + corresponding to a data item in + [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. """ @staticmethod @@ -593,16 +603,18 @@ def modality(self) -> str: class MultiModalKwargs(UserDict[str, NestedTensors]): """ A dictionary that represents the keyword arguments to - {meth}`~torch.nn.Module.forward`. + [`torch.nn.Module.forward`][]. The metadata `items` enables us to obtain the keyword arguments - corresponding to each data item in {class}`MultiModalDataItems`, via - {meth}`get_item` and {meth}`get_items`. + corresponding to each data item in + [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via + [`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and + [`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items]. """ @staticmethod def from_hf_inputs( - hf_inputs: BatchFeature, + hf_inputs: "BatchFeature", config_by_key: Mapping[str, MultiModalFieldConfig], ): # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` @@ -636,7 +648,9 @@ def from_hf_inputs( @staticmethod def from_items(items: Sequence[MultiModalKwargsItem]): - """Construct a new {class}`MultiModalKwargs` from multiple items.""" + """Construct a new + [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] + from multiple items.""" elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) for item in items: for key, elem in item.items(): @@ -732,11 +746,17 @@ def as_kwargs( batched_inputs: BatchedTensorInputs, *, device: torch.types.Device, + dtype: Optional[torch.dtype] = None, ) -> BatchedTensorInputs: json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + def maybe_cast_dtype(x: torch.Tensor): + # This mimics the behavior of transformers.BatchFeature + return x.to(dtype=dtype) if x.is_floating_point() else x + json_mapped = json_map_leaves( - lambda x: x.to(device, non_blocking=True), + # NOTE: Cast the dtype before sending it to device + lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True), json_inputs, ) @@ -792,7 +812,7 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: return self._items_by_modality[modality] -MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] +MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] """ A dictionary containing placeholder ranges for each modality. """ @@ -801,7 +821,7 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: class MultiModalInputs(TypedDict): """ Represents the outputs of - {class}`vllm.multimodal.processing.BaseMultiModalProcessor`, + [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor], ready to be passed to vLLM internals. """ @@ -823,7 +843,7 @@ class MultiModalInputs(TypedDict): mm_hashes: Optional["MultiModalHashDict"] """The hashes of the multi-modal data.""" - mm_placeholders: MultiModalPlaceholderDict + mm_placeholders: "MultiModalPlaceholderDict" """ For each modality, information about the placeholder tokens in `prompt_token_ids`. @@ -837,7 +857,8 @@ class MultiModalInputs(TypedDict): class MultiModalEncDecInputs(MultiModalInputs): """ - Represents the outputs of {class}`vllm.multimodal.EncDecMultiModalProcessor` + Represents the outputs of + [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor] ready to be passed to vLLM internals. """ diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index f9588431c8ef..63af842747a5 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -8,11 +8,9 @@ import numpy as np import torch -from PIL.Image import Image -from transformers import BatchFeature from typing_extensions import TypeAlias, TypeGuard, assert_never -from vllm.utils import is_list_of +from vllm.utils import LazyLoader, is_list_of from .audio import AudioResampler from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, @@ -22,10 +20,16 @@ _T = TypeVar("_T") _I = TypeVar("_I") +if TYPE_CHECKING: + import PIL.Image as PILImage +else: + PILImage = LazyLoader("PILImage", globals(), "PIL.Image") + class ModalityDataItems(ABC, Generic[_T, _I]): """ - Represents data items for a modality in {class}`MultiModalDataItems`. + Represents data items for a modality in + [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. """ def __init__(self, data: _T, modality: str) -> None: @@ -131,6 +135,8 @@ def __init__( Mapping[str, MultiModalFieldConfig], ], ) -> None: + from transformers.feature_extraction_utils import BatchFeature + super().__init__(data, modality) missing_required_data_keys = required_fields - data.keys() @@ -200,7 +206,7 @@ def __init__(self, data: Sequence[HfImageItem]) -> None: def get_image_size(self, item_idx: int) -> ImageSize: image = self.get(item_idx) - if isinstance(image, Image): + if isinstance(image, PILImage.Image): return ImageSize(*image.size) if isinstance(image, (np.ndarray, torch.Tensor)): _, h, w = image.shape @@ -226,7 +232,7 @@ def get_num_frames(self, item_idx: int) -> int: def get_frame_size(self, item_idx: int) -> ImageSize: image = self.get(item_idx)[0] # Assume that the video isn't empty - if isinstance(image, Image): + if isinstance(image, PILImage.Image): return ImageSize(*image.size) if isinstance(image, (np.ndarray, torch.Tensor)): _, h, w = image.shape @@ -246,15 +252,15 @@ def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): """ - As {data}`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized - such that each entry corresponds to a list. + As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but + normalized such that each entry corresponds to a list. """ def get_count(self, modality: str, *, strict: bool = True) -> int: """ Get the number of data items belonging to a modality. - - If `strict=False`, return `0` instead of raising {exc}`KeyError` + + If `strict=False`, return `0` instead of raising [`KeyError`][] even if the modality is not found. """ if modality not in self: @@ -300,8 +306,8 @@ def get_items( class MultiModalDataParser: """ - Parses {data}`~vllm.multimodal.inputs.MultiModalDataDict` into - {class}`MultiModalDataItems`. + Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict] + into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. Args: target_sr (float, optional): Enables automatic resampling of audio @@ -399,7 +405,7 @@ def _parse_image_data( if self._is_embeddings(data): return ImageEmbeddingItems(data) - if (isinstance(data, Image) + if (isinstance(data, PILImage.Image) or isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 3): data_items = [data] @@ -420,7 +426,7 @@ def _parse_video_data( if self._is_embeddings(data): return VideoEmbeddingItems(data) - if (is_list_of(data, Image) + if (is_list_of(data, PILImage.Image) or isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 4): data_items = [data] diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 27b059b3ee62..aa7914e40cbf 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re import sys from abc import ABC, abstractmethod from collections import defaultdict @@ -12,8 +11,8 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, TypeVar, Union, cast) +import regex as re import torch -from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from typing_extensions import assert_never from vllm.inputs import InputProcessingContext @@ -31,6 +30,10 @@ MultiModalDataParser) if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + from transformers.feature_extraction_utils import BatchFeature + from transformers.processing_utils import ProcessorMixin + from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -111,13 +114,14 @@ class PromptUpdateDetails(Generic[_S]): is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None """ - Given {attr}`full`, return a boolean mask of shape `(len(full),)` - indicating which positions of `full` to assign embeddings to. + Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], + return a boolean mask of shape `(len(full),)` indicating which positions + of `full` to assign embeddings to. `None` (default) means to assign embeddings to all positions of `full`. The embeddings are obtained by calling - {class}`SupportsMultiModal.get_multimodal_embeddings`. + [`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings]. """ @staticmethod @@ -156,13 +160,15 @@ def select_token_id( The token sequence or text that are part of the update. If only part of the content corresponds to feature placeholders, you can -use {class}`PromptUpdateDetails` to specify which part. +use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to +specify which part. """ PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], PromptUpdateInfo] """ -Given the index of the processed item within {attr}`modality`, +Given the index of the processed item within +[`modality`][vllm.multimodal.processing.PromptUpdate.modality], output the corresponding token sequence (or text). For convenience, you can directly pass in the token sequence (or text) @@ -257,8 +263,10 @@ class PromptInsertion(PromptUpdate): insertion: PromptUpdateContent = field(repr=False) """ - Given the index of the processed item within {attr}`modality`, - output the token sequence (or text) to insert right after {attr}`target`. + Given the index of the processed item within + [`modality`][vllm.multimodal.processing.PromptUpdate.modality], + output the token sequence (or text) to insert right after + [`target`][vllm.multimodal.processing.PromptUpdate.target]. For convenience, you can directly pass in the token sequence (or text) instead of a function if it does not depend on the input. @@ -329,8 +337,10 @@ class PromptReplacement(PromptUpdate): replacement: PromptUpdateContent = field(repr=False) """ - Given the index of the processed item within {attr}`modality`, - output the token sequence (or text) to replace {attr}`target`. + Given the index of the processed item within + [`modality`][vllm.multimodal.processing.PromptUpdate.modality], + output the token sequence (or text) to replace + [`target`][vllm.multimodal.processing.PromptUpdate.target]. For convenience, you can directly pass in the token sequence (or text) instead of a function if it does not depend on the input. @@ -384,14 +394,16 @@ def modality(self) -> str: def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: - """Convenience function to apply {func}`full_groupby` based on modality.""" + """Convenience function to apply [`full_groupby`][vllm.utils.full_groupby] + based on modality.""" return full_groupby(values, key=lambda x: x.modality) @dataclass class _BoundPromptSequence: """ - A {data}`_PromptSeq` bound to a tokenizer to automatically + A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound + to a tokenizer to automatically convert between token sequence and text representations. """ tokenizer: AnyTokenizer = field(repr=False) @@ -443,9 +455,11 @@ class _BoundPromptContent: @dataclass class BoundPromptUpdate: """ - A {class}`PromptUpdate` bound to a tokenizer to automatically convert - {attr}`target` and the result of {meth}`get_content` between - token sequence and text representations. + A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound + to a tokenizer to automatically convert + [`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of + [`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content] + between token sequence and text representations. """ _origin: PromptUpdate tokenizer: AnyTokenizer = field(repr=False) @@ -479,7 +493,8 @@ def mode(self) -> UpdateMode: def get_content(self, item_idx: int) -> _BoundPromptContent: """ - Given the index of the processed item within {attr}`modality`, + Given the index of the processed item within + [`modality`][vllm.multimodal.processing.PromptUpdate.modality], output the token sequence (or text) to update. """ content = self.content @@ -1016,7 +1031,8 @@ def put( ) -> None: """ Put a processed multi-modal item into the cache - according to its dependencies (see {meth}`get`). + according to its dependencies + (see [`get`][vllm.multimodal.processing.ProcessingCache.get]). """ cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, **{modality: input_item}, @@ -1026,6 +1042,11 @@ def put( def put_item(self, item: ProcessingCacheItem) -> None: self._cache[item.key] = item.value + def reset(self) -> bool: + self._cache.clear() + + return True + class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" @@ -1042,10 +1063,10 @@ def model_id(self) -> str: def get_tokenizer(self) -> AnyTokenizer: return self.ctx.tokenizer - def get_hf_config(self) -> PretrainedConfig: + def get_hf_config(self) -> "PretrainedConfig": return self.ctx.get_hf_config() - def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: + def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin": """ Subclasses can override this method to handle specific kwargs from model config or user inputs. @@ -1083,7 +1104,8 @@ def get_allowed_mm_limits(self) -> Mapping[str, int]: MultiModalHashes = dict[str, list[str]] """ -A collection of hashes with a similar structure as {class}`MultiModalKwargs`. +A collection of hashes with a similar structure as +[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]. """ @@ -1091,7 +1113,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ Abstract base class to process multi-modal inputs to be used in vLLM. - Not to be confused with {class}`transformers.ProcessorMixin`. + Not to be confused with `transformers.ProcessorMixin`. """ def __init__(self, @@ -1118,10 +1140,12 @@ def __call__( def _get_data_parser(self) -> MultiModalDataParser: """ Construct a parser to preprocess multi-modal data items - before passing them to {meth}`_get_hf_mm_data`. + before passing them to + [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. You can support additional modalities by creating a subclass - of {class}`MultiModalDataParser` that has additional subparsers. + of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser] + that has additional subparsers. """ return MultiModalDataParser() @@ -1130,8 +1154,11 @@ def _to_mm_items( mm_data: MultiModalDataDict, ) -> MultiModalDataItems: """ - Normalize {class}`MultiModalDataDict` to {class}`MultiModalDataItems` - before passing them to {meth}`_get_hf_mm_data`. + Normalize + [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict] + to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems] + before passing them to + [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. """ mm_items = self.data_parser.parse_mm_data(mm_data) supported_mm_limits = self.info.get_supported_mm_limits() @@ -1160,7 +1187,7 @@ def _to_mm_items( @abstractmethod def _get_mm_fields_config( self, - hf_inputs: BatchFeature, + hf_inputs: "BatchFeature", hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: """Given the HF-processed data, output the metadata of each field.""" @@ -1183,7 +1210,8 @@ def _get_prompt_updates( inputs. Moreover, this information is critical to determine the token positions - in order to construct {class}`~vllm-multimodal.input.PlaceholderRange` + in order to construct + [`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange] for each multi-modal item. """ raise NotImplementedError @@ -1217,7 +1245,7 @@ def _call_hf_processor( # This refers to the data to be passed to HF processor. mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], - ) -> BatchFeature: + ) -> "BatchFeature": """ Call the HF processor on the prompt text and associated multi-modal data. @@ -1307,7 +1335,9 @@ def _apply_hf_processor_tokens_only( Most HF processors accept prompt text but not prompt tokens. If the HF processor adds or removes tokens that are not related to multi-modal data, you should override this method so it is consistent - with the output of {meth}`_apply_hf_processor_text_only` on the + with the output of + [`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only] + on the corresponding text. """ return prompt_tokens @@ -1322,7 +1352,8 @@ def _apply_hf_processor_mm_only( Since HF processor requires that text and multi-modal items correspond to each other, we generate dummy text using - {class}`DummyInputsBuilder` to go along with the multi-modal data. + [`DummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder] + to go along with the multi-modal data. """ mm_counts = mm_items.get_all_counts() diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index b5875124c126..a85b13fb2387 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -3,7 +3,7 @@ from abc import ABC from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Generic, NamedTuple, Optional, TypeVar, cast +from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast import numpy as np import numpy.typing as npt @@ -25,9 +25,9 @@ class ProcessorInputs: """ Represents the keyword arguments to - {meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`. + [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][]. """ - prompt_text: str + prompt: Union[str, list[int]] mm_data: MultiModalDataDict hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) @@ -75,7 +75,12 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: "in an upcoming release.") seq_len = self.info.ctx.model_config.max_model_len - return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text + + prompt = self.get_dummy_processor_inputs(seq_len, mm_counts).prompt + if not isinstance(prompt, str): + prompt = self.info.get_tokenizer().decode(prompt) + + return prompt # TODO: @abstractmethod after transition def get_dummy_mm_data( @@ -101,7 +106,7 @@ def get_dummy_processor_inputs( dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) - return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data) + return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data) def _get_dummy_audios( self, @@ -177,7 +182,7 @@ def _get_dummy_mm_inputs( seq_len, mm_counts) return self.processor.apply( - prompt=processor_inputs.prompt_text, + prompt=processor_inputs.prompt, mm_data=processor_inputs.mm_data, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 3e62f4c43e10..b9f5cee922a7 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -29,7 +29,11 @@ class ProcessingInfoFactory(Protocol[_I_co]): - """Constructs a {class}`MultiModalProcessor` instance from the context.""" + """ + Constructs a + [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor] + instance from the context. + """ def __call__( self, @@ -40,7 +44,9 @@ def __call__( class DummyInputsBuilderFactory(Protocol[_I]): """ - Constructs a {class}`BaseDummyInputsBuilder` instance from the context. + Constructs a + [`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder] + instance from the context. """ def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: @@ -48,7 +54,11 @@ def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: class MultiModalProcessorFactory(Protocol[_I]): - """Constructs a {class}`MultiModalProcessor` instance from the context.""" + """ + Constructs a + [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor] + instance from the context. + """ def __call__( self, @@ -88,6 +98,12 @@ def __init__(self) -> None: self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) + def reset_processor_cache(self) -> bool: + """Reset the multi-modal processing cache.""" + self._processing_cache.reset() + + return True # Success + @deprecated("Legacy input processor/mapper pipeline has been removed. " "Please update your model runner to use " "`seq_group_metadata.multi_modal_data` directly without " @@ -106,7 +122,7 @@ def get_max_tokens_per_item_by_modality( if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=True) + processor = self.create_processor(model_config, disable_cache=False) profiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len @@ -149,8 +165,6 @@ def get_max_tokens_by_modality( """ Get the maximum number of tokens from each modality for profiling the memory usage of a model. - - See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details. """ mm_limits = self.get_mm_limits_per_prompt(model_config) @@ -164,8 +178,6 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. - - See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details. """ return sum(self.get_max_tokens_by_modality(model_config).values()) @@ -190,7 +202,7 @@ def get_mm_limits_per_prompt( if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=True) + processor = self.create_processor(model_config, disable_cache=False) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() @@ -207,10 +219,6 @@ def register_processor( When the model receives multi-modal data, the provided function is invoked to transform the data into a dictionary of model inputs. - - :::{seealso} - {ref}`mm-processing` - ::: """ def wrapper(model_cls: N) -> N: @@ -253,10 +261,6 @@ def create_processor( ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. - - :::{seealso} - {ref}`mm-processing` - ::: """ if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") @@ -286,7 +290,7 @@ def get_decoder_dummy_data( The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=True) + processor = self.create_processor(model_config, disable_cache=False) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) @@ -310,7 +314,7 @@ def get_encoder_dummy_data( The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=True) + processor = self.create_processor(model_config, disable_cache=False) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index aef5f669ac68..9ddba67bff70 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -259,7 +259,8 @@ def fetch_image_embedding( global_media_connector = MediaConnector() -"""The global {class}`MediaConnector` instance used by vLLM.""" +"""The global [`MediaConnector`][vllm.multimodal.utils.MediaConnector] +instance used by vLLM.""" fetch_audio = global_media_connector.fetch_audio fetch_image = global_media_connector.fetch_image diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 72e9b65d763c..261d56abad9c 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import base64 +from abc import abstractmethod from functools import partial from io import BytesIO from pathlib import Path @@ -9,6 +10,8 @@ import numpy.typing as npt from PIL import Image +from vllm import envs + from .base import MediaIO from .image import ImageMediaIO @@ -48,10 +51,35 @@ def sample_frames_from_video(frames: npt.NDArray, class VideoLoader: @classmethod - def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray: + @abstractmethod + def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: raise NotImplementedError +class VideoLoaderRegistry: + + def __init__(self) -> None: + self.name2class: dict[str, type] = {} + + def register(self, name: str): + + def wrap(cls_to_register): + self.name2class[name] = cls_to_register + return cls_to_register + + return wrap + + @staticmethod + def load(cls_name: str) -> VideoLoader: + cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name) + assert cls is not None, f"VideoLoader class {cls_name} not found" + return cls() + + +VIDEO_LOADER_REGISTRY = VideoLoaderRegistry() + + +@VIDEO_LOADER_REGISTRY.register("opencv") class OpenCVVideoBackend(VideoLoader): def get_cv2_video_api(self): @@ -122,7 +150,8 @@ def __init__( self.image_io = image_io self.num_frames = num_frames - self.video_loader = OpenCVVideoBackend + video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND + self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) def load_bytes(self, data: bytes) -> npt.NDArray: return self.video_loader.load_bytes(data, self.num_frames) @@ -135,7 +164,7 @@ def load_base64(self, media_type: str, data: str) -> npt.NDArray: ) return np.stack([ - np.array(load_frame(frame_data)) + np.asarray(load_frame(frame_data)) for frame_data in data.split(",") ]) diff --git a/vllm/outputs.py b/vllm/outputs.py index 65a6ed01451d..33cc50c872b6 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -4,17 +4,20 @@ from collections.abc import MutableSequence from collections.abc import Sequence as GenericSequence from dataclasses import dataclass -from typing import Generic, Optional, Union +from typing import Any, Generic, Optional, Union import torch from typing_extensions import TypeVar, deprecated +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) +logger = init_logger(__name__) + @dataclass class CompletionOutput: @@ -103,6 +106,7 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. + kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -120,7 +124,14 @@ def __init__( num_cached_tokens: Optional[int] = None, *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, + kv_transfer_params: Optional[dict[str, Any]] = None, + # Forward compatibility, code that uses args added in new release can + # still run with older versions of vLLM without breaking. + **kwargs: Any, ) -> None: + if kwargs: + logger.warning_once("RequestOutput: Ignoring extra arguments: %s", + str(kwargs)) self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids @@ -133,11 +144,13 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.kv_transfer_params = kv_transfer_params def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished + self.kv_transfer_params = next_output.kv_transfer_params for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): @@ -378,15 +391,6 @@ def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput": prompt_token_ids, finished) def __repr__(self): - """ - Returns a string representation of an PoolingRequestOutput instance. - - The representation includes the request_id and the number of outputs, - providing a quick overview of the pooling request's results. - - Returns: - str: A string representation of the PoolingRequestOutput instance. - """ return (f"{type(self).__name__}(request_id={self.request_id!r}, " f"outputs={self.outputs!r}, " f"prompt_token_ids={self.prompt_token_ids}, " diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index b1df4fd1339b..00d00d05f47a 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -42,7 +42,6 @@ def tpu_platform_plugin() -> Optional[str]: logger.debug("Confirmed TPU platform is available.") except Exception as e: logger.debug("TPU platform is not available because: %s", str(e)) - pass return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None @@ -112,7 +111,6 @@ def rocm_platform_plugin() -> Optional[str]: amdsmi.amdsmi_shut_down() except Exception as e: logger.debug("ROCm platform is not available because: %s", str(e)) - pass return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None @@ -130,7 +128,6 @@ def hpu_platform_plugin() -> Optional[str]: "habana_frameworks is not found.") except Exception as e: logger.debug("HPU platform is not available because: %s", str(e)) - pass return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None @@ -148,7 +145,6 @@ def xpu_platform_plugin() -> Optional[str]: logger.debug("Confirmed XPU platform is available.") except Exception as e: logger.debug("XPU platform is not available because: %s", str(e)) - pass return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None @@ -170,7 +166,6 @@ def cpu_platform_plugin() -> Optional[str]: except Exception as e: logger.debug("CPU platform is not available because: %s", str(e)) - pass return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index e45522a4c407..c79c603c02eb 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -9,8 +9,9 @@ import torch from vllm.logger import init_logger +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import Platform, PlatformEnum, _Backend +from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend logger = init_logger(__name__) @@ -26,6 +27,20 @@ class CpuPlatform(Platform): device_type: str = "cpu" dispatch_key: str = "CPU" + @property + def supported_dtypes(self) -> list: + if self.get_cpu_architecture() == CpuArchEnum.POWERPC: + return [torch.bfloat16, torch.float32] + elif sys.platform.startswith( + "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM: + # TODO: change this condition to check if the platform support bf16 + # instead of checking the OS. For instance M2 shall supports bf16 + # already. But we need to modify `cpu_extension.cmake` to activate + # the feature in the build. + return [torch.float16, torch.float32] + # x86/aarch64 CPU has supported both bf16 and fp16 natively. + return [torch.bfloat16, torch.float16, torch.float32] + @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "cpu" @@ -60,7 +75,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: import vllm.envs as envs from vllm.utils import GiB_bytes model_config = vllm_config.model_config - # Reminder: Please update docs/source/features/compatibility_matrix.md + # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid if not model_config.enforce_eager: model_config.enforce_eager = True @@ -163,6 +178,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: " set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.") os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on CPU.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ab03dece8c13..8bb3dfe7457a 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -5,8 +5,7 @@ import os from functools import wraps -from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, - Union) +from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union import torch from typing_extensions import ParamSpec @@ -34,24 +33,6 @@ torch.backends.cuda.enable_cudnn_sdp(False) -def device_id_to_physical_device_id(device_id: int) -> int: - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - if device_ids == [""]: - msg = ( - "CUDA_VISIBLE_DEVICES is set to empty string, which means" - " GPU support is disabled. If you are using ray, please unset" - " the environment variable `CUDA_VISIBLE_DEVICES` inside the" - " worker/actor. " - "Check https://github.com/vllm-project/vllm/issues/8402 for" - " more information.") - raise RuntimeError(msg) - physical_device_id = device_ids[device_id] - return int(physical_device_id) - else: - return device_id - - def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) @@ -73,6 +54,19 @@ class CudaPlatformBase(Platform): ray_device_key: str = "GPU" device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + @property + def supported_dtypes(self) -> list[torch.dtype]: + if self.has_device_capability(80): + # Ampere and Hopper or later NVIDIA GPUs. + return [torch.bfloat16, torch.float16, torch.float32] + elif (not self.has_device_capability(80) + ) and self.has_device_capability(60): + # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported + return [torch.float16, torch.float32] + # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported, + # though vLLM doesn't support these GPUs. + return [torch.float32] + @classmethod def get_device_capability(cls, device_id: int = 0 @@ -98,7 +92,7 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return True @classmethod - def is_fully_connected(cls, device_ids: List[int]) -> bool: + def is_fully_connected(cls, device_ids: list[int]) -> bool: raise NotImplementedError @classmethod @@ -164,6 +158,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "currently not supported with CUDA Graphs.") vllm_config.model_config.enforce_eager = True compilation_config.use_cudagraph = False + # FIXME: inductor breaks cudagraph (from @bnell) + compilation_config.use_inductor = False @classmethod def get_current_memory_usage(cls, @@ -227,6 +223,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, elif selected_backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") return "vllm.attention.backends.xformers.XFormersBackend" + elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: + logger.info("Using DualChunkFlashAttention backend.") + return ("vllm.attention.backends.dual_chunk_flash_attn." + "DualChunkFlashAttentionBackend") elif selected_backend == _Backend.FLASH_ATTN: pass elif selected_backend: @@ -312,6 +312,10 @@ def supports_v1(cls, model_config: "ModelConfig") -> bool: def use_custom_allreduce(cls) -> bool: return True + @classmethod + def get_piecewise_backend_cls(cls) -> str: + return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, @@ -325,7 +329,7 @@ def get_device_capability(cls, device_id: int = 0 ) -> Optional[DeviceCapability]: try: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) return DeviceCapability(major=major, minor=minor) @@ -336,7 +340,7 @@ def get_device_capability(cls, @with_nvml_context def has_device_capability( cls, - capability: Union[Tuple[int, int], int], + capability: Union[tuple[int, int], int], device_id: int = 0, ) -> bool: try: @@ -347,26 +351,26 @@ def has_device_capability( @classmethod @with_nvml_context def get_device_name(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) return cls._get_physical_device_name(physical_device_id) @classmethod @with_nvml_context def get_device_uuid(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) return pynvml.nvmlDeviceGetUUID(handle) @classmethod @with_nvml_context def get_device_total_memory(cls, device_id: int = 0) -> int: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) @classmethod @with_nvml_context - def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: + def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ @@ -431,7 +435,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: return device_props.total_memory @classmethod - def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: + def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: logger.exception( "NVLink detection not possible, as context support was" " not found. Assuming no NVLink available.") diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 456b054b2b43..a8dd7df9f2e3 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -7,6 +7,7 @@ from vllm import envs from vllm.logger import init_logger +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum, _Backend @@ -38,8 +39,8 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return True - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): return torch.no_grad() @classmethod @@ -80,6 +81,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.") os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on HPU.") diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 531b13da0fa1..504c3b42a75d 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import enum +import os import platform import random from platform import uname -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, NamedTuple, Optional, Union import numpy as np import torch @@ -39,7 +40,8 @@ class _Backend(enum.Enum): TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() - ROCM_AITER_MLA = enum.auto() + ROCM_AITER_MLA = enum.auto() # Supported by V1 + ROCM_AITER_MLA_VLLM_V1 = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 @@ -49,6 +51,7 @@ class _Backend(enum.Enum): PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto() + DUAL_CHUNK_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() @@ -81,7 +84,7 @@ def as_version_str(self) -> str: def to_int(self) -> int: """ - Express device capability as an integer ``<major><minor>``. + Express device capability as an integer `<major><minor>`. It is assumed that the minor version is always a single digit. """ @@ -121,6 +124,14 @@ class Platform: additional_env_vars: list[str] = [] + @property + def supported_dtypes(self) -> list[torch.dtype]: + """Returns the supported dtypes for the current platform.""" + # Be careful with the order of the dtypes. The first dtype will + # be used as the default dtype fallback for the current platform, + # when encountering unsupported dtypes in "auto" dtype. + return [torch.bfloat16, torch.float16, torch.float32] + def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA @@ -146,12 +157,30 @@ def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT def is_cuda_alike(self) -> bool: - """Stateless version of {func}`torch.cuda.is_available`.""" + """Stateless version of [torch.cuda.is_available][].""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) def is_sleep_mode_available(self) -> bool: return self._enum == PlatformEnum.CUDA + @classmethod + def device_id_to_physical_device_id(cls, device_id: int): + if cls.device_control_env_var in os.environ: + device_ids = os.environ[cls.device_control_env_var].split(",") + if device_ids == [""]: + msg = (f"{cls.device_control_env_var} is set to empty string, " + "which means current platform support is disabled. If " + "you are using ray, please unset the environment " + f"variable `{cls.device_control_env_var}` inside the " + "worker/actor. Check " + "https://github.com/vllm-project/vllm/issues/8402 for " + "more information.") + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], @@ -165,22 +194,23 @@ def get_device_capability( cls, device_id: int = 0, ) -> Optional[DeviceCapability]: - """Stateless version of {func}`torch.cuda.get_device_capability`.""" + """Stateless version of [torch.cuda.get_device_capability][].""" return None @classmethod def has_device_capability( cls, - capability: Union[Tuple[int, int], int], + capability: Union[tuple[int, int], int], device_id: int = 0, ) -> bool: """ Test whether this platform is compatible with a device capability. - The ``capability`` argument can either be: + The `capability` argument can either be: - - A tuple ``(major, minor)``. - - An integer ``<major><minor>``. (See {meth}`DeviceCapability.to_int`) + - A tuple `(major, minor)`. + - An integer `<major><minor>`. (See + [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int]) """ current_capability = cls.get_device_capability(device_id=device_id) if current_capability is None: @@ -333,7 +363,7 @@ def get_punica_wrapper(cls) -> str: raise NotImplementedError @classmethod - def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]: """ Return the platform specific values for (-inf, inf) """ @@ -449,6 +479,13 @@ def get_cu_count(cls, device_id: int = 0) -> int: """ raise NotImplementedError + @classmethod + def get_piecewise_backend_cls(cls) -> str: + """ + Get piecewise backend class for piecewise graph. + """ + return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 71f7c718cdf9..9cd49fd34804 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -6,6 +6,7 @@ from vllm import envs from vllm.logger import init_logger +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum @@ -51,12 +52,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: assert (vllm_config.lora_config is None), "LoRA is not supported for Neuron backend." - cache_config = vllm_config.cache_config - if cache_config: + if vllm_config.cache_config and vllm_config.model_config: # neuron needs block_size = max_model_len vllm_config.cache_config.block_size = \ vllm_config.model_config.max_model_len # type: ignore + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 03b49e823535..e1dcd9870b6c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -2,7 +2,7 @@ import os from functools import cache, lru_cache, wraps -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Optional import torch @@ -35,7 +35,7 @@ logger.warning("Failed to import from vllm._rocm_C with %r", e) # Models not supported by ROCm. -_ROCM_UNSUPPORTED_MODELS: List[str] = [] +_ROCM_UNSUPPORTED_MODELS: list[str] = [] # Models partially supported by ROCm. # Architecture -> Reason. @@ -43,7 +43,7 @@ "Triton flash attention. For half-precision SWA support, " "please use CK flash attention by setting " "`VLLM_USE_TRITON_FLASH_ATTN=0`") -_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { +_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = { "Qwen2ForCausalLM": _ROCM_SWA_REASON, "MistralForCausalLM": @@ -58,7 +58,7 @@ "excessive use of shared memory. If this happens, disable Triton FA " "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } -_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = { +_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = { "0x74a0": "AMD_Instinct_MI300A", "0x74a1": "AMD_Instinct_MI300X", "0x74b5": "AMD_Instinct_MI300X", # MI300X VF @@ -95,15 +95,6 @@ def wrapper(*args, **kwargs): return wrapper -def device_id_to_physical_device_id(device_id: int) -> int: - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - physical_device_id = device_ids[device_id] - return int(physical_device_id) - else: - return device_id - - @cache def on_mi250_mi300() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName @@ -111,26 +102,42 @@ def on_mi250_mi300() -> bool: @cache -def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, gqa_ratio: int, - max_seq_len: int, - sliding_window: int) -> bool: +def use_rocm_custom_paged_attention( + qtype: torch.dtype, + head_size: int, + block_size: int, + gqa_ratio: int, + max_seq_len: int, + sliding_window: int, + kv_cache_dtype: str, + alibi_slopes: Optional[torch.Tensor] = None) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) + ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) - # rocm custom page attention not support on gfx1* # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. - return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) - and (qtype == torch.half or qtype == torch.bfloat16) - and (head_size == 64 or head_size == 128) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER)) + if ON_GFX9: + return ((not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + and envs.VLLM_ROCM_USE_AITER)) + + else: + return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and head_size == 128 and block_size == 16 + and (gqa_ratio >= 3 and gqa_ratio <= 16) + and max_seq_len <= 32768 and alibi_slopes is None + and kv_cache_dtype == "auto" + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) class RocmPlatform(Platform): @@ -168,10 +175,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}.") - elif selected_backend == _Backend.ROCM_AITER_MLA: + elif selected_backend == _Backend.ROCM_AITER_MLA \ + or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: if block_size == 1: - logger.info("Using AITER MLA backend.") - return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + if use_v1: + logger.info("Using AITER MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + else: + logger.info("Using AITER MLA backend") + return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: raise ValueError( f" The selected backend, {selected_backend.name}," @@ -205,9 +217,9 @@ def get_device_capability(cls, major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) - @staticmethod + @classmethod @with_amdsmi_context - def is_fully_connected(physical_device_ids: List[int]) -> bool: + def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: """ Query if the set of gpus are fully connected by xgmi (1 hop) """ @@ -233,7 +245,7 @@ def is_fully_connected(physical_device_ids: List[int]) -> bool: @with_amdsmi_context @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = amdsmi_get_processor_handles()[physical_device_id] asic_info = amdsmi_get_gpu_asic_info(handle) device_name: str = asic_info["device_id"] @@ -366,3 +378,11 @@ def use_custom_allreduce(cls) -> bool: def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( device_id).multi_processor_count + + @classmethod + def is_navi(cls) -> bool: + return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName + + @classmethod + def get_piecewise_backend_cls(cls) -> str: + return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 2782a3866d76..0173b15697cf 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union, cast import torch from tpu_info import device @@ -9,13 +9,15 @@ from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import ModelConfig, VllmConfig + from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams else: + BlockSize = None ModelConfig = None VllmConfig = None PoolingParams = None @@ -72,7 +74,7 @@ def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" @classmethod - def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]: return torch.finfo(dtype).min, torch.finfo(dtype).max @classmethod @@ -94,7 +96,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config # For v0, the default block size is 16. if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 + cache_config.block_size = cast(BlockSize, 16) compilation_config = vllm_config.compilation_config # TPU only supports DYNAMO_ONCE compilation level @@ -118,7 +120,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.v1.attention.backends.pallas import ( PallasAttentionBackend) cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) + vllm_config) # type: ignore[assignment] min_page_size = PallasAttentionBackend.get_min_page_size( vllm_config) if min_page_size > cache_config.block_size: @@ -128,7 +130,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config.block_size, min_page_size, ) - cache_config.block_size = min_page_size + cache_config.block_size = min_page_size # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config @@ -160,6 +162,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Forcing --disable_chunked_mm_input.") scheduler_config.disable_chunked_mm_input = True + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on TPU.") @@ -193,3 +205,11 @@ def validate_request( if params.sampling_type == SamplingType.RANDOM_SEED: raise ValueError( "Torch XLA does not support per-request seed.") + + +try: + from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform + TpuPlatform = TpuCommonsPlatform # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuPlatform") + pass diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 225e756cd7ce..b2a6ad5d77db 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -5,6 +5,7 @@ import torch from vllm.logger import init_logger +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum, _Backend @@ -36,15 +37,17 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, logger.info("Using IPEX attention backend.") return "vllm.attention.backends.ipex_attn.IpexAttnBackend" - @staticmethod + @classmethod def get_device_capability( - device_id: int = 0) -> Optional[DeviceCapability]: + cls, + device_id: int = 0, + ) -> Optional[DeviceCapability]: # capacity format differs from cuda's and will cause unexpected # failure, so use None directly return None - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: return torch.xpu.get_device_name(device_id) @classmethod @@ -56,8 +59,8 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return True - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): return torch.no_grad() @classmethod @@ -113,6 +116,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "ray" + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on XPU.") diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 389cb8728103..2884cb46fecd 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -2,7 +2,7 @@ import logging import os -from typing import Callable, Dict +from typing import Any, Callable import torch @@ -14,7 +14,7 @@ plugins_loaded = False -def load_plugins_by_group(group: str) -> Dict[str, Callable]: +def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: import sys if sys.version_info < (3, 10): from importlib_metadata import entry_points @@ -27,23 +27,27 @@ def load_plugins_by_group(group: str) -> Dict[str, Callable]: if len(discovered_plugins) == 0: logger.debug("No plugins for group %s found.", group) return {} + logger.info("Available plugins for group %s:", group) for plugin in discovered_plugins: - logger.info("name=%s, value=%s", plugin.name, plugin.value) + logger.info("- %s -> %s", plugin.name, plugin.value) + if allowed_plugins is None: - logger.info("all available plugins for group %s will be loaded.", - group) - logger.info("set environment variable VLLM_PLUGINS to control" - " which plugins to load.") - plugins = {} + logger.info("All plugins in this group will be loaded. " + "Set `VLLM_PLUGINS` to control which plugins to load.") + + plugins = dict[str, Callable[[], Any]]() for plugin in discovered_plugins: if allowed_plugins is None or plugin.name in allowed_plugins: + if allowed_plugins is not None: + logger.info("Loading plugin %s", plugin.name) + try: func = plugin.load() plugins[plugin.name] = func - logger.info("plugin %s loaded.", plugin.name) except Exception: logger.exception("Failed to load plugin %s", plugin.name) + return plugins diff --git a/vllm/plugins/lora_resolvers/README.md b/vllm/plugins/lora_resolvers/README.md new file mode 100644 index 000000000000..7e7c55f5c69c --- /dev/null +++ b/vllm/plugins/lora_resolvers/README.md @@ -0,0 +1,15 @@ +# LoRA Resolver Plugins + +This directory contains vLLM general plugins for dynamically discovering and loading LoRA adapters +via the LoRAResolver plugin framework. + +Note that `VLLM_ALLOW_RUNTIME_LORA_UPDATING` must be set to true to allow LoRA resolver plugins +to work, and `VLLM_PLUGINS` must be set to include the desired resolver plugins. + +# lora_filesystem_resolver +This LoRA Resolver is installed with vLLM by default. +To use, set `VLLM_PLUGIN_LORA_CACHE_DIR` to a local directory. When vLLM receives a request +for a LoRA adapter `foobar` it doesn't currently recognize, it will look in that local directory +for a subdirectory `foobar` containing a LoRA adapter. If such an adapter exists, it will +load that adapter, and then service the request as normal. That adapter will then be available +for future requests as normal. diff --git a/vllm/plugins/lora_resolvers/__init__.py b/vllm/plugins/lora_resolvers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/plugins/lora_resolvers/filesystem_resolver.py b/vllm/plugins/lora_resolvers/filesystem_resolver.py new file mode 100644 index 000000000000..219231f77785 --- /dev/null +++ b/vllm/plugins/lora_resolvers/filesystem_resolver.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from typing import Optional + +import vllm.envs as envs +from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry + + +class FilesystemResolver(LoRAResolver): + + def __init__(self, lora_cache_dir: str): + self.lora_cache_dir = lora_cache_dir + + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + lora_path = os.path.join(self.lora_cache_dir, lora_name) + if os.path.exists(lora_path): + adapter_config_path = os.path.join(self.lora_cache_dir, lora_name, + "adapter_config.json") + if os.path.exists(adapter_config_path): + with open(adapter_config_path) as file: + adapter_config = json.load(file) + if adapter_config["peft_type"] == "LORA" and adapter_config[ + "base_model_name_or_path"] == base_model_name: + lora_request = LoRARequest(lora_name=lora_name, + lora_int_id=abs( + hash(lora_name)), + lora_path=lora_path) + return lora_request + return None + + +def register_filesystem_resolver(): + """Register the filesystem LoRA Resolver with vLLM""" + + lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR + if lora_cache_dir: + if not os.path.exists(lora_cache_dir) or not os.path.isdir( + lora_cache_dir): + raise ValueError( + "VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \ + for Filesystem Resolver plugin to function") + fs_resolver = FilesystemResolver(lora_cache_dir) + LoRAResolverRegistry.register_resolver("Filesystem Resolver", + fs_resolver) + + return diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index 6351ef63da2b..6934d328a87e 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -3,7 +3,7 @@ import copy from collections import defaultdict from dataclasses import asdict, dataclass, field -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union +from typing import Any, Callable, Optional, TypeAlias, Union import pandas as pd from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult @@ -20,7 +20,7 @@ class _ModuleTreeNode: event: _ProfilerEvent parent: Optional['_ModuleTreeNode'] = None - children: List['_ModuleTreeNode'] = field(default_factory=list) + children: list['_ModuleTreeNode'] = field(default_factory=list) trace: str = "" @property @@ -60,19 +60,19 @@ class ModelStatsEntry: @dataclass class _StatsTreeNode: entry: StatsEntry - children: List[StatsEntry] + children: list[StatsEntry] parent: Optional[StatsEntry] @dataclass class LayerwiseProfileResults(profile): _kineto_results: _ProfilerResult - _kineto_event_correlation_map: Dict[int, - List[_KinetoEvent]] = field(init=False) - _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) - _module_tree: List[_ModuleTreeNode] = field(init=False) - _model_stats_tree: List[_StatsTreeNode] = field(init=False) - _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + _kineto_event_correlation_map: dict[int, + list[_KinetoEvent]] = field(init=False) + _event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False) + _module_tree: list[_ModuleTreeNode] = field(init=False) + _model_stats_tree: list[_StatsTreeNode] = field(init=False) + _summary_stats_tree: list[_StatsTreeNode] = field(init=False) # profile metadata num_running_seqs: Optional[int] = None @@ -82,7 +82,7 @@ def __post_init__(self): self._build_module_tree() self._build_stats_trees() - def print_model_table(self, column_widths: Dict[str, int] = None): + def print_model_table(self, column_widths: dict[str, int] = None): _column_widths = dict(name=60, cpu_time_us=12, cuda_time_us=12, @@ -100,7 +100,7 @@ def print_model_table(self, column_widths: Dict[str, int] = None): filtered_model_table, indent_style=lambda indent: "|" + "-" * indent + " ")) - def print_summary_table(self, column_widths: Dict[str, int] = None): + def print_summary_table(self, column_widths: dict[str, int] = None): _column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, @@ -142,7 +142,7 @@ def convert_stats_to_dict(self) -> dict[str, Any]: } @staticmethod - def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, + def _indent_row_names_based_on_depth(depths_rows: list[tuple[int, StatsEntry]], indent_style: Union[Callable[[int], str], @@ -229,7 +229,7 @@ def _total_cuda_time(self): [self._cumulative_cuda_time(root) for root in self._module_tree]) def _build_stats_trees(self): - summary_dict: Dict[str, _StatsTreeNode] = {} + summary_dict: dict[str, _StatsTreeNode] = {} total_cuda_time = self._total_cuda_time() def pct_cuda_time(cuda_time_us): @@ -238,7 +238,7 @@ def pct_cuda_time(cuda_time_us): def build_summary_stats_tree_df( node: _ModuleTreeNode, parent: Optional[_StatsTreeNode] = None, - summary_trace: Tuple[str] = ()): + summary_trace: tuple[str] = ()): if event_has_module(node.event): name = event_module_repr(node.event) @@ -313,8 +313,8 @@ def build_model_stats_tree_df(node: _ModuleTreeNode, self._model_stats_tree.append(build_model_stats_tree_df(root)) def _flatten_stats_tree( - self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: - entries: List[Tuple[int, StatsEntry]] = [] + self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]: + entries: list[tuple[int, StatsEntry]] = [] def df_traversal(node: _StatsTreeNode, depth=0): entries.append((depth, node.entry)) @@ -327,10 +327,10 @@ def df_traversal(node: _StatsTreeNode, depth=0): return entries def _convert_stats_tree_to_dict(self, - tree: List[_StatsTreeNode]) -> List[Dict]: - root_dicts: List[Dict] = [] + tree: list[_StatsTreeNode]) -> list[dict]: + root_dicts: list[dict] = [] - def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): + def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]): curr_json_list.append({ "entry": asdict(node.entry), "children": [] diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py index 62b39f510703..b26fd4dd8c07 100644 --- a/vllm/profiler/utils.py +++ b/vllm/profiler/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses -from typing import Callable, Dict, List, Type, Union +from typing import Callable, Union from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata @@ -30,14 +30,14 @@ def trim_string_back(string, width): class TablePrinter: - def __init__(self, row_cls: Type[dataclasses.dataclass], - column_widths: Dict[str, int]): + def __init__(self, row_cls: type[dataclasses.dataclass], + column_widths: dict[str, int]): self.row_cls = row_cls self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] self.column_widths = column_widths assert set(self.column_widths.keys()) == set(self.fieldnames) - def print_table(self, rows: List[dataclasses.dataclass]): + def print_table(self, rows: list[dataclasses.dataclass]): self._print_header() self._print_line() for row in rows: diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 454167a0dc95..9dd5191da918 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import os from abc import abstractmethod from collections.abc import Sequence @@ -33,7 +35,7 @@ def vocab(self) -> dict[str, int]: return self.model_tokenizer.get_vocab() @abstractmethod - def is_reasoning_end(self, input_ids: list[int]) -> bool: + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: """ Check if the reasoning content ends in the input_ids. @@ -106,7 +108,7 @@ class ReasoningParserManager: reasoning_parsers: dict[str, type] = {} @classmethod - def get_reasoning_parser(cls, name) -> type: + def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: """ Get reasoning parser by name which is registered by `register_module`. diff --git a/vllm/reasoning/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py index 0dae02d33fec..07a63e294df4 100644 --- a/vllm/reasoning/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -import re from collections.abc import Sequence from typing import Optional, Union +import regex as re from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index affc5c64b941..dc38daa388ce 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -149,7 +149,7 @@ class SamplingParams( top_p: Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1 to consider all tokens. top_k: Integer that controls the number of top tokens to consider. Set - to -1 to consider all tokens. + to 0 (or -1) to consider all tokens. min_p: Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. @@ -209,7 +209,7 @@ class SamplingParams( repetition_penalty: float = 1.0 temperature: float = 1.0 top_p: float = 1.0 - top_k: int = -1 + top_k: int = 0 min_p: float = 0.0 seed: Optional[int] = None stop: Optional[Union[str, list[str]]] = None @@ -256,7 +256,7 @@ def from_optional( repetition_penalty: Optional[float] = 1.0, temperature: Optional[float] = 1.0, top_p: Optional[float] = 1.0, - top_k: int = -1, + top_k: int = 0, min_p: float = 0.0, seed: Optional[int] = None, stop: Optional[Union[str, list[str]]] = None, @@ -376,7 +376,7 @@ def __post_init__(self) -> None: if self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. self.top_p = 1.0 - self.top_k = -1 + self.top_k = 0 self.min_p = 0.0 self._verify_greedy_sampling() @@ -404,8 +404,9 @@ def _verify_args(self) -> None: f"temperature must be non-negative, got {self.temperature}.") if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") - if self.top_k < -1 or self.top_k == 0: - raise ValueError(f"top_k must be -1 (disable), or at least 1, " + # quietly accept -1 as disabled, but prefer 0 + if self.top_k < -1: + raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.") if not isinstance(self.top_k, int): raise TypeError( diff --git a/vllm/sequence.py b/vllm/sequence.py index 91f769d6dbd9..d359f897da25 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -27,7 +27,7 @@ def array_full(token_id: int, count: int): - """{class}`array` equivalent of {func}`numpy.full`.""" + """[`array`][] equivalent of [numpy.full][].""" return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count @@ -112,12 +112,12 @@ class RequestMetrics: will include model forward, block/sync across workers, cpu-gpu sync time and sampling time. spec_token_acceptance_counts: number of accepted speculative tokens at - each position; the first token is from + each position; the first token is from the target model and is always accepted; - e.g., when it's [10, 8, 4, 2] for a req, + e.g., when it's [10, 8, 4, 2] for a req, it means there were 10 forward passes in - total, and there were 8, 4, 2 accepted - tokens at 1st, 2nd, 3rd speculation step. + total, and there were 8, 4, 2 accepted + tokens at 1st, 2nd, 3rd speculation step. """ arrival_time: float last_token_time: float @@ -192,8 +192,8 @@ class SequenceData(msgspec.Struct, def from_prompt_token_counts( *token_counts: tuple[int, int]) -> "SequenceData": """ - Construct a {class}`SequenceData` instance by concatenating - prompt token sequences. + Construct a [`SequenceData`][vllm.sequence.SequenceData] instance + by concatenating prompt token sequences. Each tuple represents one token sequence, expressed in the form `(token_id, count)`. @@ -216,8 +216,8 @@ def from_seqs( prompt_embeds: Optional[torch.Tensor] = None, ) -> "SequenceData": """ - Construct a {class}`SequenceData` instance from prompt and output - token sequences. + Construct a [`SequenceData`][vllm.sequence.SequenceData] instance + from prompt and output token sequences. """ prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids) @@ -452,9 +452,11 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - The sequence is constructed from the {data}`DecoderOnlyInputs` - (for decoder-only) or {data}`EncoderDecoderInputs` (for encoder-decoder) - instance passed in through the `inputs` constructor argument. + The sequence is constructed from the + [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only) + or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] + (for encoder-decoder) instance passed in through the `inputs` + constructor argument. Args: seq_id: The ID of the sequence. @@ -714,9 +716,9 @@ class SequenceGroup: trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. priority: User-defined priority of the request. - draft_size: The number of speculative tokens plus one from the target + draft_size: The number of speculative tokens plus one from the target model; equal to max number of tokens a step can generate - for single-draft speculative decoding but larger than + for single-draft speculative decoding but larger than that for multi-draft SD (currently not supported). """ @@ -1123,7 +1125,7 @@ def __repr__(self) -> str: self.output_embed.shape if self.output_embed is not None else None return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " - f"output_embed.shape={output_embed_shape}" + f"output_embed.shape={output_embed_shape}, " f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: @@ -1330,6 +1332,8 @@ def prune(self, # may be "paused" then "resumed" later. This should only prune sequences # which are confirmed to be aborted. seq_ids = get_all_seq_ids(seq_group_metadata_list) + # Only keep sequence IDs that exist in self._seq_ids + seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids] if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] @@ -1492,7 +1496,7 @@ def add_request(request_id: str, engine, params, **kwargs): for i in range(original_params.n): request_id_i = f"{request_id}_parallel_sample_{i}" group.seq_id_to_index[request_id_i] = i - params = copy.deepcopy(original_params) + params = original_params.clone() params.n = 1 if params.seed is not None: params.seed += i diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index a6276c563394..991d2040a878 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -294,8 +294,11 @@ def execute_model( inputs_embeds=None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_runner.model_config.dtype, + device=self.device, + ), **model_execute_kwargs, ) diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 0bb8d602ec8f..4430da26c049 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -126,12 +126,12 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: """Copy rejection/typical-acceptance sampling metrics (number of accepted tokens, etc) to CPU asynchronously. - Returns a CUDA event recording when the copy is complete. + Returns a device event recording when the copy is complete. """ assert self._copy_stream is not None - self._copy_stream.wait_stream(torch.cuda.current_stream()) + self._copy_stream.wait_stream(current_platform.current_stream()) - with torch.cuda.stream(self._copy_stream): + with current_platform.stream(self._copy_stream): self._aggregate_num_accepted_tokens.copy_( self.spec_decode_sampler.num_accepted_tokens, non_blocking=True) @@ -142,7 +142,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: self._aggregate_num_draft_tokens = ( self.spec_decode_sampler.num_draft_tokens) - aggregate_metrics_ready = torch.cuda.Event() + aggregate_metrics_ready = current_platform.Event() aggregate_metrics_ready.record(self._copy_stream) return aggregate_metrics_ready diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 6ba5a51007b4..252c80957305 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -114,7 +114,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": return spec_decode_worker -# Reminder: Please update docs/source/features/compatibility_matrix.md +# Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid class SpecDecodeWorker(LoRANotSupportedWorkerBase): """Worker which implements speculative decoding. diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 8611a25922bb..f8cec380f336 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -110,7 +110,7 @@ "royokong/e5-v", "sentence-transformers/all-roberta-large-v1", "sentence-transformers/stsb-roberta-base-v2", - "shanearora/OLMo-7B-1124-hf", + "allenai/OLMo-2-0425-1B", "shuyuej/Llama-3.2-1B-Instruct-GPTQ", "ssmits/Qwen2-7B-Instruct-embed-base", "stabilityai/stablelm-3b-4e1t", diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py index 01d5bb4b5748..84bd7a747656 100644 --- a/vllm/transformers_utils/__init__.py +++ b/vllm/transformers_utils/__init__.py @@ -1,19 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.envs import VLLM_USE_MODELSCOPE +from vllm import envs -if VLLM_USE_MODELSCOPE: - # Patch here, before each import happens - import modelscope - from packaging import version +if envs.VLLM_USE_MODELSCOPE: + try: + # Patch here, before each import happens + import modelscope + from packaging import version - # patch_hub begins from modelscope>=1.18.1 - if version.parse(modelscope.__version__) <= version.parse('1.18.0'): - raise ImportError( - 'Using vLLM with ModelScope needs modelscope>=1.18.1, please ' - 'install by `pip install modelscope -U`') - - from modelscope.utils.hf_util import patch_hub + # patch_hub begins from modelscope>=1.18.1 + if version.parse(modelscope.__version__) <= version.parse('1.18.0'): + raise ImportError( + 'Using vLLM with ModelScope needs modelscope>=1.18.1, please ' + 'install by `pip install modelscope -U`') + from modelscope.utils.hf_util import patch_hub - # Patch hub to download models from modelscope to speed up. - patch_hub() + # Patch hub to download models from modelscope to speed up. + patch_hub() + except ImportError as err: + raise ImportError( + "Please install modelscope>=1.18.1 via " + "`pip install modelscope>=1.18.1` to use ModelScope.") from err diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f6c2b35535b6..69e7207cc350 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,7 +6,7 @@ import time from functools import cache from pathlib import Path -from typing import Any, Callable, Dict, Literal, Optional, Type, Union +from typing import Any, Callable, Literal, Optional, Union import huggingface_hub from huggingface_hub import hf_hub_download @@ -24,7 +24,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME -from vllm.envs import VLLM_USE_MODELSCOPE +from vllm import envs from vllm.logger import init_logger # yapf conflicts with isort for this block # yapf: disable @@ -45,21 +45,20 @@ from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname -if VLLM_USE_MODELSCOPE: +if envs.VLLM_USE_MODELSCOPE: from modelscope import AutoConfig else: from transformers import AutoConfig MISTRAL_CONFIG_NAME = "params.json" -HF_TOKEN = os.getenv('HF_TOKEN', None) logger = init_logger(__name__) -_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { +_CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = { "mllama": MllamaConfig } -_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { "chatglm": ChatGLMConfig, "cohere2": Cohere2Config, "dbrx": DbrxConfig, @@ -130,7 +129,7 @@ def lookup_files() -> list[str]: ] # if model is remote, use hf_hub api to list files try: - if VLLM_USE_MODELSCOPE: + if envs.VLLM_USE_MODELSCOPE: from vllm.transformers_utils.utils import ( modelscope_list_repo_files) return modelscope_list_repo_files(repo_id, @@ -185,7 +184,7 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, return file_exists(str(model), config_name, revision=revision, - token=HF_TOKEN) + token=os.getenv('HF_TOKEN', None)) def patch_rope_scaling(config: PretrainedConfig) -> None: @@ -199,7 +198,7 @@ def patch_rope_scaling(config: PretrainedConfig) -> None: patch_rope_scaling_dict(rope_scaling) -def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None: +def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None: if "rope_type" in rope_scaling and "type" in rope_scaling: rope_type = rope_scaling["rope_type"] rope_type_legacy = rope_scaling["type"] @@ -300,7 +299,10 @@ def get_config( " - For Hugging Face models: ensure the presence of a " "'config.json'.\n" " - For Mistral models: ensure the presence of a " - "'params.json'.\n").format(model=model) + "'params.json'.\n" + "3. For GGUF: pass the local path of the GGUF checkpoint.\n" + " Loading GGUF from a remote repo directly is not yet " + "supported.\n").format(model=model) raise ValueError(error_message) from e @@ -309,7 +311,7 @@ def get_config( model, revision=revision, code_revision=code_revision, - token=HF_TOKEN, + token=os.getenv('HF_TOKEN', None), **kwargs, ) @@ -321,7 +323,7 @@ def get_config( model, revision=revision, code_revision=code_revision, - token=HF_TOKEN, + token=os.getenv('HF_TOKEN', None), **kwargs, ) else: @@ -331,7 +333,7 @@ def get_config( trust_remote_code=trust_remote_code, revision=revision, code_revision=code_revision, - token=HF_TOKEN, + token=os.getenv('HF_TOKEN', None), **kwargs, ) except ValueError as e: @@ -349,7 +351,7 @@ def get_config( raise e elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision, token=HF_TOKEN, **kwargs) + config = load_params_config(model, revision, **kwargs) else: supported_formats = [ fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO @@ -558,7 +560,7 @@ def get_sentence_transformer_tokenizer_config(model: str, # If model is on HuggingfaceHub, get the repo files repo_files = list_repo_files(model, revision=revision, - token=HF_TOKEN) + token=os.getenv('HF_TOKEN', None)) except Exception: repo_files = [] @@ -686,9 +688,24 @@ def recurse_elems(elem: Any): config_dict["hidden_act"] = config_dict.get("activation", "silu") config_dict["tie_word_embeddings"] = config_dict.get( "tie_embeddings", False) - config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000) - config_dict["max_position_embeddings"] = config_dict.get( - "max_position_embeddings", 128_000) + + if config_dict.get("max_position_embeddings") is None: + max_position_embeddings = 128_000 + try: + trust_remote_code_val = kwargs.get("trust_remote_code", False) + hf_config = get_config(model=model, + trust_remote_code=trust_remote_code_val, + revision=revision, + config_format=ConfigFormat.HF) + if hf_value := hf_config.get_text_config().max_position_embeddings: + max_position_embeddings = hf_value + except Exception as e: + logger.warning( + "The params.json file is missing 'max_position_embeddings'" + " and could not get a value from the HF config." + " Defaulting to 128000", + exc_info=e) + config_dict["max_position_embeddings"] = max_position_embeddings if config_dict.get("quantization") is not None: quantization = config_dict.get("quantization", {}) @@ -748,9 +765,9 @@ def get_hf_image_processor_config( hf_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, **kwargs, -) -> Dict[str, Any]: +) -> dict[str, Any]: # ModelScope does not provide an interface for image_processor - if VLLM_USE_MODELSCOPE: + if envs.VLLM_USE_MODELSCOPE: return dict() # Separate model folder from file path for GGUF models if check_gguf_file(model): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index db3efafeef96..ed10c22c84f0 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,7 +23,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config -from vllm.transformers_utils.configs.ovis2 import OvisConfig +from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py index 5ab70c0e4136..2261f0a9e9aa 100644 --- a/vllm/transformers_utils/configs/arctic.py +++ b/vllm/transformers_utils/configs/arctic.py @@ -8,7 +8,7 @@ """ Arctic model configuration""" from dataclasses import asdict, dataclass -from typing import Any, Dict +from typing import Any from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -192,14 +192,14 @@ def __init__( ) @classmethod - def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": + def from_dict(cls, config_dict: dict[str, Any], **kwargs) -> "ArcticConfig": result = super().from_dict(config_dict, **kwargs) config = result[0] if isinstance(result, tuple) else result if isinstance(config.quantization, dict): config.quantization = ArcticQuantizationConfig(**config.quantization) return result - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: ret = super().to_dict() if isinstance(ret["quantization"], ArcticQuantizationConfig): ret["quantization"] = asdict(ret["quantization"]) diff --git a/vllm/transformers_utils/configs/cohere2.py b/vllm/transformers_utils/configs/cohere2.py index e30409b3af5f..21328d7675b8 100644 --- a/vllm/transformers_utils/configs/cohere2.py +++ b/vllm/transformers_utils/configs/cohere2.py @@ -61,7 +61,7 @@ class Cohere2Config(PretrainedConfig): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): + rope_scaling (`dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. @@ -86,11 +86,11 @@ class Cohere2Config(PretrainedConfig): `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): + `short_factor` (`list[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): + `long_factor` (`list[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py index 24d4052d8721..a54486fa41cd 100644 --- a/vllm/transformers_utils/configs/deepseek_vl2.py +++ b/vllm/transformers_utils/configs/deepseek_vl2.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py#L115-L268 -from typing import Tuple from transformers.configuration_utils import PretrainedConfig @@ -191,12 +190,12 @@ class DeepseekVLV2Config(PretrainedConfig): tile_tag: str = "2D" global_view_pos: str = "head" - candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384), ) + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384), ) def __init__(self, tile_tag: str = "tile_tag", global_view_pos: str = "head", - candidate_resolutions: Tuple[Tuple[int, + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384), ), **kwargs): super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 586d5c7f5e54..377523efefc3 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -52,13 +52,15 @@ def __init__(self, assert self.model is not None, \ "model should not be None when method is eagle" kwargs["architectures"] = [ - f"Eagle{arch}" for arch in self.model.architectures + f"Eagle{arch}" if not arch.startswith("Eagle") \ + else arch for arch in self.model.architectures ] elif method == "eagle3": assert self.model is not None, \ "model should not be None when method is eagle3" kwargs["architectures"] = [ - f"Eagle3{arch}" for arch in self.model.architectures + f"Eagle3{arch}" if not arch.startswith("Eagle3") \ + else arch for arch in self.model.architectures ] else: raise ValueError(f"Invalid method {method}. \ diff --git a/vllm/transformers_utils/configs/exaone.py b/vllm/transformers_utils/configs/exaone.py index 8181604191a1..25bafbb85d30 100644 --- a/vllm/transformers_utils/configs/exaone.py +++ b/vllm/transformers_utils/configs/exaone.py @@ -17,14 +17,12 @@ # limitations under the License. """Exaone model configuration""" -from typing import Dict - from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) -EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, str] = {} +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: dict[str, str] = {} class ExaoneConfig(PretrainedConfig): diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index be0f3b7e5e52..b947c6a9e2b4 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -98,7 +98,7 @@ class JAISConfig(PretrainedConfig): Scale attention weights by dividing by hidden_size instead of sqrt(hidden_size). Need to set scale_attn_weights to `True` as well. - alibi_scaling (`Dict`, *optional*): + alibi_scaling (`dict`, *optional*): Dictionary containing the scaling configuration for ALiBi embeddings. Currently only supports linear scaling strategy. Can specify either the scaling `factor` (must be @@ -108,7 +108,7 @@ class JAISConfig(PretrainedConfig): formats are `{"type": strategy name, "factor": scaling factor}` or `{"type": strategy name, "train_seq_len": training sequence length}`. - architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']): + architectures (`list`, *optional*, defaults to ['JAISLMHeadModel']): architecture names for Jais. Example: diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py index c761f659e5b2..70f60752905c 100644 --- a/vllm/transformers_utils/configs/mlp_speculator.py +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import Optional from transformers import PretrainedConfig @@ -17,7 +17,7 @@ def __init__(self, emb_dim: int = 4096, inner_dim: int = 0, n_predict: int = 3, - top_k_tokens_per_head: Optional[List[int]] = None, + top_k_tokens_per_head: Optional[list[int]] = None, n_candidates: int = 5, tie_weights: bool = False, scale_input: bool = False, @@ -34,7 +34,7 @@ def __init__(self, the inner dimension of the model. If 0, will be the emb_dim. n_predict: int the number of lookaheads for the speculator - top_k_tokens_per_head: List[int] + top_k_tokens_per_head: list[int] Number of tokens to consider from each head when forming the candidate tree. For each candidate branch in the tree, head n produces topk[n] diff --git a/vllm/transformers_utils/configs/mpt.py b/vllm/transformers_utils/configs/mpt.py index 96356135f6b2..2d52658d3973 100644 --- a/vllm/transformers_utils/configs/mpt.py +++ b/vllm/transformers_utils/configs/mpt.py @@ -4,11 +4,11 @@ # https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py """A HuggingFace-style model configuration.""" import warnings -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from transformers import PretrainedConfig -attn_config_defaults: Dict = { +attn_config_defaults: dict = { 'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', @@ -20,8 +20,8 @@ 'alibi': False, 'alibi_bias_max': 8 } -ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'} -init_config_defaults: Dict = { +ffn_config_defaults: dict = {'ffn_type': 'mptmlp'} +init_config_defaults: dict = { 'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', @@ -52,15 +52,15 @@ def __init__(self, resid_pdrop: float = 0.0, emb_pdrop: float = 0.0, learned_pos_emb: bool = True, - attn_config: Dict = attn_config_defaults, - ffn_config: Dict = ffn_config_defaults, + attn_config: dict = attn_config_defaults, + ffn_config: dict = ffn_config_defaults, init_device: str = 'cpu', logit_scale: Optional[Union[float, str]] = None, no_bias: bool = False, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', use_cache: bool = False, - init_config: Dict = init_config_defaults, + init_config: dict = init_config_defaults, fc_type: str = 'torch', verbose: Optional[int] = None, **kwargs: Any): @@ -102,8 +102,8 @@ def __init__(self, self._validate_config() def _set_config_defaults( - self, config: Dict[str, Any], - config_defaults: Dict[str, Any]) -> Dict[str, Any]: + self, config: dict[str, Any], + config_defaults: dict[str, Any]) -> dict[str, Any]: for (k, v) in config_defaults.items(): if k not in config: config[k] = v diff --git a/vllm/transformers_utils/configs/ovis2.py b/vllm/transformers_utils/configs/ovis.py similarity index 93% rename from vllm/transformers_utils/configs/ovis2.py rename to vllm/transformers_utils/configs/ovis.py index 437a16e778c2..0ec224214f06 100644 --- a/vllm/transformers_utils/configs/ovis2.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -123,6 +123,19 @@ def __init__(self, **kwargs): self.backbone_kwargs['num_hidden_layers'] = self.depths[0] +class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "siglip_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) diff --git a/vllm/transformers_utils/configs/solar.py b/vllm/transformers_utils/configs/solar.py index 0d5db896b93d..6eaf699d17be 100644 --- a/vllm/transformers_utils/configs/solar.py +++ b/vllm/transformers_utils/configs/solar.py @@ -108,7 +108,7 @@ class SolarConfig(PretrainedConfig): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): + rope_scaling (`dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 6b2765db94e7..4c5072427263 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py -from typing import Any, Dict, Optional +from typing import Any, Optional import transformers @@ -48,8 +48,8 @@ class UltravoxConfig(transformers.PretrainedConfig): def __init__( self, - audio_config: Optional[Dict[str, Any]] = None, - text_config: Optional[Dict[str, Any]] = None, + audio_config: Optional[dict[str, Any]] = None, + text_config: Optional[dict[str, Any]] = None, audio_model_id: Optional[str] = None, text_model_id: Optional[str] = None, ignore_index: int = -100, @@ -58,8 +58,8 @@ def __init__( stack_factor: int = 8, norm_init: float = 0.4, projector_act: str = "swiglu", - text_model_lora_config: Optional[Dict[str, Any]] = None, - audio_model_lora_config: Optional[Dict[str, Any]] = None, + text_model_lora_config: Optional[dict[str, Any]] = None, + audio_model_lora_config: Optional[dict[str, Any]] = None, projector_ln_mid: bool = False, **kwargs, ): diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 991d5631e64e..3adf2e32cca7 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional +from typing import Optional from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, Sequence, SequenceGroup) @@ -22,7 +22,7 @@ def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, - prompt_logprobs: List[Optional[Dict[ + prompt_logprobs: list[Optional[dict[ int, Logprob]]], position_offset: int) -> None: """Decodes the logprobs for the prompt of a sequence group. @@ -49,7 +49,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, read_offset = 0 next_iter_prefix_offset = 0 next_iter_read_offset = 0 - next_iter_tokens: List[str] = [] + next_iter_tokens: list[str] = [] prev_tokens = None for token_position_in_logprob, prompt_logprobs_for_token in enumerate( diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index a1fa27773fe5..7373fa0ede23 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple +from typing import Optional from .tokenizer import AnyTokenizer -def _replace_none_with_empty(tokens: List[Optional[str]]): +def _replace_none_with_empty(tokens: list[Optional[str]]): for i, token in enumerate(tokens): if token is None: tokens[i] = "" @@ -13,7 +13,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]): def _convert_tokens_to_string_with_added_encoders( tokenizer: AnyTokenizer, - output_tokens: List[str], + output_tokens: list[str], skip_special_tokens: bool, spaces_between_special_tokens: bool, ) -> str: @@ -22,8 +22,8 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. - sub_texts: List[str] = [] - current_sub_text: List[str] = [] + sub_texts: list[str] = [] + current_sub_text: list[str] = [] all_special_tokens = set(tokenizer.all_special_tokens) for token in output_tokens: if skip_special_tokens and token in all_special_tokens: @@ -52,9 +52,9 @@ def _convert_tokens_to_string_with_added_encoders( def convert_prompt_ids_to_tokens( tokenizer: AnyTokenizer, - prompt_ids: List[int], + prompt_ids: list[int], skip_special_tokens: bool = False, -) -> Tuple[List[str], int, int]: +) -> tuple[list[str], int, int]: """Converts the prompt ids to tokens and returns the tokens and offsets for incremental detokenization. @@ -76,8 +76,8 @@ def convert_prompt_ids_to_tokens( def convert_ids_list_to_tokens( tokenizer: AnyTokenizer, - token_ids: List[int], -) -> List[str]: + token_ids: list[int], +) -> list[str]: """Detokenize the input ids individually. Args: @@ -98,13 +98,13 @@ def convert_ids_list_to_tokens( # under Apache 2.0 license def detokenize_incrementally( tokenizer: AnyTokenizer, - all_input_ids: List[int], - prev_tokens: Optional[List[str]], + all_input_ids: list[int], + prev_tokens: Optional[list[str]], prefix_offset: int, read_offset: int, skip_special_tokens: bool = False, spaces_between_special_tokens: bool = True, -) -> Tuple[List[str], str, int, int]: +) -> tuple[list[str], str, int, int]: """Detokenizes the input ids incrementally and returns the new tokens and the new text. diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index d27c26659b55..ce6427de432d 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import lru_cache -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from transformers.processing_utils import ProcessorMixin from typing_extensions import TypeVar @@ -54,6 +54,7 @@ def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs): def get_processor( processor_name: str, *args: Any, + revision: Optional[str] = None, trust_remote_code: bool = False, processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, **kwargs: Any, @@ -70,6 +71,7 @@ def get_processor( processor = processor_factory.from_pretrained( processor_name, *args, + revision=revision, trust_remote_code=trust_remote_code, **kwargs, ) @@ -106,6 +108,7 @@ def cached_processor_from_config( ) -> _P: return cached_get_processor( model_config.model, + revision=model_config.revision, trust_remote_code=model_config.trust_remote_code, processor_cls=processor_cls, # type: ignore[arg-type] **_merge_mm_kwargs(model_config, **kwargs), @@ -115,6 +118,7 @@ def cached_processor_from_config( def get_feature_extractor( processor_name: str, *args: Any, + revision: Optional[str] = None, trust_remote_code: bool = False, **kwargs: Any, ): @@ -128,6 +132,7 @@ def get_feature_extractor( feature_extractor = AutoFeatureExtractor.from_pretrained( processor_name, *args, + revision=revision, trust_remote_code=trust_remote_code, **kwargs) except ValueError as e: @@ -156,6 +161,7 @@ def cached_feature_extractor_from_config( ): return cached_get_feature_extractor( model_config.model, + revision=model_config.revision, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, **kwargs), ) @@ -164,6 +170,7 @@ def cached_feature_extractor_from_config( def get_image_processor( processor_name: str, *args: Any, + revision: Optional[str] = None, trust_remote_code: bool = False, **kwargs: Any, ): @@ -177,6 +184,7 @@ def get_image_processor( processor = AutoImageProcessor.from_pretrained( processor_name, *args, + revision=revision, trust_remote_code=trust_remote_code, **kwargs) except ValueError as e: @@ -206,6 +214,7 @@ def cached_image_processor_from_config( ): return cached_get_image_processor( model_config.model, + revision=model_config.revision, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, **kwargs), ) diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 2e9cf3e4d90b..2bd9ab1f099b 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -2,6 +2,6 @@ from vllm.transformers_utils.processors.deepseek_vl2 import ( DeepseekVLV2Processor) -from vllm.transformers_utils.processors.ovis2 import OvisProcessor +from vllm.transformers_utils.processors.ovis import OvisProcessor __all__ = ["DeepseekVLV2Processor", "OvisProcessor"] diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index 316281f2af4e..df960e9c7aa8 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -24,7 +24,6 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import math -from typing import List, Tuple import torch import torchvision.transforms as T @@ -36,8 +35,8 @@ class ImageTransform: def __init__(self, - mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), - std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + std: tuple[float, float, float] = (0.5, 0.5, 0.5), normalize: bool = True): self.mean = mean self.std = std @@ -62,11 +61,11 @@ class DeepseekVLV2Processor(ProcessorMixin): def __init__( self, tokenizer: LlamaTokenizerFast, - candidate_resolutions: Tuple[Tuple[int, int]], + candidate_resolutions: tuple[tuple[int, int]], patch_size: int, downsample_ratio: int, - image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), - image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: tuple[float, float, float] = (0.5, 0.5, 0.5), normalize: bool = True, image_token: str = "<image>", pad_token: str = "<๏ฝœโ–padโ–๏ฝœ>", @@ -170,13 +169,13 @@ def encode(self, text: str, bos: bool = True, eos: bool = False): return t - def decode(self, t: List[int], **kwargs) -> str: + def decode(self, t: list[int], **kwargs) -> str: return self.tokenizer.decode(t, **kwargs) def process_one( self, prompt: str, - images: List[Image.Image], + images: list[Image.Image], inference_mode: bool = True, **kwargs, ): @@ -184,8 +183,8 @@ def process_one( Args: prompt (str): the formatted prompt; - conversations (List[Dict]): conversations with a list of messages; - images (List[ImageType]): the list of images; + conversations (list[dict]): conversations with a list of messages; + images (list[ImageType]): the list of images; inference_mode (bool): if True, then remove the last eos token; system_prompt (str): the system prompt; **kwargs: @@ -196,7 +195,7 @@ def process_one( - target_ids (torch.LongTensor): [N + image tokens] - pixel_values (torch.FloatTensor): [n_patches, 3, H, W] - image_id (int): the id of the image token - - num_image_tokens (List[int]): the number of image tokens + - num_image_tokens (list[int]): the number of image tokens """ assert (prompt is not None and images is not None @@ -257,7 +256,7 @@ def __call__( self, *, prompt: str, - images: List[Image.Image], + images: list[Image.Image], inference_mode: bool = True, **kwargs, ): @@ -265,7 +264,7 @@ def __call__( Args: prompt (str): the formatted prompt; - images (List[ImageType]): the list of images; + images (list[ImageType]): the list of images; inference_mode (bool): if True, then remove the last eos token; **kwargs: @@ -274,7 +273,7 @@ def __call__( - input_ids (torch.LongTensor): [N + image tokens] - images (torch.FloatTensor): [n_images, 3, H, W] - image_id (int): the id of the image token - - num_image_tokens (List[int]): the number of image tokens + - num_image_tokens (list[int]): the number of image tokens """ prepare = self.process_one( @@ -288,7 +287,7 @@ def __call__( def tokenize_with_images( self, conversation: str, - images: List[Image.Image], + images: list[Image.Image], bos: bool = True, eos: bool = True, cropping: bool = True, diff --git a/vllm/transformers_utils/processors/ovis2.py b/vllm/transformers_utils/processors/ovis.py similarity index 91% rename from vllm/transformers_utils/processors/ovis2.py rename to vllm/transformers_utils/processors/ovis.py index a633256ec12c..f1c6407e1f3a 100644 --- a/vllm/transformers_utils/processors/ovis2.py +++ b/vllm/transformers_utils/processors/ovis.py @@ -22,7 +22,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from functools import cached_property +from typing import Union import PIL import torch @@ -32,7 +33,9 @@ Unpack) from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -__all__ = [ 'OvisProcessor'] +from vllm.multimodal.image import convert_image_mode + +__all__ = ['OvisProcessor'] IGNORE_ID = -100 class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] @@ -64,18 +67,29 @@ class OvisProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] + valid_kwargs = ["chat_template", "image_pad_token", "image_segement_len"] image_processor_class = "AutoImageProcessor" - tokenizer_class = "Qwen2Tokenizer" + tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor=None, tokenizer=None, chat_template=None, image_pad_token=None, **kwargs): + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + image_pad_token=None, + image_segment_len=255, + **kwargs, + ): self.image_token = "<image>" - self.image_pad_token = "<|image_pad|>" if image_pad_token is None else image_pad_token + self.image_pad_token = image_pad_token + self.image_segment_len = image_segment_len super().__init__(image_processor, tokenizer, chat_template=chat_template) - self.image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] - self.extra_special_tokens = { + @cached_property + def extra_special_tokens(self): + image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] + extra_special_tokens = { "image_token": -200, "image_atom": -300, "image_start": -301, @@ -83,13 +97,14 @@ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, ima "image_col_sep": -303, "image_row_sep": -304, "image_end": -305, - 'image_pad': self.image_pad_token_id, + 'image_pad': image_pad_token_id, } + return extra_special_tokens def __call__( self, images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, **kwargs: Unpack[OvisProcessorKwargs], ) -> BatchFeature: """ @@ -98,14 +113,14 @@ def __call__( the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. Args: - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - text (`str`, `List[str]`, `List[List[str]]`): + text (`str`, `list[str]`, `list[list[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): @@ -224,8 +239,14 @@ def _tokenize_with_image_symbol(self, text_list: list[str]) -> torch.LongTensor: return torch.tensor(batch_token_ids, dtype=torch.long) def get_image_size(self): - height = self.image_processor.crop_size["height"] - width = self.image_processor.crop_size["width"] + size = self.image_processor.size + if 'shortest_edge' in size: + width = height = size['shortest_edge'] + elif "height" in size and "width" in size: + width = size['width'] + height = size['height'] + else: + raise ValueError( "Can't parse image size from image_processor config.") return height, width def get_token_value(self, tok): @@ -259,8 +280,7 @@ def construct_image_placeholders(self, grid): for token in image_placeholders: padded_placeholder_tokens.append(image_padding_token_id) if token == image_atom_token_id: - # Add 255 padding tokens after each image atom token - padded_placeholder_tokens.extend([image_padding_token_id] * 255) + padded_placeholder_tokens.extend([image_padding_token_id] * self.image_segment_len) return padded_placeholder_tokens def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors): @@ -343,8 +363,8 @@ def _get_best_grid(img, side): # pick the partition with maximum covering_ratio and break the tie using #sub_images return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] - if convert_to_rgb and image.mode != 'RGB': - image = image.convert('RGB') + if convert_to_rgb: + image = convert_image_mode(image, 'RGB') sides = self.get_image_size() @@ -382,7 +402,7 @@ def post_process_image_text_to_text(self, generated_outputs): The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` or `(sequence_length,)`. Returns: - `List[str]`: The decoded text. + `list[str]`: The decoded text. """ return self.tokenizer.batch_decode( generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e31580ede57b..fa7a208c48ed 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -13,7 +13,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -from vllm.envs import VLLM_USE_MODELSCOPE +from vllm import envs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer_base import (TokenizerBase, @@ -168,7 +168,7 @@ def get_tokenizer( ) -> AnyTokenizer: """Gets a tokenizer for the given model name via HuggingFace or ModelScope. """ - if VLLM_USE_MODELSCOPE: + if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. # pylint: disable=C. diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py index aff2d2eb1c35..8b9e4881ef88 100644 --- a/vllm/transformers_utils/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import Optional from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig from vllm.lora.request import LoRARequest @@ -32,7 +32,7 @@ def get_max_input_len(self, return self.max_input_length def _raise_if_input_too_long(self, - encoded_tokens: List[int], + encoded_tokens: list[int], lora_request: Optional[LoRARequest] = None): input_length = len(encoded_tokens) if lora_request: @@ -48,7 +48,7 @@ def encode(self, max_length: Optional[int] = None, truncation: Optional[bool] = None, lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: + add_special_tokens: Optional[bool] = None) -> list[int]: tokenizer = self.get_lora_tokenizer(lora_request) ret = encode_tokens(tokenizer, @@ -65,7 +65,7 @@ async def encode_async( max_length: Optional[int] = None, truncation: Optional[bool] = None, lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: + add_special_tokens: Optional[bool] = None) -> list[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) ret = encode_tokens(tokenizer, prompt, diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 3db7a0a5c5c1..23b6f67f09df 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import os -import re from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast import huggingface_hub +import regex as re from huggingface_hub import HfApi, hf_hub_download from vllm.logger import init_logger @@ -28,7 +28,7 @@ @dataclass class Encoding: - input_ids: Union[List[int], List[List[int]]] + input_ids: Union[list[int], list[list[int]]] def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): @@ -105,7 +105,7 @@ def validate_request_params(request: "ChatCompletionRequest"): "for Mistral tokenizers.") -def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: +def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]: repo_cache = os.path.join( huggingface_hub.constants.HF_HUB_CACHE, huggingface_hub.constants.REPO_ID_SEPARATOR.join( @@ -125,7 +125,7 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: return [] -def find_tokenizer_file(files: List[str]): +def find_tokenizer_file(files: list[str]): file_pattern = re.compile( r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") @@ -145,10 +145,10 @@ def find_tokenizer_file(files: List[str]): def make_mistral_chat_completion_request( - messages: List["ChatCompletionMessageParam"], - tools: Optional[List[Dict[str, + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None) -> "ChatCompletionRequest": - last_message = cast(Dict[str, Any], messages[-1]) + last_message = cast(dict[str, Any], messages[-1]) if last_message["role"] == "assistant": last_message["prefix"] = True @@ -156,7 +156,11 @@ def make_mistral_chat_completion_request( # # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 for message in messages: - if message.get("role") == "assistant": + # Remove reasoning_content as unsupported by Mistral + _ = message.pop("reasoning_content", None) # type: ignore + + # Convert list text content to string + if message.get("role") in ("assistant", "tool"): content = message.get("content") if isinstance(content, list): content = "\n".join(chunk.get("text") for chunk in content) @@ -199,7 +203,7 @@ def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") self._vocab = tokenizer_.vocab() - # Convert to a Dict[str, int] to match protocol, but this is a lossy + # Convert to a dict[str, int] to match protocol, but this is a lossy # conversion. There may be multiple token ids that decode to the same # string due to partial UTF-8 byte sequences being converted to ๏ฟฝ self._vocab_dict = { @@ -314,21 +318,21 @@ def __len__(self) -> int: def __call__( self, - text: Union[str, List[str], List[int]], + text: Union[str, list[str], list[int]], text_pair: Optional[str] = None, add_special_tokens: bool = False, truncation: bool = False, max_length: Optional[int] = None, ): - input_ids: Union[List[int], List[List[int]]] - # For List[str], original prompt text + input_ids: Union[list[int], list[list[int]]] + # For list[str], original prompt text if is_list_of(text, str): - input_ids_: List[List[int]] = [] + input_ids_: list[list[int]] = [] for p in text: each_input_ids = self.encode_one(p, truncation, max_length) input_ids_.append(each_input_ids) input_ids = input_ids_ - # For List[int], apply chat template output, already tokens. + # For list[int], apply chat template output, already tokens. elif is_list_of(text, int): input_ids = text # For str, single prompt text @@ -350,7 +354,7 @@ def encode_one( text: str, truncation: bool = False, max_length: Optional[int] = None, - ) -> List[int]: + ) -> list[int]: # Mistral Tokenizers should not add special tokens input_ids = self.encode(text) @@ -362,7 +366,7 @@ def encode(self, text: str, truncation: Optional[bool] = None, max_length: Optional[int] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: + add_special_tokens: Optional[bool] = None) -> list[int]: # `encode` should only be used for prompt completion # it should never be used for chat_completion. # For chat completion use `apply_chat_template` @@ -374,9 +378,9 @@ def encode(self, return self.tokenizer.encode(text, bos=True, eos=False) def apply_chat_template(self, - messages: List["ChatCompletionMessageParam"], - tools: Optional[List[Dict[str, Any]]] = None, - **kwargs) -> List[int]: + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs) -> list[int]: request = make_mistral_chat_completion_request(messages, tools) encoded = self.mistral.encode_chat_completion(request) @@ -384,7 +388,7 @@ def apply_chat_template(self, # encode-decode to get clean prompt return encoded.tokens - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: from mistral_common.tokens.tokenizers.base import SpecialTokens if self.is_tekken: tokens = [ @@ -417,7 +421,7 @@ def _token_to_id(t: str): # make sure certain special tokens like Tool calls are # not decoded special_tokens = {SpecialTokens.tool_calls} - regular_tokens: List[str] = [] + regular_tokens: list[str] = [] decoded_list = [] for token in tokens: @@ -442,7 +446,7 @@ def _token_to_id(t: str): # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer # for more. def decode(self, - ids: Union[List[int], int], + ids: Union[list[int], int], skip_special_tokens: bool = True) -> str: assert ( skip_special_tokens @@ -454,9 +458,9 @@ def decode(self, def convert_ids_to_tokens( self, - ids: List[int], + ids: list[int], skip_special_tokens: bool = True, - ) -> List[str]: + ) -> list[str]: from mistral_common.tokens.tokenizers.base import SpecialTokens # TODO(Patrick) - potentially allow special tokens to not be skipped diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 81eb4d9b6abc..8dff1b612fdb 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -4,7 +4,7 @@ from functools import cache from os import PathLike from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union from vllm.envs import VLLM_MODEL_REDIRECT_PATH from vllm.logger import init_logger @@ -38,7 +38,7 @@ def modelscope_list_repo_files( repo_id: str, revision: Optional[str] = None, token: Union[str, bool, None] = None, -) -> List[str]: +) -> list[str]: """List files in a modelscope repo.""" from modelscope.hub.api import HubApi api = HubApi() diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 67b834533b7d..90af0c63cc02 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -161,7 +161,7 @@ def _report_usage_worker(self, model_architecture: str, usage_context: UsageContext, extra_kvs: dict[str, Any]) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) - self._report_continous_usage() + self._report_continuous_usage() def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, @@ -219,7 +219,7 @@ def _report_usage_once(self, model_architecture: str, self._write_to_file(data) self._send_to_server(data) - def _report_continous_usage(self): + def _report_continuous_usage(self): """Report usage every 10 minutes. This helps us to collect more data points for uptime of vLLM usages. diff --git a/vllm/utils.py b/vllm/utils.py index 24535196ccde..846df7743736 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -15,10 +15,10 @@ import importlib.util import inspect import ipaddress +import json import multiprocessing import os import pickle -import re import signal import socket import subprocess @@ -33,7 +33,8 @@ import warnings import weakref from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, - ArgumentTypeError, _ArgumentGroup) + ArgumentTypeError, RawDescriptionHelpFormatter, + _ArgumentGroup) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, @@ -53,6 +54,7 @@ import numpy as np import numpy.typing as npt import psutil +import regex as re import torch import torch.types import yaml @@ -76,9 +78,15 @@ logger = init_logger(__name__) +# This value is chosen to have a balance between ITL and TTFT. Note it is +# not optimized for throughput. +DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 +POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 + # Exception strings for non-implemented encoder/decoder scenarios -# Reminder: Please update docs/source/features/compatibility_matrix.md +# Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid STR_NOT_IMPL_ENC_DEC_SWA = \ @@ -153,6 +161,7 @@ STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" GB_bytes = 1_000_000_000 @@ -612,6 +621,10 @@ def is_valid_ipv6_address(address: str) -> bool: def get_distributed_init_method(ip: str, port: int) -> str: + return get_tcp_uri(ip, port) + + +def get_tcp_uri(ip: str, port: int) -> str: # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848 return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" @@ -746,16 +759,15 @@ def get_kv_cache_torch_dtype( model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": - if isinstance(model_dtype, str): + if isinstance(model_dtype, + str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] elif isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype else: raise ValueError(f"Invalid model dtype: {model_dtype}") - elif cache_dtype in ["half", "bfloat16", "float"]: + elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - elif cache_dtype == "fp8": - torch_dtype = torch.uint8 else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") elif isinstance(cache_dtype, torch.dtype): @@ -992,7 +1004,7 @@ def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): """ - Unlike {class}`itertools.groupby`, groups are not broken by + Unlike [`itertools.groupby`][], groups are not broken by non-contiguous data. """ groups = defaultdict[_K, list[_V]](list) @@ -1311,7 +1323,8 @@ def __call__(self, parser, namespace, values, option_string=None): "Expected 'true' or 'false'.") -class SortedHelpFormatter(ArgumentDefaultsHelpFormatter): +class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, + RawDescriptionHelpFormatter): """SortedHelpFormatter that sorts arguments by their option strings.""" def _split_lines(self, text, width): @@ -1414,6 +1427,51 @@ def parse_args( # type: ignore[override] else: processed_args.append(arg) + def create_nested_dict(keys: list[str], value: str): + """Creates a nested dictionary from a list of keys and a value. + + For example, `keys = ["a", "b", "c"]` and `value = 1` will create: + `{"a": {"b": {"c": 1}}}` + """ + nested_dict: Any = value + for key in reversed(keys): + nested_dict = {key: nested_dict} + return nested_dict + + def recursive_dict_update(original: dict, update: dict): + """Recursively updates a dictionary with another dictionary.""" + for k, v in update.items(): + if isinstance(v, dict) and isinstance(original.get(k), dict): + recursive_dict_update(original[k], v) + else: + original[k] = v + + delete = set() + dict_args: dict[str, dict] = defaultdict(dict) + for i, processed_arg in enumerate(processed_args): + if processed_arg.startswith("--") and "." in processed_arg: + if "=" in processed_arg: + processed_arg, value = processed_arg.split("=", 1) + if "." not in processed_arg: + # False positive, . was only in the value + continue + else: + value = processed_args[i + 1] + delete.add(i + 1) + key, *keys = processed_arg.split(".") + # Merge all values with the same key into a single dict + arg_dict = create_nested_dict(keys, value) + recursive_dict_update(dict_args[key], arg_dict) + delete.add(i) + # Filter out the dict args we set to None + processed_args = [ + a for i, a in enumerate(processed_args) if i not in delete + ] + # Add the dict args back as if they were originally passed as JSON + for dict_arg, dict_value in dict_args.items(): + processed_args.append(dict_arg) + processed_args.append(json.dumps(dict_value)) + return super().parse_args(processed_args, namespace) def check_port(self, value): @@ -1820,6 +1878,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) +def is_in_doc_build() -> bool: + try: + from sphinx.ext.autodoc.mock import _MockModule + return isinstance(zmq, _MockModule) + except ModuleNotFoundError: + return False + + def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): """ Import a Python file according to its file path. @@ -1859,11 +1925,11 @@ class _PlaceholderBase: Disallows downstream usage of placeholder modules. We need to explicitly override each dunder method because - {meth}`__getattr__` is not called when they are accessed. + [`__getattr__`][vllm.utils._PlaceholderBase.__getattr__] + is not called when they are accessed. - :::{seealso} - [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup) - ::: + Info: + [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup) """ def __getattr__(self, key: str) -> Never: @@ -2337,6 +2403,24 @@ def split_zmq_path(path: str) -> Tuple[str, str, str]: return scheme, host, port +def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str: + """Make a ZMQ path from its parts. + + Args: + scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc). + host: The host - can be an IPv4 address, IPv6 address, or hostname. + port: Optional port number, only used for TCP sockets. + + Returns: + A properly formatted ZMQ path string. + """ + if not port: + return f"{scheme}://{host}" + if is_valid_ipv6_address(host): + return f"{scheme}://[{host}]:{port}" + return f"{scheme}://{host}:{port}" + + # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] @@ -2447,7 +2531,7 @@ def _maybe_force_spawn(): logger.warning( "We must use the `spawn` multiprocessing start method. " "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " - "See https://docs.vllm.ai/en/latest/getting_started/" + "See https://docs.vllm.ai/en/latest/usage/" "troubleshooting.html#python-multiprocessing " "for more information. Reason: %s", reason) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -2712,14 +2796,17 @@ def wrapper(*args, **kwargs): # Only relevant for models using ALiBi (e.g, MPT) def check_use_alibi(model_config: ModelConfig) -> bool: - return (getattr(model_config.hf_text_config, "alibi", False) # Falcon + cfg = model_config.hf_text_config + return (getattr(cfg, "alibi", False) # Falcon or ("BloomForCausalLM" in getattr(model_config.hf_config, "architectures", [])) # Bloom - or getattr(model_config.hf_text_config, "position_encoding_type", - "") == "alibi" # codellm_1b_alibi - or - (hasattr(model_config.hf_text_config, "attn_config") # MPT - and model_config.hf_text_config.attn_config.get("alibi", False))) + or getattr(cfg, "position_encoding_type", "") == + "alibi" # codellm_1b_alibi + or (hasattr(cfg, "attn_config") # MPT + and ((isinstance(cfg.attn_config, dict) + and cfg.attn_config.get("alibi", False)) or + (not isinstance(cfg.attn_config, dict) + and getattr(cfg.attn_config, "alibi", False))))) def sha256(input) -> int: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 605dff3749fb..9ed3dec7f269 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -19,6 +19,8 @@ from vllm.platforms import current_platform from vllm.utils import cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -167,7 +169,7 @@ def make_local_attention_virtual_batches( query_start_loc_np: np.ndarray, seq_lens_np: np.ndarray, block_table: torch.Tensor, - page_size: int = 0, + block_size: int = 0, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] actual_batch_size = seq_lens_np.shape[0] @@ -238,14 +240,14 @@ def make_local_attention_virtual_batches( # For the example the local attention blocks start at: # _b0_ _____b1_____ _b2_ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] - block_starts = k_seqstarts_absolute // page_size - assert attn_chunk_size % page_size == 0, \ + block_starts = k_seqstarts_absolute // block_size + assert attn_chunk_size % block_size == 0, \ f"attn_chunk_size {attn_chunk_size} is not " \ - f"divisible by page_size {page_size}" - pages_per_local_batch = attn_chunk_size // page_size + f"divisible by block_size {block_size}" + pages_per_local_batch = attn_chunk_size // block_size # Create a block_table for the local attention blocks - # For out example if we have a block-table like (assuming page_size=2): + # For out example if we have a block-table like (assuming block_size=2): # block_table = [ # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 @@ -289,7 +291,8 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner"): + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, + block_table: BlockTable): model_config = runner.model_config compilation_config = runner.vllm_config.compilation_config @@ -299,7 +302,9 @@ def __init__(self, runner: "GPUModelRunner"): self.num_heads_kv = model_config.get_num_kv_heads( runner.parallel_config) self.headdim = model_config.get_head_size() - self.page_size = self.runner.block_size + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table if get_flash_attn_version() == 3: self.aot_schedule = not compilation_config.full_cuda_graph @@ -323,9 +328,17 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping[:num_actual_tokens] + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) @@ -354,7 +367,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, num_heads_q=self.num_heads_q, num_heads_kv=self.num_heads_kv, headdim=self.headdim, - page_size=self.page_size, + page_size=self.block_size, cu_seqlens_q=cu_query_lens, causal=causal, window_size=self.aot_sliding_window, @@ -365,12 +378,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata = None if self.runner.attention_chunk_size is not None: seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table = make_local_attention_virtual_batches( + virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, self.runner.query_start_loc_np[:num_reqs + 1], self.runner.seq_lens_np[:num_reqs], - block_table, - self.runner.block_size, + block_table_tensor, + self.block_size, ) local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( self.runner.device, non_blocking=True) @@ -389,7 +402,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( local_query_start_loc=local_query_start_loc, local_seqused_k=local_seqused_k, - local_block_table=virt_block_table, + local_block_table=virt_block_table_tensor, local_max_query_len=local_max_query_len, local_max_seq_len=local_max_seq_len, local_scheduler_metadata=local_scheduler_metadata, @@ -440,7 +453,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, - block_table=block_table, + block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0852e15f9c19..1c4f7f62fa67 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -14,11 +14,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import (VllmConfig, get_current_vllm_config, - get_layers_from_vllm_config) +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -202,7 +203,8 @@ def __post_init__(self): class FlashInferMetadataBuilder: - def __init__(self, runner: GPUModelRunner): + def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): self.runner = runner self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append @@ -212,7 +214,9 @@ def __init__(self, runner: GPUModelRunner): # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = get_current_vllm_config() + self.vllm_config = runner.vllm_config + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -400,13 +404,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) - page_size = self.runner.block_size + page_size = self.kv_cache_spec.block_size device = self.runner.device qo_indptr = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] + slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() block_table_bounds = (seq_lens + page_size - 1) // page_size @@ -422,12 +425,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], dtype=torch.int32, device=device) - shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_page_indices = block_table_tensor[ + 0, :num_common_kv_blocks] shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device=device) # Remove the blocks of the shared prefix from all requests. - block_table = block_table[:, num_common_kv_blocks:] + block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] block_table_bounds -= num_common_kv_blocks else: shared_qo_indptr = None @@ -435,11 +439,11 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indices = None shared_kv_last_page_len = None - mask = (torch.arange(block_table.size(1), - dtype=block_table.dtype, - device=block_table.device).unsqueeze(0) + mask = (torch.arange(block_table_tensor.size(1), + dtype=block_table_tensor.dtype, + device=block_table_tensor.device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table[mask] + paged_kv_indices = block_table_tensor[mask] paged_kv_indptr = torch.cat([ torch.zeros(1, @@ -459,10 +463,10 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, num_qo_heads=self.runner.num_query_heads, - num_kv_heads=self.runner.num_kv_heads, - head_dim=self.runner.head_size, + num_kv_heads=self.kv_cache_spec.num_kv_heads, + head_dim=self.kv_cache_spec.head_size, page_size=page_size, - data_type=self.runner.kv_cache_dtype, + data_type=self.kv_cache_spec.dtype, q_data_type=self.runner.dtype, slot_mapping=slot_mapping, num_decodes=self._num_decodes, @@ -481,7 +485,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: - if self.runner.kv_cache_dtype != self.runner.model_config.dtype: + if self.kv_cache_spec.dtype != self.runner.model_config.dtype: # TODO: The cascade wrapper currently does not support setting # kv cache dtype to something different from query dtype. return False diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0d18a5639c2a..83e181116577 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -204,10 +204,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -269,9 +270,6 @@ class ChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -280,9 +278,6 @@ class ChunkedContextMetadata: @dataclass class MLACommonDecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor block_table: torch.Tensor seq_lens: torch.Tensor @@ -341,6 +336,8 @@ class MLACommonMetadataBuilder(Generic[M]): def __init__(self, runner: "GPUModelRunner", + kv_cache_spec: AttentionSpec, + block_table: BlockTable, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata @@ -353,10 +350,11 @@ def __init__(self, runner.parallel_config) self.mla_dims = get_mla_dims(model_config) self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD if self.aot_schedule: - self.page_size = self.runner.block_size + self.page_size = self.kv_cache_spec.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -382,6 +380,7 @@ def __init__(self, dtype=model_config.dtype, device=runner.device, ) + self.block_table = block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -443,11 +442,10 @@ def reorder_batch(self, input_batch: "InputBatch", return modified_batch - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, seq_lens: torch.Tensor): + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor): return MLACommonDecodeMetadata( - input_positions=input_positions, - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, ) @@ -460,11 +458,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() query_start_loc = common_attn_metadata.query_start_loc @@ -473,7 +469,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start - tokens_start = self._num_decode_tokens context_lens_cpu = self.runner.input_batch.\ num_computed_tokens_cpu_tensor[reqs_start:num_reqs] @@ -496,11 +491,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, max_context_chunk = (self.chunked_prefill_workspace_size // num_prefills_with_context_cpu) - # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) @@ -541,8 +537,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, self.chunked_prefill_workspace_size prefill_metadata = MLACommonPrefillMetadata( - input_positions=input_positions[tokens_start:], - block_table=block_table[reqs_start:, ...], + block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, @@ -551,8 +546,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, decode_metadata = None if self._num_decodes > 0: decode_metadata = self._build_decode( - input_positions=input_positions[:self._num_decode_tokens], - block_table=block_table[:self._num_decodes, ...], + block_table_tensor=block_table_tensor[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], ) @@ -598,7 +592,6 @@ def __init__( qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - rotary_emb: RotaryEmbedding, kv_b_proj: ColumnParallelLinear, ) -> None: self.num_heads = num_heads @@ -613,15 +606,6 @@ def __init__( self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - - # Hack for V1 for now to avoid torch library overhead (since we are - # already inside an attention custom op), pull out the forward - # method from the rotary embedding and call it directly - # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb.forward_native - if current_platform.is_cuda(): - self.rotary_emb = rotary_emb.forward_cuda - self.kv_b_proj = kv_b_proj self.vllm_flash_attn_version = get_flash_attn_version() @@ -881,8 +865,10 @@ def forward( assert output is not None, "Output tensor must be provided." if attn_metadata is None: - # Profiling run. - return output + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens @@ -893,9 +879,6 @@ def forward( k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None @@ -904,35 +887,12 @@ def forward( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - q = q.view(-1, self.num_heads, self.qk_head_dim) decode_q = q[:num_decode_tokens] - decode_k_pe = k_pe[:num_decode_tokens] prefill_q = q[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] - if has_decode: - assert attn_metadata.decode is not None - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) - - if has_prefill: - assert attn_metadata.prefill is not None - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, - prefill_k_pe) - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -950,6 +910,16 @@ def forward( attn_metadata) if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index f18c9c8b6462..e6594c6b6fa8 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -16,6 +16,8 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -52,14 +54,14 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - def __init__(self, runner): - super().__init__(runner) + def __init__(self, runner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( @@ -69,8 +71,7 @@ def _build_decode(self, input_positions: torch.Tensor, ) return FlashMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py new file mode 100644 index 000000000000..31980e94a037 --- /dev/null +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +import vllm.envs as envs +from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd +# yapf conflicts with isort for this docstring +# yapf: disable +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable + +# yapf: enable + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + +@dataclass +class AiterMLADecodeMetadata(MLACommonDecodeMetadata): + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The query indptr, shape : [num_decode + 1] + qo_indptr: Optional[torch.Tensor] = None + + +class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): + pass + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + + def __init__(self, runner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) + max_model_len = self.runner.model_config.max_model_len + assert max_model_len == 32768,\ + "AITER MLA requires max_model_len=32768" + assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ + "only supports block size 1." + + def _get_paged_kv_tensors( + self, block_table: torch.Tensor, + seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: + page_size = self.kv_cache_spec.block_size + block_table_bounds = (seq_lens + page_size - 1) // page_size + device = self.runner.device + + mask = (torch.arange(block_table.size(1), + dtype=block_table.dtype, + device=device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + paged_kv_indptr = torch.cat([ + torch.zeros(1, dtype=block_table_bounds.dtype, device=device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32) + ]) + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) + qo_indptr = torch.arange(0, + self._num_decodes + 1, + step=1, + dtype=torch.int32, + device=device) + + return ( + paged_kv_indices, + paged_kv_indptr, + paged_kv_last_page_len, + qo_indptr, + ) + + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + + ( + paged_kv_indices, + paged_kv_indptr, + paged_last_page_len, + qo_indptr, + ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) + + attn_metadata = AiterMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_last_page_len, + qo_indptr=qo_indptr) + + return attn_metadata + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + assert (num_heads == 16 or num_heads == 128), ( + f"Aiter MLA only supports 16 or 128 number of heads.\n" + f"Provided {num_heads} number of heads.\n" + "Try adjusting tensor_parallel_size value.") + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + if self.num_heads == 16: + # AITER MLA decode kernel only supports + # max_seqlen_q=1 when using 16 heads. + max_seqlen_qo = 1 + else: + # AITER MLA decode Kernel handles arbitrary + # max_seqlen_q values when using 128 heads. + assert attn_metadata.prefill is not None + max_seqlen_qo = attn_metadata.prefill.max_query_len + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.decode.qo_indptr, max_seqlen_qo, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_len) + + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index bb700c8e2e7a..4000f93984d3 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,20 +1,37 @@ # SPDX-License-Identifier: Apache-2.0 """Attention layer with PagedAttention and Triton prefix prefill.""" -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import torch from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.ops.chunked_prefill_paged_decode import ( + chunked_prefill_paged_decode) +from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) +class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) + self.aot_schedule = False + + class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -51,8 +68,8 @@ def use_cascade_attention(*args, **kwargs) -> bool: return False @staticmethod - def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: - return FlashAttentionMetadataBuilder + def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: + return TritonAttentionMetadataBuilder class TritonAttentionImpl(AttentionImpl): @@ -108,6 +125,8 @@ def __init__( "are not implemented for " "TritonAttentionImpl") + self.fp8_dtype = current_platform.fp8_dtype() + def forward( self, layer: torch.nn.Module, @@ -146,30 +165,54 @@ def forward( # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. + num_queries_per_kv = query.shape[1] // key.shape[1] + use_prefill_decode_attn = (num_queries_per_kv & + (num_queries_per_kv - 1)) != 0 + num_actual_tokens = attn_metadata.num_actual_tokens - key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if use_prefill_decode_attn: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + else: + key_cache, value_cache = kv_cache.unbind(0) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(torch.float8_e4m3fn) - value_cache = value_cache.view(torch.float8_e4m3fn) + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape assert layer._q_scale == 1.0, \ "A non 1.0 q_scale is not currently supported." - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + if not current_platform.is_rocm(): + # Skip Q quantization on ROCm, since dequantizing back to + # f32 in the attention kernel is not supported. + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) use_local_attn = \ @@ -190,26 +233,47 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - - unified_attention( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - q_descale=None, # Not supported - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) + if use_prefill_decode_attn: + # Compute attention and update output up to `num_actual_tokens`. + chunked_prefill_paged_decode(query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale) + + else: + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + + unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) return output diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 9e172b6bdb00..0f6098d2b400 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,17 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from collections.abc import Iterable from dataclasses import dataclass from typing import Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger -from vllm.utils import cdiv, sha256 +from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) -from vllm.v1.core.specialized_manager import get_specialized_manager +from vllm.v1.core.single_type_kv_cache_manager import ( + get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -32,9 +32,22 @@ def create_empty(cls) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" return cls([]) - def get_block_ids(self) -> list[int]: - """Converts the KVCacheBlocks instance to a list of block IDs.""" - return [block.block_id for block in self.blocks] + def get_block_ids(self) -> list[list[int]]: + """ + Converts the KVCacheBlocks instance to block_ids. + + Returns: + list[list[int]]: A two-level list where + * the outer list corresponds to KV cache groups (only 1 group now) + * each inner list contains the block_ids of the blocks in that group + """ + return [[block.block_id for block in self.blocks]] + + def get_unhashed_block_ids(self) -> list[int]: + """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" + return [ + block.block_id for block in self.blocks if block.block_hash is None + ] class KVCacheManager: @@ -56,7 +69,6 @@ def __init__( self.block_size = kv_cache_spec.block_size self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash @@ -68,30 +80,20 @@ def __init__( self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, enable_kv_cache_events) - self.specialized_manager = get_specialized_manager( + self.single_type_manager = get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, use_eagle=self.use_eagle, + num_kv_cache_groups=1, + caching_hash_fn=self.caching_hash_fn, ) - # Mapping from request ID to blocks to track the blocks allocated - # for each request, so that we can free the blocks when the request - # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) - # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ str, list[BlockHashType]] = defaultdict(list) - # {req_id: The number of cached blocks for this given request} - # This is used to track the number of cached blocks for each request. - # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. - self.num_cached_block: dict[str, int] = {} - @property def usage(self) -> float: """Get the KV cache usage. @@ -126,8 +128,10 @@ def get_computed_blocks(self, - A list of blocks that are computed for the request. - The number of computed tokens. """ - if not self.enable_caching: - # Prefix caching is disabled. + # Prefix caching is disabled or + # When the request requires prompt logprobs, we skip prefix caching. + if (not self.enable_caching + or request.sampling_params.prompt_logprobs is not None): return KVCacheBlocks.create_empty(), 0 # The block hashes for the request may already be computed @@ -141,62 +145,56 @@ def get_computed_blocks(self, if self.log_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.requests += 1 - # When the request requires prompt logprobs, we skip prefix caching. - if request.sampling_params.prompt_logprobs is not None: - return KVCacheBlocks.create_empty(), 0 - - if len(block_hashes) * self.block_size == request.num_tokens: - # When prompt length is divisible by the block size and all - # blocks are cached, we need to recompute the last token. This - # have to be achieved by re-computing an entire block because - # allocate_slots() assumes num_computed_tokens is always a - # multiple of the block size. To achieve this, remove the last - # block hash from the block_hashes for find_longest_cache_hit - # This limitation can potentially be removed in the future to - # slightly improve the performance. - last_block_hash = block_hashes.pop() - else: - last_block_hash = None - computed_blocks = ( - self.specialized_manager.find_longest_cache_hit(block_hashes)) - - if self.log_stats: - assert self.prefix_cache_stats is not None - self.prefix_cache_stats.queries += len(block_hashes) - self.prefix_cache_stats.hits += len(computed_blocks) - - if last_block_hash is not None: - # Add back the last block hash if it was removed. - # NOTE: Because block_hashes is cached in req_to_block_hashes, - # we shouldn't modify it directly. - block_hashes.append(last_block_hash) + # NOTE: When all tokens hit the cache, we must recompute the last token + # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. + # This can trigger recomputation of an entire block, rather than just + # the single last token, because allocate_slots() requires + # num_computed_tokens to be block-size aligned. Removing this limitation + # could slightly improve performance in the future. + max_cache_hit_length = request.num_tokens - 1 + computed_blocks = self.single_type_manager.find_longest_cache_hit( + block_hashes, max_cache_hit_length) # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. num_computed_tokens = len(computed_blocks) * self.block_size + + if self.log_stats: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.queries += request.num_tokens + self.prefix_cache_stats.hits += num_computed_tokens + return KVCacheBlocks(computed_blocks), num_computed_tokens def allocate_slots( self, request: Request, - num_tokens: int, + num_new_tokens: int, + num_new_computed_tokens: int = 0, new_computed_blocks: Optional[KVCacheBlocks] = None, + num_draft_tokens: int = 0, num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. Args: request: The request to allocate slots. - num_tokens: The number of tokens to allocate, including external + num_new_tokens: The number of tokens to allocate, including external tokens. Note that this does not include tokens that have already been computed locally (i.e. new_computed_blocks). - new_computed_blocks: The new computed blocks just hitting the - prefix caching. + num_new_computed_tokens: The number of new computed tokens just + hitting the prefix caching, excluding external tokens. + new_computed_blocks: The cached blocks for the above new computed + tokens. num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such as eagle. + delay_cache_blocks: Whether to skip caching the blocks. This is + used by P/D when allocating blocks used in a KV transfer + which will complete in a future step. Blocks layout: ``` @@ -215,44 +213,38 @@ def allocate_slots( Returns: A list of new allocated blocks. """ - if num_tokens == 0: - raise ValueError("num_tokens must be greater than 0") + if num_new_tokens == 0: + raise ValueError("num_new_tokens must be greater than 0") if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = [] - req_blocks = self.req_to_blocks[request.request_id] - # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). # We can do this even if we cannot schedule this request due to # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - removed_blocks = self.specialized_manager.remove_skipped_blocks( - req_blocks, request.num_computed_tokens) - self.block_pool.free_blocks(removed_blocks) + self.single_type_manager.remove_skipped_blocks( + request.request_id, request.num_computed_tokens) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - len(new_computed_block_list) * self.block_size) - num_required_blocks = cdiv( - num_computed_tokens + num_tokens + num_lookahead_tokens, - self.block_size) - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_block_list)) - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum(1 - for blk in new_computed_block_list - if blk.ref_cnt == 0) - if (num_new_blocks > self.block_pool.get_num_free_blocks() - - num_evictable_computed_blocks): + num_new_computed_tokens) + num_tokens_need_slot = min( + num_computed_tokens + num_new_tokens + num_lookahead_tokens, + self.max_model_len) + num_blocks_to_allocate = ( + self.single_type_manager.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_tokens_need_slot, + new_computed_blocks=new_computed_block_list, + )) + + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): # Cannot allocate new blocks return None @@ -266,74 +258,35 @@ def allocate_slots( # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_block_list) + self.single_type_manager.save_new_computed_blocks( + request.request_id, new_computed_block_list) - # Start to handle new blocks + new_blocks = self.single_type_manager.allocate_new_blocks( + request.request_id, num_tokens_need_slot) - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool. - num_new_blocks = min( - num_new_blocks, - self.block_pool.get_num_free_blocks(), - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) - - if not self.enable_caching: + # P/D: delay caching blocks if we have to recv from + # remote. Update state for locally cached blocks. + if not self.enable_caching or delay_cache_blocks: return KVCacheBlocks(new_blocks) - # Use `new_computed_block_list` for a new request, and - # `num_cached_block` for a running request. - num_cached_blocks = self.num_cached_block.get( - request.request_id, len(new_computed_block_list)) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( - request.spec_token_ids)) // self.block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks, - block_hashes=self.req_to_block_hashes[request.request_id], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks_after_append, - block_size=self.block_size, - hash_fn=self.caching_hash_fn, - ) + self.single_type_manager.cache_blocks( + request, self.req_to_block_hashes[request.request_id], + num_computed_tokens + num_new_tokens - num_draft_tokens) - self.num_cached_block[ - request.request_id] = num_full_blocks_after_append return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. - When caching is enabled, we free the blocks in reverse order so that - the tail blocks are evicted first. + We free the blocks in reverse order so that he tail blocks are evicted + first when caching is enabled. Args: request: The request to free the blocks. """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) - - self.block_pool.free_blocks(ordered_blocks) - self.num_cached_block.pop(request.request_id, None) + self.single_type_manager.free(request.request_id) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -355,9 +308,9 @@ def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> int: + ) -> list[int]: """Calculate the number of common prefix blocks shared by all requests - in the RUNNING state. + in the RUNNING state for each kv cache group. The function determines this by selecting any request and iterating through its blocks. A block is considered a common prefix block if its @@ -387,17 +340,14 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - int: The number of common prefix blocks. + list[int]: The number of common prefix blocks for each kv cache + group. """ assert request.status == RequestStatus.RUNNING - blocks = self.req_to_blocks[request.request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break - return num_common_blocks + return [ + self.single_type_manager.get_num_common_prefix_blocks( + request.request_id, num_running_requests) + ] def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -414,3 +364,9 @@ def take_events(self) -> list[KVCacheEvent]: A list of KV cache events. """ return self.block_pool.take_events() + + def get_block_ids(self, request_id: str) -> list[list[int]]: + """Get the block ids of a request.""" + assert request_id in self.single_type_manager.req_to_blocks + return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id] + ).get_block_ids() diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 27c515835087..403b5401be75 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -577,14 +577,12 @@ def create_kv_cache_group_specs( """ kv_cache_groups = [] for layer_names_one_group in grouped_layer_names: - layer_spec = kv_cache_spec[layer_names_one_group[0]] - assert all( - kv_cache_spec[layer_name] == layer_spec - for layer_name in layer_names_one_group[1:]), ( - "All layers in the same KV cache group must share the same " - "KVCacheSpec.") + layer_specs = [ + kv_cache_spec[layer_name] for layer_name in layer_names_one_group + ] + merged_layer_spec = layer_specs[0].merge(layer_specs) kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, layer_spec)) + KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) return kv_cache_groups @@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): head_size=spec.head_size, dtype=spec.dtype, use_mla=spec.use_mla, + sliding_window=spec.sliding_window, ) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 0b328f510903..c17f80b6ae78 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats @@ -137,3 +138,6 @@ def make_stats(self) -> Optional["SchedulerStats"]: def shutdown(self) -> None: """Shutdown the scheduler.""" raise NotImplementedError + + def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]: + return None diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 24032498e50b..257234430983 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -26,7 +26,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: list[int] + block_ids: list[list[int]] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -34,7 +34,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: list[int], + block_ids: list[list[int]], ) -> NewRequestData: return cls( req_id=request.request_id, @@ -85,7 +85,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: list[int] + new_block_ids: list[list[int]] num_computed_tokens: int @classmethod @@ -94,7 +94,7 @@ def from_request( request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: list[int], + new_block_ids: list[list[int]], ) -> CachedRequestData: return cls( req_id=request.request_id, @@ -131,9 +131,9 @@ class SchedulerOutput: # E.g., if a request has [0, 1], it could mean the vision encoder needs # to process that the request's 0-th and 1-th images in the current step. scheduled_encoder_inputs: dict[str, list[int]] - # Number of common prefix blocks for all requests. + # Number of common prefix blocks for all requests in each KV cache group. # This can be used for cascade attention. - num_common_prefix_blocks: int + num_common_prefix_blocks: list[int] # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 258e0d570e3e..4c6b3eea0cb7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -5,18 +5,19 @@ import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -96,6 +97,9 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() + # P/D: requests in process of recving KV transfers + self.finished_recving_kv_req_ids: set[str] = set() + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> deque of CachedRequestData @@ -169,7 +173,7 @@ def schedule(self) -> SchedulerOutput: # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, list[int]] = {} + req_to_new_block_ids: dict[str, list[list[int]]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -223,10 +227,15 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue + num_draft_tokens = max( + num_new_tokens + request.num_computed_tokens - + request.num_tokens, 0) + while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, + num_draft_tokens=num_draft_tokens, num_lookahead_tokens=self.num_lookahead_tokens) if new_blocks is None: # The request cannot be scheduled. @@ -307,6 +316,19 @@ def schedule(self) -> SchedulerOutput: request = self.waiting[0] + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -329,50 +351,70 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue - # Get already-cached tokens. - computed_blocks, num_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + num_external_computed_tokens = 0 + load_kv_async = False - # Get externally-cached tokens if using a KVConnector. - num_external_tokens = ( - 0 if self.connector is None else - self.connector.get_num_new_matched_tokens( - request, num_computed_tokens)) + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + else: + new_computed_blocks = KVCacheBlocks.create_empty() + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens - # Total computed tokens (local + external). - num_computed_tokens += num_external_tokens + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + # KVTransfer: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 # Number of tokens to be scheduled. - # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed requests, - # which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) - num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 - - # Schedule encoder inputs. - if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) - if num_new_tokens == 0: - # The request cannot be scheduled. - break else: - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break new_blocks = self.kv_cache_manager.allocate_slots( request, - num_new_tokens + num_external_tokens, - computed_blocks, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async, ) if new_blocks is None: # The request cannot be scheduled. @@ -381,13 +423,22 @@ def schedule(self) -> SchedulerOutput: # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. - if self.connector is not None: + if num_external_computed_tokens: + assert self.connector is not None self.connector.update_state_after_alloc( request, - num_external_tokens, + new_computed_blocks + new_blocks, + num_external_computed_tokens, ) self.waiting.popleft() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + if request.use_structured_output: structured_output_request_ids[ request.request_id] = req_index @@ -407,12 +458,14 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_block_ids[request.request_id] = ( - computed_blocks + new_blocks).get_block_ids() + self.kv_cache_manager.get_block_ids(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - + # Count the number of prifix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( @@ -439,7 +492,8 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = 0 + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -526,7 +580,7 @@ def _make_cached_request_data( request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: list[int], + new_block_ids: list[list[int]], resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating @@ -698,6 +752,7 @@ def update_from_output( stopped = False new_logprobs = None new_token_ids = generated_token_ids + kv_transfer_params = None # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner @@ -709,7 +764,7 @@ def update_from_output( # This must be called before we make the EngineCoreOutput. stopped = check_stop(request, self.max_model_len) if stopped: - self._free_request(request) + kv_transfer_params = self._free_request(request) del new_token_ids[num_new:] # Trim new tokens if needed. break @@ -719,7 +774,8 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and request.use_structured_output: + if new_token_ids and self.structured_output_manager.should_advance( + request): # NOTE: structured_output_request # should not be None if use_structured_output, we have # check above, so safe to ignore type warning @@ -728,18 +784,18 @@ def update_from_output( # Add newly generated spec token ids to the request. if spec_token_ids is not None: - if request.use_structured_output: + if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request - assert metadata is not None and metadata.grammar is not None # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids: + if new_token_ids or kv_transfer_params: + # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -749,7 +805,11 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + num_cached_tokens=request.num_cached_tokens, + )) + else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -757,6 +817,9 @@ def update_from_output( if not stopped: new_running.append(request) + # P/D: update state for finished KV Transfers. + self._update_from_kv_xfer_finished(model_runner_output) + # Return the cached request data to the queue so they can be reused. for req_data in scheduler_output.scheduled_cached_reqs: # NOTE(rob): since we free stopped reqs above, adding stopped reqs @@ -811,15 +874,27 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> None: + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() - self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) - del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + + def _free_blocks(self, request: Request): + assert request.is_finished() + assert request.request_id not in self._cached_reqs_data + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + del self.requests[request.request_id] + def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) @@ -863,3 +938,78 @@ def make_spec_decoding_stats( def shutdown(self) -> None: if self.kv_event_publisher: self.kv_event_publisher.shutdown() + + ######################################################################## + # P/D Related Methods + ######################################################################## + + def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + return self.connector + + def _connector_finished( + self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Invoke the KV connector request_finished() method if applicable. + + Returns optional kv transfer parameters to be included with the + request outputs. + """ + if self.connector is None: + return False, None + assert len(self.kv_cache_config.kv_cache_groups + ) == 1, "KV connector only supports one KV cache group now" + block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] + return self.connector.request_finished(request, block_ids) + + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + P/D: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + if request.request_id not in self.finished_recving_kv_req_ids: + return False + assert len(self.kv_cache_config.kv_cache_groups + ) == 1, "KV connector only supports one KV cache group now" + # Now that the blocks are ready, actually cache them. + block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] + num_computed_tokens = len(block_ids) * self.block_size + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + self.kv_cache_manager.single_type_manager.cache_blocks( + request, + self.kv_cache_manager.req_to_block_hashes[request.request_id], + num_computed_tokens, + ) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + + def _update_from_kv_xfer_finished(self, + model_runner_output: ModelRunnerOutput): + """ + P/D: update the scheduler state based on the output. + + The Worker side connectors add finished_recving and + finished_sending reqs to the output. + * if finished_sending: free the blocks + # if finished_recving: add to state so we can + scheduler the request during the next step. + """ + # P/D: update recv and send status from last step. + for req_id in (model_runner_output.finished_recving or ()): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (model_runner_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py new file mode 100644 index 000000000000..0223c9ceec8d --- /dev/null +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -0,0 +1,358 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Callable + +from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + SlidingWindowSpec) +from vllm.v1.request import Request + + +class SingleTypeKVCacheManager(ABC): + """ + An abstract base class for a manager that handle the kv cache management + logic of one specific type of attention layer. + """ + + def __init__( + self, + kv_cache_spec: KVCacheSpec, + block_pool: BlockPool, + use_eagle: bool, + num_kv_cache_groups: int, + caching_hash_fn: Callable, + ) -> None: + """ + Initializes the SpecializedManager. + Args: + kv_cache_spec: The kv_cache_spec for this manager. + block_pool: The block pool. + use_eagle: Whether to use eagle. + num_kv_cache_groups: The number of kv cache groups managed by this + manager. + caching_hash_fn: The caching hash function. + """ + + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_pool = block_pool + + # Needs special handling for find_longest_cache_hit if eagle is enabled + self.use_eagle = use_eagle + + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: defaultdict[str, + list[KVCacheBlock]] = defaultdict(list) + + # {req_id: The number of cached blocks for this given request} + # This is used to track the number of cached blocks for each request. + # This is only used to track the RUNNING requests, we do not track the + # data for reempted ones. + self.num_cached_block: dict[str, int] = {} + + self.num_kv_cache_groups = num_kv_cache_groups + self.caching_hash_fn = caching_hash_fn + + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[KVCacheBlock]) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks. + """ + + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = (num_required_blocks - len(new_computed_blocks) - + len(self.req_to_blocks[request_id])) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum(blk.ref_cnt == 0 + for blk in new_computed_blocks) + return ((num_new_blocks + num_evictable_computed_blocks) * + self.num_kv_cache_groups) + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + """ + Add the new computed blocks to the request. + + Args: + request_id: The request ID. + new_computed_blocks: The new computed blocks just hitting the + prefix cache. + """ + if request_id not in self.num_cached_block: + # A new request. + req_blocks = self.req_to_blocks[request_id] + assert len(req_blocks) == 0 + req_blocks.extend(new_computed_blocks) + self.num_cached_block[request_id] = len(new_computed_blocks) + else: + # A running request. Should not have new computed blocks. + assert len(new_computed_blocks) == 0 + + def allocate_new_blocks(self, request_id: str, + num_tokens: int) -> list[KVCacheBlock]: + """ + Allocate new blocks for the request to give it at least `num_tokens` + token slots. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + + Returns: + The new allocated blocks. + """ + req_blocks = self.req_to_blocks[request_id] + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = num_required_blocks - len(req_blocks) + if num_new_blocks <= 0: + return [] + else: + new_blocks = self.block_pool.get_new_blocks( + num_new_blocks * self.num_kv_cache_groups) + req_blocks.extend(new_blocks) + return new_blocks + + def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], + num_tokens: int) -> None: + """ + Cache the blocks for the request. + + Args: + request: The request. + block_hashes: The block hashes of the request. + num_tokens: The total number of tokens that need to be cached + (including tokens that are already cached). + """ + num_cached_blocks = self.num_cached_block[request.request_id] + num_full_blocks = num_tokens // self.block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=self.req_to_blocks[request.request_id], + block_hashes=block_hashes, + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks, + block_size=self.block_size, + hash_fn=self.caching_hash_fn, + ) + + self.num_cached_block[request.request_id] = num_full_blocks + + def free(self, request_id: str) -> None: + # Default to [] in case a request is freed (aborted) before alloc. + req_blocks = self.req_to_blocks.pop(request_id, []) + + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(req_blocks) + + self.block_pool.free_blocks(ordered_blocks) + self.num_cached_block.pop(request_id, None) + + @abstractmethod + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + """ + Get the number of common prefix blocks for a request. + + Args: + request_id: The request ID. + block_hashes: The block hashes of the request. + + Returns: + The number of common prefix blocks. + """ + + raise NotImplementedError + + @abstractmethod + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: + """ + Get the longest cache hit prefix of the blocks that is not longer than + `max_length`. If no cache hit is found, return an empty list. + If eagle is enabled, drop the last matched block to force recompute the + last block to get the required hidden states for eagle drafting head. + Need to be customized for each attention type. + + Args: + block_hashes: The block hashes of the request. + max_length: The maximum length of the cache hit prefix. + + Returns: + A list of cached blocks with skipped blocks replaced by null block. + For example, sliding window manager should return a list like + [NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and + sliding window 8. + """ + + raise NotImplementedError + + @abstractmethod + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + """ + Remove the blocks that are no longer needed from `blocks`. The removed + blocks should be replaced by null_block. Return the removed blocks in + eviction order, where the first returned block should be evicted first. + Don't free the removed blocks in this function. Need to be customized + for each attention type. + + Args: + request_id: The request ID. + num_computed_tokens: The number of tokens that have been computed. + """ + raise NotImplementedError + + +class FullAttentionManager(SingleTypeKVCacheManager): + + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: + computed_blocks: list[KVCacheBlock] = [] + max_num_blocks = max_length // self.block_size + for i in range(max_num_blocks): + block_hash = block_hashes[i] + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self.block_pool.get_cached_block(block_hash): + computed_blocks.append(cached_block) + else: + break + if self.use_eagle and len(computed_blocks) > 0: + computed_blocks.pop() + return computed_blocks + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # No need to remove blocks for full attention. + pass + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + blocks = self.req_to_blocks[request_id] + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks + + +class SlidingWindowManager(SingleTypeKVCacheManager): + + def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, + use_eagle: bool, **kwargs) -> None: + super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs) + self.sliding_window = kv_cache_spec.sliding_window + # The number of contiguous blocks needed for prefix cache hit. + # -1 since the input token itself is also included in the window + self.sliding_window_contiguous_blocks = cdiv( + (kv_cache_spec.sliding_window - 1), self.block_size) + if self.use_eagle: + # Need to drop the last matched block if eagle is enabled. For + # sliding window layer, we achieve this by increasing the number of + # contiguous blocks needed for prefix cache hit by one and dropping + # the last matched block. + self.sliding_window_contiguous_blocks += 1 + self._null_block = block_pool.null_block + + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: + # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to + # optimize the time complexity from O(max_num_blocks) to + # O(max_num_blocks / sliding_window_contiguous_blocks + + # sliding_window_contiguous_blocks), + # which is good for low cache hit rate scenarios. + max_num_blocks = max_length // self.block_size + computed_blocks = [self._null_block] * max_num_blocks + num_contiguous_blocks = 0 + + match_found = False + # Search from right to left and early stop when a match is found. + for i in range(max_num_blocks - 1, -1, -1): + if cached_block := self.block_pool.get_cached_block( + block_hashes[i]): + computed_blocks[i] = cached_block + num_contiguous_blocks += 1 + if (num_contiguous_blocks + >= self.sliding_window_contiguous_blocks): + # Trim the trailing blocks. + # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] + # when sliding_window_contiguous_blocks=2. + del computed_blocks[i + num_contiguous_blocks:] + match_found = True + break + else: + num_contiguous_blocks = 0 + if not match_found: + # The first `num_contiguous_blocks` is a cache hit even if + # `num_contiguous_blocks < sliding_window_contiguous_blocks`. + del computed_blocks[num_contiguous_blocks:] + if self.use_eagle and len(computed_blocks) > 0: + computed_blocks.pop() + return computed_blocks + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # Remove the blocks that are no longer be in the sliding window and + # skipped during the attention computation. + last_useful_token = num_computed_tokens - self.sliding_window + 1 + last_useful_block = last_useful_token // self.block_size + blocks = self.req_to_blocks[request_id] + removed_blocks: list[KVCacheBlock] = [] + for i in range(last_useful_block - 1, -1, -1): + if blocks[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also have been set to null blocks by the previous calls + # to this function. + break + removed_blocks.append(blocks[i]) + blocks[i] = self._null_block + self.block_pool.free_blocks(removed_blocks) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + """ + NOTE(Chen): The prefix blocks are null blocks for sliding window layers. + So it's not correct to count ref_cnt like FullAttentionManager. Return + 0 here for correctness. Need to support cascade attention + sliding + window in the future. + """ + return 0 + + +spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { + FullAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager, +} + + +def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, + **kwargs) -> SingleTypeKVCacheManager: + manager_class = spec_manager_map[type(kv_cache_spec)] + manager = manager_class(kv_cache_spec, **kwargs) + return manager diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py deleted file mode 100644 index f04eedf42662..000000000000 --- a/vllm/v1/core/specialized_manager.py +++ /dev/null @@ -1,180 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from abc import ABC, abstractmethod - -from vllm.utils import cdiv -from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, - SlidingWindowSpec) - - -class SpecializedManager(ABC): - """ - An abstract base class for specialized managers that handle the kv - cache management logic of different attention layers. - """ - - def __init__( - self, - kv_cache_spec: KVCacheSpec, - block_pool: BlockPool, - use_eagle: bool, - ) -> None: - """ - Initializes the SpecializedManager. - Args: - kv_cache_spec: The kv_cache_spec for this manager. - block_pool: The block pool. - """ - - self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec - self.block_pool = block_pool - - # Needs special handling for find_longest_cache_hit if eagle is enabled - self.use_eagle = use_eagle - - @abstractmethod - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: - """ - Get the longest cache hit prefix of the blocks. If no cache hit is - found, return an empty list. if eagle is enabled, drop the last matched - block to force recompute the last block to get the required hidden - states for eagle drafting head. - - Args: - block_hashes: The block hashes of the request. - Returns: - A list of cached blocks with skipped blocks replaced by null block. - For example, sliding window manager should return a list like - [NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and - sliding window 8. - """ - - raise NotImplementedError - - @abstractmethod - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: - """ - Remove the blocks that are no longer needed from `blocks`. The removed - blocks should be replaced by null_block. Return the removed blocks in - eviction order, where the first returned block should be evicted first. - Don't free the removed blocks in this function. - - Args: - blocks: The list of blocks to be updated. - num_computed_tokens: The number of tokens that have been computed. - Returns: - The removed blocks in eviction order. - """ - raise NotImplementedError - - -class FullAttentionManager(SpecializedManager): - - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: - computed_blocks: list[KVCacheBlock] = [] - for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block(block_hash): - computed_blocks.append(cached_block) - else: - break - if self.use_eagle and len(computed_blocks) > 0: - computed_blocks.pop() - return computed_blocks - - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: - # No need to remove blocks for full attention. - return [] - - -class SlidingWindowManager(SpecializedManager): - - def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - use_eagle: bool): - super().__init__(kv_cache_spec, block_pool, use_eagle) - self.sliding_window = kv_cache_spec.sliding_window - # The number of contiguous blocks needed for prefix cache hit. - # -1 since the input token itself is also included in the window - self.sliding_window_contiguous_blocks = cdiv( - (kv_cache_spec.sliding_window - 1), self.block_size) - if self.use_eagle: - # Need to drop the last matched block if eagle is enabled. For - # sliding window layer, we achieve this by increasing the number of - # contiguous blocks needed for prefix cache hit by one and dropping - # the last matched block. - self.sliding_window_contiguous_blocks += 1 - self._null_block = block_pool.null_block - - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: - # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to - # optimize the time complexity from O(len(block_hashes)) to - # O(len(block_hashes) / sliding_window_contiguous_blocks + - # sliding_window_contiguous_blocks), - # which is good for low cache hit rate scenarios. - computed_blocks = [self._null_block] * len(block_hashes) - num_contiguous_blocks = 0 - - match_found = False - # Search from right to left and early stop when a match is found. - for i in range(len(block_hashes) - 1, -1, -1): - if cached_block := self.block_pool.get_cached_block( - block_hashes[i]): - computed_blocks[i] = cached_block - num_contiguous_blocks += 1 - if (num_contiguous_blocks - >= self.sliding_window_contiguous_blocks): - # Trim the trailing blocks. - # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] - # when sliding_window_contiguous_blocks=2. - del computed_blocks[i + num_contiguous_blocks:] - match_found = True - break - else: - num_contiguous_blocks = 0 - if not match_found: - # The first `num_contiguous_blocks` is a cache hit even if - # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - del computed_blocks[num_contiguous_blocks:] - if self.use_eagle and len(computed_blocks) > 0: - computed_blocks.pop() - return computed_blocks - - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: - # Remove the blocks that are no longer be in the sliding window and - # skipped during the attention computation. - last_useful_token = num_computed_tokens - self.sliding_window + 1 - last_useful_block = last_useful_token // self.block_size - - removed_blocks: list[KVCacheBlock] = [] - for i in range(last_useful_block - 1, -1, -1): - if blocks[i] == self._null_block: - # If the block is already a null block, the blocks before it - # should also have been set to null blocks by the previous calls - # to this function. - break - removed_blocks.append(blocks[i]) - blocks[i] = self._null_block - return removed_blocks - - -spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = { - FullAttentionSpec: FullAttentionManager, - SlidingWindowSpec: SlidingWindowManager, -} - - -def get_specialized_manager(kv_cache_spec: KVCacheSpec, - **kwargs) -> SpecializedManager: - manager_class = spec_manager_map[type(kv_cache_spec)] - manager = manager_class(kv_cache_spec, **kwargs) - return manager diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e33d1a1e5dcd..41db99beaad5 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -105,6 +105,10 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + kv_transfer_params: Optional[dict[str, Any]] = None + + # The number of tokens with prefix cache hits. + num_cached_tokens: int = 0 @property def finished(self) -> bool: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 00ceb7d3d0c4..74c2251c7521 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -20,6 +20,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext @@ -80,6 +82,9 @@ def __init__( "AsyncLLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") + # Ensure we can serialize custom transformer configs + maybe_register_config_serialize_by_value() + self.model_config = vllm_config.model_config self.vllm_config = vllm_config self.log_requests = log_requests @@ -476,6 +481,11 @@ async def start_profile(self) -> None: async def stop_profile(self) -> None: await self.engine_core.profile_async(False) + async def reset_mm_cache(self) -> None: + self.processor.mm_registry.reset_processor_cache() + self.processor.mm_input_cache_client.reset() + await self.engine_core.reset_mm_cache_async() + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: if device == Device.CPU: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d9dd4957cff2..740ba60fe231 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import json import os import queue import signal @@ -23,7 +22,7 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx +from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -43,6 +42,7 @@ logger = init_logger(__name__) POLLING_TIMEOUT_S = 2.5 +HANDSHAKE_TIMEOUT_MINS = 5 _R = TypeVar('_R') # Return type for collective_rpc @@ -57,6 +57,10 @@ def __init__(self, executor_fail_callback: Optional[Callable] = None): assert vllm_config.model_config.runner_type != "pooling" + # plugins need to be loaded at the engine/scheduler level too + from vllm.plugins import load_general_plugins + load_general_plugins() + self.vllm_config = vllm_config logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) @@ -182,6 +186,11 @@ def add_request(self, request: EngineCoreRequest): # Start grammar compilation asynchronously self.structured_output_manager.grammar_init(req) + if req.kv_transfer_params is not None and ( + not self.scheduler.get_kv_connector()): + logger.warning("Got kv_transfer_params, but no KVConnector found. " + "Disabling KVTransfer for this request.") + self.scheduler.add_request(req) def abort_requests(self, request_ids: list[str]): @@ -277,6 +286,15 @@ def shutdown(self): def profile(self, is_start: bool = True): self.model_executor.profile(is_start) + def reset_mm_cache(self): + # NOTE: Since this is mainly for debugging, we don't attempt to + # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror) + if self.scheduler.has_unfinished_requests(): + logger.warning("Resetting the multi-modal cache when requests are " + "in progress may lead to desynced internal caches.") + + self.mm_input_cache_server.reset() + def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -322,6 +340,13 @@ def collective_rpc(self, return self.model_executor.collective_rpc(method, timeout, args, kwargs) + def save_tensorized_model( + self, + tensorizer_config, + ) -> None: + self.model_executor.save_tensorized_model( + tensorizer_config=tensorizer_config, ) + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -330,9 +355,9 @@ class EngineCoreProc(EngineCore): def __init__( self, - input_path: str, - output_path: str, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -342,28 +367,91 @@ def __init__( executor_fail_callback = lambda: input_queue.put_nowait( (EngineCoreRequestType.EXECUTOR_FAILED, b'')) - super().__init__(vllm_config, executor_class, log_stats, - executor_fail_callback) - - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - self.engines_running = False - - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = input_queue - self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() - threading.Thread(target=self.process_input_socket, - args=(input_path, engine_index), - daemon=True).start() - self.output_thread = threading.Thread( - target=self.process_output_socket, - args=(output_path, engine_index), - daemon=True) - self.output_thread.start() + # Create input socket. + input_ctx = zmq.Context() + identity = engine_index.to_bytes(length=2, byteorder="little") + input_socket = make_zmq_socket(input_ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False) + try: + # Register engine with front-end. + output_address = self.startup_handshake( + input_socket, on_head_node, vllm_config.parallel_config) + + # Update config which may have changed from the handshake. + vllm_config.__post_init__() + + # Set up data parallel environment. + self._init_data_parallel(vllm_config) + + # Initialize engine core and model. + super().__init__(vllm_config, executor_class, log_stats, + executor_fail_callback) + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + self.engines_running = False + + # Send ready message. + num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks + input_socket.send( + msgspec.msgpack.encode({ + "status": "READY", + "local": on_head_node, + "num_gpu_blocks": num_gpu_blocks, + })) + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() + threading.Thread(target=self.process_input_socket, + args=(input_socket, ), + daemon=True).start() + input_socket = None + self.output_thread = threading.Thread( + target=self.process_output_socket, + args=(output_address, engine_index), + daemon=True) + self.output_thread.start() + finally: + if input_socket is not None: + input_socket.close(linger=0) + + @staticmethod + def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, + parallel_config: ParallelConfig) -> str: + + # Send registration message. + input_socket.send( + msgspec.msgpack.encode({ + "status": "HELLO", + "local": on_head_node, + })) + + # Receive initialization message. + logger.info("Waiting for init message from front-end.") + if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + raise RuntimeError("Did not receive response from front-end " + f"process within {HANDSHAKE_TIMEOUT_MINS} " + f"minutes") + init_bytes = input_socket.recv() + init_message = msgspec.msgpack.decode(init_bytes) + logger.debug("Received init message: %s", init_message) + + output_socket_address = init_message["output_socket_address"] + #TBD(nick) maybe replace IP with configured head node address + + received_parallel_config = init_message["parallel_config"] + for key, value in received_parallel_config.items(): + setattr(parallel_config, key, value) + + return output_socket_address @staticmethod def run_engine_core(*args, @@ -394,7 +482,7 @@ def signal_handler(signum, frame): try: parallel_config: ParallelConfig = kwargs[ "vllm_config"].parallel_config - if parallel_config.data_parallel_size > 1: + if parallel_config.data_parallel_size > 1 or dp_rank > 0: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank parallel_config.data_parallel_rank_local = local_dp_rank @@ -418,6 +506,9 @@ def signal_handler(signum, frame): if engine_core is not None: engine_core.shutdown() + def _init_data_parallel(self, vllm_config: VllmConfig): + pass + def run_busy_loop(self): """Core busy loop of the EngineCore.""" @@ -509,40 +600,25 @@ def _send_engine_dead(self): logger.fatal("vLLM shutdown signal from EngineCore failed " "to send. Please report this issue.") - def process_input_socket(self, input_path: str, engine_index: int): + def process_input_socket(self, input_socket: zmq.Socket): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - identity = engine_index.to_bytes(length=2, byteorder="little") - - with zmq_socket_ctx(input_path, - zmq.DEALER, - identity=identity, - bind=False) as socket: - - # Send ready message to front-end once input socket is connected. - message_dict = { - 'type': 'READY', - 'num_gpu_blocks': self.vllm_config.cache_config.num_gpu_blocks, - } - message = json.dumps(message_dict).encode('utf-8') - socket.send(message) - while True: - # (RequestType, RequestData) - type_frame, *data_frames = socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) + while True: + # (RequestType, RequestData) + type_frame, *data_frames = input_socket.recv_multipart(copy=False) + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type - == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frames) + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frames) - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) def process_output_socket(self, output_path: str, engine_index: int): """Output socket IO thread.""" @@ -591,9 +667,9 @@ class DPEngineCoreProc(EngineCoreProc): def __init__( self, - input_path: str, - output_path: str, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -605,34 +681,37 @@ def __init__( _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) - dp_size = vllm_config.parallel_config.data_parallel_size + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + + # Initialize the engine. + dp_rank = vllm_config.parallel_config.data_parallel_rank + super().__init__(vllm_config, on_head_node, input_address, + executor_class, log_stats, dp_rank) + + def _init_data_parallel(self, vllm_config: VllmConfig): + + # Configure GPUs and stateless process group for data parallel. dp_rank = vllm_config.parallel_config.data_parallel_rank + dp_size = vllm_config.parallel_config.data_parallel_size local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local assert dp_size > 1 assert 0 <= local_dp_rank <= dp_rank < dp_size from vllm.platforms import current_platform - if current_platform.is_cuda_alike(): - from vllm.platforms.cuda import device_id_to_physical_device_id - tp_size = vllm_config.parallel_config.tensor_parallel_size - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - str(device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * - tp_size)) - - self.local_dp_rank = local_dp_rank + device_control_env_var = current_platform.device_control_env_var + world_size = vllm_config.parallel_config.world_size + os.environ[device_control_env_var] = ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * + world_size)) + + self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.current_wave = 0 - # Initialize the engine after setting up environment. - super().__init__(input_path, output_path, vllm_config, executor_class, - log_stats, dp_rank) - - # Counts forward-passes of the model so that we can synchronize - # finished with DP peers every N steps. - self.counter = 0 - def shutdown(self): super().shutdown() if dp_group := getattr(self, "dp_group", None): @@ -702,7 +781,7 @@ def run_busy_loop(self): local_unfinished_reqs) if not self.engines_running: - if self.local_dp_rank == 0: + if self.dp_rank == 0: # Notify client that we are pausing the loop. logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 91a0a75a3081..0d52bc9a6814 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio import contextlib -import json import queue import uuid import weakref @@ -9,25 +8,27 @@ from collections import deque from collections.abc import Awaitable, Sequence from concurrent.futures import Future -from dataclasses import dataclass, field +from dataclasses import dataclass +from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union +import msgspec import zmq import zmq.asyncio -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, - make_zmq_socket) +from vllm.utils import (get_open_port, get_open_zmq_inproc_path, + get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import BackgroundProcHandle +from vllm.v1.utils import CoreEngineProcManager logger = init_logger(__name__) @@ -88,6 +89,9 @@ def add_request(self, request: EngineCoreRequest) -> None: def profile(self, is_start: bool = True) -> None: raise NotImplementedError + def reset_mm_cache(self) -> None: + raise NotImplementedError + def reset_prefix_cache(self) -> None: raise NotImplementedError @@ -143,6 +147,9 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: async def profile_async(self, is_start: bool = True) -> None: raise NotImplementedError + async def reset_mm_cache_async(self) -> None: + raise NotImplementedError + async def reset_prefix_cache_async(self) -> None: raise NotImplementedError @@ -214,6 +221,9 @@ def shutdown(self) -> None: def profile(self, is_start: bool = True) -> None: self.engine_core.profile(is_start) + def reset_mm_cache(self) -> None: + self.engine_core.reset_mm_cache() + def reset_prefix_cache(self) -> None: self.engine_core.reset_prefix_cache() @@ -255,45 +265,22 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + class CoreEngine: """One per data parallel rank.""" - def __init__( - self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - input_path: str, - output_path: str, - index: int = 0, - local_dp_rank: int = 0, - ): + def __init__(self, index: int = 0, local: bool = True): + self.local = local self.index = index self.identity = index.to_bytes(length=2, byteorder="little") - try: - # Start EngineCore in background process. - self.proc_handle = BackgroundProcHandle( - input_path=input_path, - output_path=output_path, - process_name=f"EngineCore_{index}", - target_fn=EngineCoreProc.run_engine_core, - process_kwargs={ - "vllm_config": vllm_config, - "dp_rank": index, - "local_dp_rank": local_dp_rank, - "executor_class": executor_class, - "log_stats": log_stats, - }) - self.num_reqs_in_flight = 0 - finally: - if not hasattr(self, "num_reqs_in_flight"): - # Ensure socket is closed if process fails to start. - self.close() - - def close(self): - if proc_handle := getattr(self, "proc_handle", None): - proc_handle.shutdown() + self.state = CoreEngineState.NEW + self.num_reqs_in_flight = 0 @dataclass @@ -302,7 +289,7 @@ class BackgroundResources: circular reference back to the client object.""" ctx: Union[zmq.Context] - core_engines: list[CoreEngine] = field(default_factory=list) + local_engine_manager: Optional[CoreEngineProcManager] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_queue_task: Optional[asyncio.Task] = None @@ -316,8 +303,8 @@ def __call__(self): """Clean up background resources.""" self.engine_dead = True - for core_engine in self.core_engines: - core_engine.close() + if self.local_engine_manager is not None: + self.local_engine_manager.close() if self.output_queue_task is not None: self.output_queue_task.cancel() @@ -379,25 +366,56 @@ def __init__( self._finalizer = weakref.finalize(self, self.resources) success = False try: - # Paths and sockets for IPC. - self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(self.ctx, - input_path, - zmq.ROUTER, - bind=True) - self.resources.input_socket = self.input_socket - - new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, input_path, self. - output_path, index, local_dp_rank) - - # Start engine core process(es). - self._init_core_engines(vllm_config, new_core_engine, - self.resources.core_engines) + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + local_start_index = parallel_config.data_parallel_rank_local + + # SPMD mode is where there is an LLM instance per DP rank and + # one core engine per LLM, see + # examples/offline_inference/data_parallel.py. + spmd_mode = local_start_index is not None + if spmd_mode: + assert local_engine_count == 1 + self.core_engines = [ + CoreEngine(index=local_start_index, local=True) + ] + else: + assert start_index == 0 + local_start_index = 0 + self.core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(parallel_config.data_parallel_size) + ] + + input_address, output_address = self._get_zmq_addresses( + parallel_config, spmd_mode) + + # Create input and output sockets. + self.input_socket = self.resources.input_socket = make_zmq_socket( + self.ctx, input_address, zmq.ROUTER, bind=True) + + self.resources.output_socket = make_zmq_socket( + self.ctx, output_address, zmq.constants.PULL) + # Start local engines. + if local_engine_count: + # In server mode, start_index and local_start_index will + # both be 0. + self.resources.local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + input_address=input_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index) + + self.core_engine = self.core_engines[0] # Wait for engine core process(es) to start. - self._wait_for_engine_startup() + self._wait_for_engine_startup(output_address, parallel_config) self.utility_results: dict[int, AnyFuture] = {} @@ -411,56 +429,116 @@ def __init__( if not success: self._finalizer() - def _wait_for_engine_startup(self): + @staticmethod + def _get_zmq_addresses(parallel_config: ParallelConfig, + spmd_mode: bool) -> tuple[str, str]: + """Returns (input_address, output_address).""" + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + + if local_engine_count == dp_size or spmd_mode: + input_address = get_open_zmq_ipc_path() + output_address = get_open_zmq_ipc_path() + else: + host = parallel_config.data_parallel_master_ip + input_port = parallel_config.data_parallel_rpc_port + output_port = get_open_port() + input_address = get_tcp_uri(host, input_port) + output_address = get_tcp_uri(host, output_port) + + return input_address, output_address + + def _wait_for_engine_startup(self, output_address: str, + parallel_config: ParallelConfig): # Get a sync handle to the socket which can be sync or async. sync_input_socket = zmq.Socket.shadow(self.input_socket) # Wait for engine core process(es) to send ready messages. - identities = set(eng.index for eng in self.resources.core_engines) + local_count = parallel_config.data_parallel_size_local + remote_count = len(self.core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() poller.register(sync_input_socket, zmq.POLLIN) - for eng in self.resources.core_engines: - poller.register(eng.proc_handle, zmq.POLLIN) - while identities: + proc_manager = self.resources.local_engine_manager + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): events = poller.poll(STARTUP_POLL_PERIOD_MS) if not events: - logger.debug("Waiting for %d core engine proc(s) to start: %s", - len(identities), identities) + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) continue if len(events) > 1 or events[0][0] != sync_input_socket: - # One of the core processes exited. + # One of the local core processes exited. + finished = proc_manager.finished_procs( + ) if proc_manager else {} raise RuntimeError("Engine core initialization failed. " - "See root cause above.") - - eng_id_bytes, data = sync_input_socket.recv_multipart() - eng_id = int.from_bytes(eng_id_bytes, byteorder="little") - if eng_id not in identities: - raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}") - message_dict = json.loads(data.decode('utf-8')) - if message_dict['type'] != 'READY': - raise RuntimeError(f"Engine {eng_id} failed: {data.decode()}") - logger.info("Core engine process %d ready.", eng_id) - identities.discard(eng_id) - # Setup KV cache config with initialization state from - # engine core process. Sum values from all engines in DP case. - num_gpu_blocks = self.vllm_config.cache_config.num_gpu_blocks or 0 - num_gpu_blocks += message_dict['num_gpu_blocks'] - self.vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks - - def _init_core_engines( - self, - vllm_config: VllmConfig, - new_core_engine: Callable[[int, Optional[int]], CoreEngine], - core_engines: list[CoreEngine], - ) -> None: - - # Default case - single core engine. - core_engine = new_core_engine( - vllm_config.parallel_config.data_parallel_rank, - vllm_config.parallel_config.data_parallel_rank_local, - ) - core_engines.append(core_engine) - self.core_engine = core_engine + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, byteorder="little") + engine = next( + (e for e in self.core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = self.encoder.encode({ + "output_socket_address": output_address, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": + parallel_config.data_parallel_size, + }, + }) + sync_input_socket.send_multipart((eng_identity, *init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state + == CoreEngineState.CONNECTED): + # Setup KV cache config with initialization state from + # engine core process. Sum values from all engines in DP case. + cache_config = self.vllm_config.cache_config + num_gpu_blocks = cache_config.num_gpu_blocks or 0 + num_gpu_blocks += msg['num_gpu_blocks'] + cache_config.num_gpu_blocks = num_gpu_blocks + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) def shutdown(self): # Terminate background resources. @@ -511,7 +589,8 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. ctx = self.ctx - output_path = self.output_path + out_socket = self.resources.output_socket + assert out_socket is not None decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue @@ -522,7 +601,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) - out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() @@ -557,6 +635,9 @@ def process_outputs_socket(): daemon=True) self.output_queue_thread.start() + # The thread takes on responsibility for closing the socket. + self.resources.output_socket = None + def get_output(self) -> EngineCoreOutputs: # If an exception arises in process_outputs_socket task, # it is forwarded to the outputs_queue so we can raise it @@ -600,6 +681,9 @@ def abort_requests(self, request_ids: list[str]) -> None: def profile(self, is_start: bool = True) -> None: self.call_utility("profile", is_start) + def reset_mm_cache(self) -> None: + self.call_utility("reset_mm_cache") + def reset_prefix_cache(self) -> None: self.call_utility("reset_prefix_cache") @@ -681,10 +765,8 @@ def _ensure_output_queue_task(self): self.__class__, "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None - output_path = self.output_path - output_socket = make_zmq_socket(self.ctx, output_path, - zmq.constants.PULL) - resources.output_socket = output_socket + output_socket = resources.output_socket + assert output_socket is not None async def process_outputs_socket(): try: @@ -787,6 +869,9 @@ async def abort_requests_async(self, request_ids: list[str]) -> None: async def profile_async(self, is_start: bool = True) -> None: await self.call_utility_async("profile", is_start) + async def reset_mm_cache_async(self) -> None: + await self.call_utility_async("reset_mm_cache") + async def reset_prefix_cache_async(self) -> None: await self.call_utility_async("reset_prefix_cache") @@ -846,21 +931,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], assert len(self.core_engines) > 1 - def _init_core_engines( - self, - vllm_config: VllmConfig, - new_core_engine: Callable[[int, Optional[int]], CoreEngine], - core_engines: list[CoreEngine], - ) -> None: - - # Launch a core engine for each data parallel rank. - dp_size = vllm_config.parallel_config.data_parallel_size - for i in range(dp_size): - # Multi-node not yet supported so local_dp_rank == dp_rank. - core_engines.append(new_core_engine(i, i)) - - self.core_engines = core_engines - async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index b471b153657f..c856e2645a2c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -27,7 +27,10 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerFactory +from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase, + StatLoggerFactory) +from vllm.v1.metrics.reader import Metric, get_metrics_snapshot +from vllm.v1.metrics.stats import IterationStats logger = init_logger(__name__) @@ -64,6 +67,11 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + self.log_stats = log_stats + self.stat_logger: Optional[StatLoggerBase] = None + if self.log_stats: + self.stat_logger = PrometheusStatLogger(vllm_config) + # important: init dp group before init the engine_core # In the decoupled engine case this is handled in EngineCoreProc. parallel_config = vllm_config.parallel_config @@ -86,7 +94,7 @@ def __init__( # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor(self.tokenizer, - log_stats=False) + log_stats=self.log_stats) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) self.engine_core = EngineCoreClient.make_client( @@ -94,13 +102,16 @@ def __init__( asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, - log_stats=False, # FIXME: implement + log_stats=self.log_stats, ) if not multiprocess_mode: # for v0 compatibility self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + # Don't keep the dummy data in memory + self.reset_mm_cache() + @classmethod def from_vllm_config( cls, @@ -220,12 +231,21 @@ def step(self) -> list[RequestOutput]: outputs = self.engine_core.get_output() # 2) Process EngineCoreOutputs. + iteration_stats = IterationStats() if self.log_stats else None processed_outputs = self.output_processor.process_outputs( - outputs.outputs) + outputs.outputs, + engine_core_timestamp=outputs.timestamp, + iteration_stats=iteration_stats) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) + # 4) Record stats + if self.stat_logger is not None: + assert outputs.scheduler_stats is not None + self.stat_logger.record(scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats) + return processed_outputs.request_outputs def get_vllm_config(self): @@ -240,6 +260,11 @@ def start_profile(self): def stop_profile(self): self.engine_core.profile(False) + def reset_mm_cache(self): + self.processor.mm_registry.reset_processor_cache() + self.processor.mm_input_cache_client.reset() + self.engine_core.reset_mm_cache() + def reset_prefix_cache(self, device: Optional[Device] = None): self.engine_core.reset_prefix_cache() @@ -252,6 +277,10 @@ def wake_up(self, tags: Optional[list[str]] = None): def is_sleeping(self) -> bool: return self.engine_core.is_sleeping() + def get_metrics(self) -> list[Metric]: + assert self.log_stats, "Stat logging disabled" + return get_metrics_snapshot() + def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 64ece840fc4c..fcb90bebdb62 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -83,3 +83,8 @@ def get_and_update_p1( full_mm_inputs.append(mm_input) return full_mm_inputs + + def reset(self) -> bool: + self.mm_cache.clear() + + return True diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5f5ffe6e09db..293c291b4341 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional, Union +from typing import Any, Optional, Union from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind @@ -146,6 +146,8 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + kv_transfer_params: Optional[dict[str, Any]] = None, + num_cached_tokens: int = 0, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -167,13 +169,16 @@ def make_request_output( if not outputs: return None - return self._new_request_output(request_id, outputs, finished) + return self._new_request_output(request_id, outputs, finished, + kv_transfer_params, num_cached_tokens) def _new_request_output( self, request_id: str, outputs: list[CompletionOutput], finished: bool, + kv_transfer_params: Optional[dict[str, Any]] = None, + num_cached_tokens: int = 0, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -189,6 +194,8 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=outputs, finished=finished, + kv_transfer_params=kv_transfer_params, + num_cached_tokens=num_cached_tokens, ) def _new_completion_output( @@ -335,7 +342,8 @@ def process_outputs( new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason - + kv_transfer_params = engine_core_output.kv_transfer_params + num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False # 2) Detokenize the token ids into text and perform stop checks. @@ -350,7 +358,8 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, finish_reason, stop_reason, + kv_transfer_params, num_cached_tokens): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 27d70a781471..64a756148780 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -54,6 +54,10 @@ def __init__( self.use_hash = self.mm_input_cache_client.use_cache or \ self.cache_config.enable_prefix_caching + @property + def mm_registry(self): + return self.input_preprocessor.mm_registry + def _validate_logprobs( self, params: SamplingParams, @@ -74,6 +78,7 @@ def _validate_logprobs( def _validate_sampling_params( self, params: SamplingParams, + lora_request: Optional[LoRARequest], ) -> None: self._validate_structured_output(params) self._validate_logit_bias(params) @@ -82,7 +87,8 @@ def _validate_sampling_params( return if not params.allowed_token_ids: raise ValueError("allowed_token_ids is not None and empty!") - vocab_size = self.model_config.get_vocab_size() + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + vocab_size = len(tokenizer) if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): raise ValueError( "allowed_token_ids contains out-of-vocab token id!") @@ -122,6 +128,7 @@ def _validate_supported_sampling_params( def _validate_params( self, params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest], ): """ Validate supported SamplingParam. @@ -132,7 +139,7 @@ def _validate_params( raise ValueError("V1 does not yet support Pooling models.") self._validate_logprobs(params) - self._validate_sampling_params(params) + self._validate_sampling_params(params, lora_request) self._validate_supported_sampling_params(params) def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: @@ -185,8 +192,10 @@ def _validate_structured_output(self, params: SamplingParams) -> None: validate_xgrammar_grammar(params) params.guided_decoding.backend = "xgrammar" except ValueError: - # The request includes some jsonschema feature(s) that + # The request either failed validation + # or includes some jsonschema feature(s) that # are not supported in xgrammar. Fall back to guidance. + validate_guidance_grammar(params, tokenizer=None) params.guided_decoding.backend = "guidance" # Remember that this backend was set automatically params.guided_decoding.backend_was_auto = True @@ -207,7 +216,7 @@ def process_inputs( # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) - self._validate_params(params) + self._validate_params(params, lora_request) if priority != 0: raise ValueError("V1 does not support priority yet.") if trace_headers is not None: diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 74b226b45424..eb5f9d4bfe00 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -38,7 +38,7 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -EXECUTE_MODEL_TIMEOUT_S = 40 +EXECUTE_MODEL_TIMEOUT_S = 300 class MultiprocExecutor(Executor): @@ -50,6 +50,7 @@ def _init_executor(self) -> None: self.is_failed = False self.shutdown_event = threading.Event() self.failure_callback: Optional[FailureCallback] = None + self.io_thread_pool: Optional[ThreadPoolExecutor] = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size @@ -107,7 +108,6 @@ def _init_executor(self) -> None: # For pipeline parallel, we use a thread pool for asynchronous # execute_model. - self.io_thread_pool: Optional[ThreadPoolExecutor] = None if self.max_concurrent_batches > 1: # Note: must use only 1 IO thread to keep dequeue sequence # from the response queue diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4fc0844cd1f4..2747fc7fabd1 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +import copy from dataclasses import dataclass +from typing import Optional import torch +from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger @@ -53,6 +56,16 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: """ raise NotImplementedError + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of KVCacheSpec objects into a single KVCacheSpec object. + """ + assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), ( + "All layers in the same KV cache group must share the same " + "type_id.") + return copy.deepcopy(specs[0]) + @dataclass class AttentionSpec(KVCacheSpec): @@ -71,6 +84,16 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): + sliding_window: Optional[int] = None + """ + When hybrid allocator is disabled and the model contains both full + attention layers and sliding window attention layers, sliding + window attention are regarded as full attention in KV cache manager + (blocks are allocated for all tokens), while computed as sliding window + attention in model runner. + In this case, we use FullAttentionSpec and record the sliding window size. + Default to None for not using sliding window attention. + """ @property def type_id(self) -> str: @@ -80,6 +103,25 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + merged_spec = super().merge(specs) + sliding_window = set(spec.sliding_window for spec in specs + if spec.sliding_window is not None) + if len(sliding_window) == 0: + merged_spec.sliding_window = None + elif len(sliding_window) == 1: + merged_spec.sliding_window = sliding_window.pop() + else: + raise ValueError( + "All sliding window layers in the same KV cache group " + "must have the same window size.") + return merged_spec + @dataclass class SlidingWindowSpec(AttentionSpec): diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 9109bdcf42f2..3dc2f77444f6 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -128,9 +128,7 @@ def log(self): scheduler_stats.gpu_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, ) - - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.log(log_fn=log_fn) + self.spec_decoding_logging.log(log_fn=log_fn) def log_engine_initialized(self): logger.info( @@ -140,6 +138,10 @@ def log_engine_initialized(self): class PrometheusStatLogger(StatLoggerBase): + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram + _spec_decoding_cls = SpecDecodingProm def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self._unregister_vllm_metrics() @@ -158,18 +160,18 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): max_model_len = vllm_config.model_config.max_model_len - self.spec_decoding_prom = SpecDecodingProm( + self.spec_decoding_prom = self._spec_decoding_cls( vllm_config.speculative_config, labelnames, labelvalues) # # Scheduler state # - self.gauge_scheduler_running = prometheus_client.Gauge( + self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", labelnames=labelnames).labels(*labelvalues) - self.gauge_scheduler_waiting = prometheus_client.Gauge( + self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames).labels(*labelvalues) @@ -177,45 +179,45 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # # GPU cache # - self.gauge_gpu_cache_usage = prometheus_client.Gauge( + self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames).labels(*labelvalues) - self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( + self.counter_gpu_prefix_cache_queries = self._counter_cls( name="vllm:gpu_prefix_cache_queries", documentation= - "GPU prefix cache queries, in terms of number of queried blocks.", + "GPU prefix cache queries, in terms of number of queried tokens.", labelnames=labelnames).labels(*labelvalues) - self.counter_gpu_prefix_cache_hits = prometheus_client.Counter( + self.counter_gpu_prefix_cache_hits = self._counter_cls( name="vllm:gpu_prefix_cache_hits", documentation= - "GPU prefix cache hits, in terms of number of cached blocks.", + "GPU prefix cache hits, in terms of number of cached tokens.", labelnames=labelnames).labels(*labelvalues) # # Counters # - self.counter_num_preempted_reqs = prometheus_client.Counter( - name="vllm:num_preemptions_total", + self.counter_num_preempted_reqs = self._counter_cls( + name="vllm:num_preemptions", documentation="Cumulative number of preemption from the engine.", labelnames=labelnames).labels(*labelvalues) - self.counter_prompt_tokens = prometheus_client.Counter( - name="vllm:prompt_tokens_total", + self.counter_prompt_tokens = self._counter_cls( + name="vllm:prompt_tokens", documentation="Number of prefill tokens processed.", labelnames=labelnames).labels(*labelvalues) - self.counter_generation_tokens = prometheus_client.Counter( - name="vllm:generation_tokens_total", + self.counter_generation_tokens = self._counter_cls( + name="vllm:generation_tokens", documentation="Number of generation tokens processed.", labelnames=labelnames).labels(*labelvalues) self.counter_request_success: dict[FinishReason, prometheus_client.Counter] = {} - counter_request_success_base = prometheus_client.Counter( - name="vllm:request_success_total", + counter_request_success_base = self._counter_cls( + name="vllm:request_success", documentation="Count of successfully processed requests.", labelnames=labelnames + ["finished_reason"]) for reason in FinishReason: @@ -227,21 +229,21 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # Histograms of counts # self.histogram_num_prompt_tokens_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) self.histogram_num_generation_tokens_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_generation_tokens", documentation="Number of generation tokens processed.", buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) self.histogram_iteration_tokens = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", buckets=[ @@ -251,7 +253,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): labelnames=labelnames).labels(*labelvalues) self.histogram_max_num_generation_tokens_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_max_num_generation_tokens", documentation= "Histogram of maximum number of requested generation tokens.", @@ -259,14 +261,14 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): labelnames=labelnames).labels(*labelvalues) self.histogram_n_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", buckets=[1, 2, 5, 10, 20], labelnames=labelnames).labels(*labelvalues) self.histogram_max_tokens_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_params_max_tokens", documentation="Histogram of the max_tokens request parameter.", buckets=build_1_2_5_buckets(max_model_len), @@ -276,7 +278,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # Histogram of timing intervals # self.histogram_time_to_first_token = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", buckets=[ @@ -287,7 +289,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): labelnames=labelnames).labels(*labelvalues) self.histogram_time_per_output_token = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:time_per_output_token_seconds", documentation="Histogram of time per output token in seconds.", buckets=[ @@ -301,34 +303,34 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 ] self.histogram_e2e_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of e2e request latency in seconds.", buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_queue_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_queue_time_seconds", documentation= "Histogram of time spent in WAITING phase for request.", buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_inference_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_inference_time_seconds", documentation= "Histogram of time spent in RUNNING phase for request.", buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_prefill_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_prefill_time_seconds", documentation= "Histogram of time spent in PREFILL phase for request.", buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_decode_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_decode_time_seconds", documentation= "Histogram of time spent in DECODE phase for request.", @@ -345,7 +347,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.labelname_running_lora_adapters = "running_lora_adapters" self.max_lora = vllm_config.lora_config.max_loras self.gauge_lora_info = \ - prometheus_client.Gauge( + self._gauge_cls( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", labelnames=[ @@ -367,7 +369,7 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): # Info type metrics are syntactic sugar for a gauge permanently set to 1 # Since prometheus multiprocessing mode does not support Info, emulate # info here with a gauge. - info_gauge = prometheus_client.Gauge( + info_gauge = self._gauge_cls( name=name, documentation=documentation, labelnames=metrics_info.keys()).labels(**metrics_info) diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py new file mode 100644 index 000000000000..a51c3ed7f572 --- /dev/null +++ b/vllm/v1/metrics/ray_wrappers.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +import time +from typing import Optional, Union + +from vllm.config import VllmConfig +from vllm.v1.metrics.loggers import PrometheusStatLogger +from vllm.v1.spec_decode.metrics import SpecDecodingProm + +try: + from ray.util import metrics as ray_metrics + from ray.util.metrics import Metric +except ImportError: + ray_metrics = None + + +class RayPrometheusMetric: + + def __init__(self): + if ray_metrics is None: + raise ImportError( + "RayPrometheusMetric requires Ray to be installed.") + + self.metric: Metric = None + + def labels(self, *labels, **labelskwargs): + if labelskwargs: + for k, v in labelskwargs.items(): + if not isinstance(v, str): + labelskwargs[k] = str(v) + + self.metric.set_default_tags(labelskwargs) + + return self + + +class RayGaugeWrapper(RayPrometheusMetric): + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self.metric = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def set(self, value: Union[int, float]): + return self.metric.set(value) + + def set_to_current_time(self): + # ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html + return self.metric.set(time.time()) + + +class RayCounterWrapper(RayPrometheusMetric): + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self.metric = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self.metric.inc(value) + + +class RayHistogramWrapper(RayPrometheusMetric): + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + buckets: Optional[list[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + boundaries = buckets if buckets else [] + self.metric = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries) + + def observe(self, value: Union[int, float]): + return self.metric.observe(value) + + +class RaySpecDecodingProm(SpecDecodingProm): + """ + RaySpecDecodingProm is used by RayMetrics to log to Ray metrics. + Provides the same metrics as SpecDecodingProm but uses Ray's + util.metrics library. + """ + + _counter_cls = RayCounterWrapper + + +class RayPrometheusStatLogger(PrometheusStatLogger): + """RayPrometheusStatLogger uses Ray metrics instead.""" + + _gauge_cls = RayGaugeWrapper + _counter_cls = RayCounterWrapper + _histogram_cls = RayHistogramWrapper + _spec_decoding_cls = RaySpecDecodingProm + + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + super().__init__(vllm_config, engine_index) + + @staticmethod + def _unregister_vllm_metrics(): + # No-op on purpose + pass diff --git a/vllm/v1/metrics/reader.py b/vllm/v1/metrics/reader.py new file mode 100644 index 000000000000..5ab78129a009 --- /dev/null +++ b/vllm/v1/metrics/reader.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Optional + +from prometheus_client import REGISTRY +from prometheus_client import Metric as PromMetric +from prometheus_client.samples import Sample + + +@dataclass +class Metric: + """A base class for prometheus metrics. + + Each metric may be associated with key=value labels, and + in some cases a single vLLM instance may have multiple + metrics with the same name but different sets of labels. + """ + name: str + labels: dict[str, str] + + +@dataclass +class Counter(Metric): + """A monotonically increasing integer counter.""" + value: int + + +@dataclass +class Vector(Metric): + """An ordered array of integer counters. + + This type - which doesn't exist in Prometheus - models one very + specific metric, vllm:spec_decode_num_accepted_tokens_per_pos. + """ + values: list[int] + + +@dataclass +class Gauge(Metric): + """A numerical value that can go up or down.""" + value: float + + +@dataclass +class Histogram(Metric): + """Observations recorded in configurable buckets. + + Buckets are represented by a dictionary. The key is + the upper limit of the bucket, and the value is the + observed count in that bucket. A '+Inf' key always + exists. + + The count property is the total count across all + buckets, identical to the count of the '+Inf' bucket. + + The sum property is the total sum of all observed + values. + """ + count: int + sum: float + buckets: dict[str, int] + + +def get_metrics_snapshot() -> list[Metric]: + """An API for accessing in-memory Prometheus metrics. + + Example: + >>> for metric in llm.get_metrics(): + ... if isinstance(metric, Counter): + ... print(f"{metric} = {metric.value}") + ... elif isinstance(metric, Gauge): + ... print(f"{metric} = {metric.value}") + ... elif isinstance(metric, Histogram): + ... print(f"{metric}") + ... print(f" sum = {metric.sum}") + ... print(f" count = {metric.count}") + ... for bucket_le, value in metrics.buckets.items(): + ... print(f" {bucket_le} = {value}") + """ + collected: list[Metric] = [] + for metric in REGISTRY.collect(): + if not metric.name.startswith("vllm:"): + continue + if metric.type == "gauge": + samples = _get_samples(metric) + for s in samples: + collected.append( + Gauge(name=metric.name, labels=s.labels, value=s.value)) + elif metric.type == "counter": + samples = _get_samples(metric, "_total") + if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + # + # Ugly vllm:num_accepted_tokens_per_pos special case. + # + # This metric is a vector of counters - for each spec + # decoding token position, we observe the number of + # accepted tokens using a Counter labeled with 'position'. + # We convert these into a vector of integer values. + # + for labels, values in _digest_num_accepted_by_pos_samples( + samples): + collected.append( + Vector(name=metric.name, labels=labels, values=values)) + else: + for s in samples: + collected.append( + Counter(name=metric.name, + labels=s.labels, + value=int(s.value))) + + elif metric.type == "histogram": + # + # A histogram has a number of '_bucket' samples where + # the 'le' label represents the upper limit of the bucket. + # We convert these bucketized values into a dict of values + # indexed by the value of the 'le' label. The 'le=+Inf' + # label is a special case, catching all values observed. + # + bucket_samples = _get_samples(metric, "_bucket") + count_samples = _get_samples(metric, "_count") + sum_samples = _get_samples(metric, "_sum") + for labels, buckets, count_value, sum_value in _digest_histogram( + bucket_samples, count_samples, sum_samples): + collected.append( + Histogram(name=metric.name, + labels=labels, + buckets=buckets, + count=count_value, + sum=sum_value)) + else: + raise AssertionError(f"Unknown metric type {metric.type}") + + return collected + + +def _get_samples(metric: PromMetric, + suffix: Optional[str] = None) -> list[Sample]: + name = (metric.name + suffix) if suffix is not None else metric.name + return [s for s in metric.samples if s.name == name] + + +def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]: + labels_copy = labels.copy() + labels_copy.pop(key_to_remove) + return labels_copy + + +def _digest_histogram( + bucket_samples: list[Sample], count_samples: list[Sample], + sum_samples: list[Sample] +) -> list[tuple[dict[str, str], dict[str, int], int, float]]: + # + # In the case of DP, we have an indigestable + # per-bucket-per-engine count as a list of labelled + # samples, along with total and sum samples + # + # bucket_samples (in): + # labels = {bucket: 100, idx: 0}, value = 2 + # labels = {bucket: 200, idx: 0}, value = 4 + # labels = {bucket: Inf, idx: 0}, value = 10 + # labels = {bucket: 100, idx: 1}, value = 1 + # labels = {bucket: 200, idx: 2}, value = 5 + # labels = {bucket: Inf, idx: 3}, value = 7 + # count_samples (in): + # labels = {idx: 0}, value = 10 + # labels = {idx: 1}, value = 7 + # sum_samples (in): + # labels = {idx: 0}, value = 2000 + # labels = {idx: 1}, value = 1200 + # + # output: [ + # {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000 + # {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200 + # ] + buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {} + for s in bucket_samples: + bucket = s.labels["le"] + labels_key = frozenset(_strip_label(s.labels, "le").items()) + if labels_key not in buckets_by_labels: + buckets_by_labels[labels_key] = {} + buckets_by_labels[labels_key][bucket] = int(s.value) + + counts_by_labels: dict[frozenset[tuple[str, str]], int] = {} + for s in count_samples: + labels_key = frozenset(s.labels.items()) + counts_by_labels[labels_key] = int(s.value) + + sums_by_labels: dict[frozenset[tuple[str, str]], float] = {} + for s in sum_samples: + labels_key = frozenset(s.labels.items()) + sums_by_labels[labels_key] = s.value + + assert set(buckets_by_labels.keys()) == set( + counts_by_labels.keys()) == set(sums_by_labels.keys()) + + output = [] + label_keys = list(buckets_by_labels.keys()) + for k in label_keys: + labels = dict(k) + output.append((labels, buckets_by_labels[k], counts_by_labels[k], + sums_by_labels[k])) + return output + + +def _digest_num_accepted_by_pos_samples( + samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]: + # + # In the case of DP, we have an indigestable + # per-position-per-engine count as a list of + # labelled samples + # + # samples (in): + # labels = {pos: 0, idx: 0}, value = 10 + # labels = {pos: 1, idx: 0}, value = 7 + # labels = {pos: 2, idx: 0}, value = 2 + # labels = {pos: 0, idx: 1}, value = 5 + # labels = {pos: 1, idx: 1}, value = 3 + # labels = {pos: 2, idx: 1}, value = 1 + # + # output: [ + # {idx: 0}, [10, 7, 2] + # {idx: 1}, [5, 3, 1] + # ] + # + max_pos = 0 + values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {} + + for s in samples: + position = int(s.labels["position"]) + max_pos = max(max_pos, position) + + labels_key = frozenset(_strip_label(s.labels, "position").items()) + if labels_key not in values_by_labels: + values_by_labels[labels_key] = {} + values_by_labels[labels_key][position] = int(s.value) + + output = [] + for labels_key, values_by_position in values_by_labels.items(): + labels = dict(labels_key) + values = [0] * (max_pos + 1) + for pos, val in values_by_position.items(): + values[pos] = val + output.append((labels, values)) + return output diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index fd949264885b..8fe1630616a4 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -19,7 +19,7 @@ class PrefixCacheStats: # The number of requests in this update. requests: int = 0 # The number of queries in these requests. Note that "queries" here - # means the number of blocks that were queried from the cache. + # means the number of tokens that were queried from the cache. queries: int = 0 # The number of hits in these requests. hits: int = 0 diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2732b933c28a..e8ce0df5ed8d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -100,12 +100,16 @@ class ModelRunnerOutput: # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] - -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( - req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, -) + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None + + +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=None, + finished_recving=None) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fde366d61c7d..b4c84507532a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -61,17 +61,26 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 + # P/D: Connector-specific KV transfer parameters. + kv_params = (None if sampling_params.extra_args is None else + sampling_params.extra_args.get("kv_transfer_params")) + self.kv_transfer_params: Optional[dict[str, Any]] = kv_params + # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) if self.mm_hashes: assert len(self.mm_inputs) == len(self.mm_hashes) # Read-only views - # Prevent directly appending to the these lists since + # Prevent directly appending to these lists since # they should also be updated simultaneously. self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) + # State + # The number of tokens with prefix cache hits. + self.num_cached_tokens = -1 + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": if request.mm_inputs is not None: @@ -150,6 +159,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 745b81ded3f1..4a5fbb10d408 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -31,21 +31,10 @@ def __init__(self): if current_platform.is_cuda(): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ - if flashinfer_version >= "0.2.3": - # FIXME(DefTruth): Currently, we have errors when using - # FlashInfer>=v0.2.3 for top-p & top-k sampling. As a - # workaround, we disable FlashInfer for top-p & top-k - # sampling by default while FlashInfer>=v0.2.3. - # The sampling API removes the success return value - # of all sampling API, which is not compatible with - # earlier design. - # https://github.com/flashinfer-ai/flashinfer/releases/ - # tag/v0.2.3 - logger.info( - "Currently, FlashInfer top-p & top-k sampling sampler " - "is disabled because FlashInfer>=v0.2.3 is not " - "backward compatible. Falling back to the PyTorch-" - "native implementation of top-p & top-k sampling.") + if flashinfer_version < "0.2.3": + logger.warning( + "FlashInfer version >= 0.2.3 required. " + "Falling back to default sampling implementation.") self.forward = self.forward_native elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for @@ -100,13 +89,18 @@ def forward_cuda( p: Optional[torch.Tensor], ) -> torch.Tensor: """More optimized implementation for top-k and top-p sampling.""" - probs = logits.softmax(dim=-1, dtype=torch.float32) if k is None and p is None: # We prefer `random_sample` over `flashinfer_sample` when sorting is # not needed. This is because `random_sample` does not require # CPU-GPU synchronization while `flashinfer_sample` does. + probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) - return flashinfer_sample(probs, k, p, generators) + if generators: + logger.warning("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") + return self.forward_native(logits, generators, k, p) + return flashinfer_sample(logits, k, p, generators) def forward_tpu( self, @@ -260,17 +254,17 @@ def random_sample( def flashinfer_sample( - probs: torch.Tensor, + logits: torch.Tensor, k: Optional[torch.Tensor], p: Optional[torch.Tensor], generators: dict[int, torch.Generator], ) -> torch.Tensor: - """Sample from the probabilities using FlashInfer. + """Sample from the logits using FlashInfer. Statistically, this function is equivalent to the `random_sample` function. However, this function is faster because it avoids sorting the logits tensor via rejection sampling. - + NOTE: The outputs of this function do not necessarily match the outputs of the `random_sample` function. It only guarantees that the outputs are statistically equivalent. @@ -280,36 +274,19 @@ def flashinfer_sample( the synchronization overhead. """ assert not (k is None and p is None) - max_top_k_round = 32 - batch_size = probs.shape[0] - uniform_samples = torch.empty((max_top_k_round, batch_size), - device=probs.device) - if len(generators) != batch_size: - uniform_samples.uniform_() - if generators: - for i, generator in generators.items(): - uniform_samples[:, i].uniform_(generator=generator) - if k is None: # Top-p only. - next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( - probs, uniform_samples, p, deterministic=True) + probs = logits.softmax(dim=-1, dtype=torch.float32) + next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( + probs, p, deterministic=True) elif p is None: # Top-k only. - next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( - probs, uniform_samples, k, deterministic=True) + probs = logits.softmax(dim=-1, dtype=torch.float32) + next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( + probs, k, deterministic=True) else: # Both top-k and top-p. - next_token_ids, success = ( - flashinfer.sampling.top_k_top_p_sampling_from_probs( - probs, uniform_samples, k, p, deterministic=True)) - - # NOTE: CPU-GPU synchronization happens here. - if not success.all(): - if k is not None: - probs = flashinfer.sampling.top_k_renorm_prob(probs, k) - if p is not None: - probs = flashinfer.sampling.top_p_renorm_prob(probs, p) - next_token_ids = flashinfer.sampling.sampling_from_probs( - probs, uniform_samples[0], deterministic=True) + next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits( + logits, k, p, deterministic=True) + return next_token_ids.view(-1) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 13cfcc4bbb6e..971b06758c21 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,16 +4,17 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config, set_current_vllm_config) + get_layers_from_vllm_config) +from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.triton_utils import tl, triton -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, + FlashAttentionMetadata) +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel logger = init_logger(__name__) @@ -26,12 +27,15 @@ def __init__( self, vllm_config: VllmConfig, device: torch.device, + runner=None, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + self.runner = runner + self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size @@ -107,24 +111,51 @@ def propose( # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - max_seq_len = seq_lens.max().item() - max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_tokens, - max_query_len=max_num_tokens, - query_start_loc=cu_num_tokens, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=target_slot_mapping, - # TODO(woosuk): Support cascade attention. - use_cascade=False, - common_prefix_len=0, - cu_prefix_query_lens=None, - prefix_kv_lens=None, - suffix_kv_lens=None, - ) + if self.method in ["eagle", "eagle3"]: + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + max_seq_len = seq_lens.max().item() + max_num_tokens = (cu_num_tokens[1:] - + cu_num_tokens[:-1]).max().item() + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_tokens, + max_query_len=max_num_tokens, + query_start_loc=cu_num_tokens, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=target_slot_mapping, + # TODO(woosuk): Support cascade attention. + use_cascade=False, + common_prefix_len=0, + cu_prefix_query_lens=None, + prefix_kv_lens=None, + suffix_kv_lens=None, + ) + elif self.method == "deepseek_mtp": + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=cu_num_tokens, seq_lens=seq_lens) + + assert self.runner is not None + + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_metadata_builder.build( + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise ValueError(f"Unsupported method: {self.method}") + + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -134,14 +165,18 @@ def propose( self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states - with set_forward_context(attn_metadata, + with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], + ret_hidden_states = self.model( + self.input_ids[:num_input_tokens], + self.positions[:num_input_tokens], + self.hidden_states[:num_input_tokens], ) + if self.method == "deepseek_mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) @@ -151,6 +186,10 @@ def propose( # [batch_size, 1] return draft_token_ids.view(-1, 1) + # TODO: Currently, MTP module released by deepseek only has + # one layer. Adapt this code to support multiple layers once + # there's a multi-layer MTP module. + # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -212,17 +251,19 @@ def propose( self.hidden_states[:batch_size] = hidden_states # Run the model. - with set_forward_context(attn_metadata, + with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:input_batch_size], - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], + self.input_ids[:input_batch_size], + self.positions[:input_batch_size], + self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) + + # TODO(wenlong): get more than one token for tree attention draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -236,6 +277,7 @@ def prepare_inputs( cu_target_query_lens: torch.Tensor, # [batch_size] num_rejected_tokens: torch.Tensor, + num_tokens: int, ) -> tuple[torch.Tensor, torch.Tensor]: # cu_target_query_lens: [0, a, a + b, a + b + c] # num_rejected_tokens: [n1, n2, n3] @@ -251,21 +293,18 @@ def prepare_inputs( # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens - cu_num_tokens = torch.empty_like(cu_target_query_lens) + # [a - n1, b - n2, c - n3] -> + # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + cu_num_tokens = torch.zeros_like(cu_target_query_lens) torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - cu_num_tokens[0] = 0 - - # FIXME(woosuk): Avoid synchronization. - num_tokens = cu_num_tokens[-1].item() token_indices = torch.empty( num_tokens, dtype=torch.int32, - device=cu_num_tokens.device, + device=cu_target_query_lens.device, ) - batch_size = num_rejected_tokens.shape[0] BLOCK_SIZE = 1024 - prepare_input_kernel[(batch_size, )]( + prepare_eagle_input_kernel[(batch_size, )]( token_indices, cu_target_query_lens, cu_num_tokens, @@ -274,40 +313,38 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - loader = get_model_loader(self.vllm_config.load_config) - target_layer_num = self.vllm_config.model_config.get_num_layers( - self.vllm_config.parallel_config) + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - # FIXME(lily): This does not handle with distributed inference. - target_device = self.vllm_config.device_config.device - # We need to set the vllm_config here to register attention - # layers in the forward context. - with set_default_torch_dtype( - draft_model_config.dtype), set_current_vllm_config( - self.vllm_config): - draft_model_cls, arch = ModelRegistry.resolve_model_cls( - draft_model_config.architectures) - self.model = draft_model_cls( - vllm_config=self.vllm_config, - start_layer_id=target_layer_num).to(target_device) + self.model = get_model(vllm_config=self.vllm_config, + model_config=draft_model_config) draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) - assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = next(iter(draft_attn_layer_names)) - loaded_weights = self.model.load_weights( - loader.get_all_weights(draft_model_config, self.model)) - if self.vllm_config.speculative_config.method == "eagle3": - if "model.embed_tokens.weight" not in loaded_weights: - logger.info( - "Loading EAGLE embedding weights from the target model.") - self.model.model.embed_tokens = target_model.model.embed_tokens + + self.attn_layer_names = list(draft_attn_layer_names) + + # share embed_tokens with the target model if needed + if get_pp_group().world_size == 1: + logger.info( + "The EAGLE head shares the same vocab embedding" \ + " with the target model." + ) + self.model.model.embed_tokens = target_model.model.embed_tokens else: + logger.info( + "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) + + # share lm_head with the target model if needed + # some model definition do not define lm_head explicitly + # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM + if self.vllm_config.speculative_config.method != "eagle3" and \ + hasattr(target_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_model.lm_head @@ -319,11 +356,30 @@ def dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], + self.input_ids[:num_tokens], + self.positions[:num_tokens], + self.hidden_states[:num_tokens], ) + def validate_same_kv_cache_group(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Validate that all eagle layers belong to the same KVCacheGroup. + Need this assumption to ensure all eagle layers can use the + same AttentionMetadata. + May extend to multiple AttentionMetadata in the future. + """ + kv_cache_groups: dict[str, int] = {} + for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + for layer_name in kv_cache_group.layer_names: + kv_cache_groups[layer_name] = id + assert len( + set([ + kv_cache_groups[layer_name] + for layer_name in self.attn_layer_names + ]) + ) == 1, "All eagle layers should belong to the same kv cache group" + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage @@ -366,29 +422,3 @@ def compute_probs_and_sample_next_token( next_token_ids, ) return next_token_ids, probs - - -@triton.jit -def prepare_input_kernel( - out_ptr, - cu_query_lens_ptr, - cu_num_tokens_ptr, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - # [start_pos, end_pos) - start_pos = tl.load(cu_num_tokens_ptr + pid) - end_pos = tl.load(cu_num_tokens_ptr + pid + 1) - num_tokens = end_pos - start_pos - - index_start = tl.load(cu_query_lens_ptr + pid) - - num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) - for i in tl.range(num_blocks): - offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - tl.store( - out_ptr + start_pos + offset, - index_start + offset, - mask=offset < num_tokens, - ) diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py new file mode 100644 index 000000000000..fdac2ef64c3f --- /dev/null +++ b/vllm/v1/spec_decode/medusa.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.v1.sample.metadata import SamplingMetadata + +# Initialize logger +logger = init_logger(__name__) + + +class MedusaProposer: + """ + Medusa proposer class for generating token sequences + """ + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + # Save config parameters + self.vllm_config = vllm_config + self.device = device + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + self.hidden_size = vllm_config.speculative_config.\ + draft_model_config.get_hidden_size( + ) + self.dtype = vllm_config.model_config.dtype + + def propose( + self, + target_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + # Generate blocks and compute logits + blocks = self.model(target_hidden_states) + logits = self.model.compute_logits(blocks, None) + + # Get draft tokens and transpose the result + draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] + return [list(row) for row in zip(*draft_tokens)] + + def load_model(self, target_model: nn.Module) -> None: + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) + + @torch.inference_mode() + def dummy_run(self, num_tokens: int) -> None: + hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + self.model(hidden_states) diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 33ce98284e20..36091bef2895 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -67,13 +67,17 @@ def observe(self, spec_decoding_stats: SpecDecodingStats): spec_decoding_stats.num_accepted_tokens_per_pos) def log(self, log_fn=logger.info): + if not self.num_drafts: + return num_drafts = np.sum(self.num_drafts) num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * 100 if num_draft_tokens > 0 else float("nan")) - mean_acceptance_length = (num_accepted_tokens / num_drafts) + + # Conventionally, mean acceptance length includes the bonus token + mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts) pos_matrix = np.array(self.accepted_tokens_per_pos_lists) acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts @@ -103,10 +107,12 @@ class SpecDecodingProm: rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / rate(vllm:spec_decode_num_draft_tokens_total[$interval]) - The mean acceptance length can be calculated using: + The mean acceptance length (conventionally including bonus tokens) + can be calculated using: + 1 + ( rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / - rate(vllm:spec_decode_num_drafts[$interval]) + rate(vllm:spec_decode_num_drafts[$interval])) A per-position acceptance rate vector can be computed using @@ -114,25 +120,31 @@ class SpecDecodingProm: vllm:spec_decode_num_drafts[$interval] """ - def __init__(self, speculative_config: Optional[SpeculativeConfig], - labelnames: list[str], labelvalues: list[str]): + _counter_cls = prometheus_client.Counter + + def __init__( + self, + speculative_config: Optional[SpeculativeConfig], + labelnames: list[str], + labelvalues: list[str], + ): self.spec_decoding_enabled = speculative_config is not None if not self.spec_decoding_enabled: return self.counter_spec_decode_num_drafts = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_drafts_total", + self._counter_cls( + name="vllm:spec_decode_num_drafts", documentation="Number of spec decoding drafts.", labelnames=labelnames).labels(*labelvalues) self.counter_spec_decode_num_draft_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_draft_tokens_total", + self._counter_cls( + name="vllm:spec_decode_num_draft_tokens", documentation="Number of draft tokens.", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames,).labels(*labelvalues) self.counter_spec_decode_num_accepted_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_accepted_tokens_total", + self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens", documentation="Number of accepted tokens.", labelnames=labelnames).labels(*labelvalues) @@ -140,12 +152,13 @@ def __init__(self, speculative_config: Optional[SpeculativeConfig], num_spec_tokens = (speculative_config.num_speculative_tokens if self.spec_decoding_enabled else 0) pos_labelnames = labelnames + ["position"] - base_counter = prometheus_client.Counter( + base_counter = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens_per_pos", documentation="Accepted tokens per draft position.", - labelnames=pos_labelnames) - self.counter_spec_decode_num_accepted_tokens_per_pos: \ - list[prometheus_client.Counter] = [] + labelnames=pos_labelnames, + ) + self.counter_spec_decode_num_accepted_tokens_per_pos: list[ + prometheus_client.Counter] = [] for pos in range(num_spec_tokens): pos_labelvalues = labelvalues + [str(pos)] self.counter_spec_decode_num_accepted_tokens_per_pos.append( diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index ce81a40ee3ae..334258e7f87a 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu_input_batch import InputBatch @@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: return False return True + + +@triton.jit +def prepare_eagle_input_kernel( + out_ptr, + cu_query_lens_ptr, + cu_num_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + # [start_pos, end_pos) + start_pos = tl.load(cu_num_tokens_ptr + pid) + end_pos = tl.load(cu_num_tokens_ptr + pid + 1) + num_tokens = end_pos - start_pos + + index_start = tl.load(cu_query_lens_ptr + pid) + + num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) + for i in tl.range(num_blocks): + offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + out_ptr + start_pos + offset, + index_start + offset, + mask=offset < num_tokens, + ) diff --git a/vllm/v1/stats/common.py b/vllm/v1/stats/common.py deleted file mode 100644 index 46818977dae5..000000000000 --- a/vllm/v1/stats/common.py +++ /dev/null @@ -1,453 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import time -from dataclasses import dataclass -from dataclasses import field as dataclass_field -from enum import IntEnum -from typing import ClassVar, Optional - -import msgspec -from msgspec import field as msgspec_field - -from vllm.sampling_params import SamplingParams - - -class RequestStatsUpdate( - msgspec.Struct, # type: ignore - array_like=True, - omit_defaults=True, - gc=False): - """ - An update to the request stats. - - This represents a stats update at a specific timestamp with metadata - associated with the update. - - NOTE: since there might be multiple processes generating updates at - different parts of the engine (e.g. input processor, scheduler, engine core, - etc.), we use the monotonic timestamp to record the update to compute any - intervals, and explicit wall-clock timestamp should be used for timestamps. - - WARNING: This assumes stats are generated in a single machine. If there are - potentially multiple machines, one should always generate the stats updates - on one single machine or use something else. - """ - - class Type(IntEnum): - """See `RequestStats` for the lifecycle of a request.""" - - # Request arrived at the engine frontend. - ARRIVED = 0 - # Input processed by the input processor. - INPUT_PROCESSED = 1 - # Queued on the engine core. - QUEUED = 2 - # Scheduled running prefill by the scheduler. - # A request could be running a new prefill on the prompt tokens or - # a resumed prefill on the original prefill tokens + generated output - # tokens before preemption. - PREFILLING = 3 - # Preempted by the scheduler. - PREEMPTED = 4 - # Output token is generated by the engine core. - DECODING = 5 - # Token detokenized by the detokenizer. - # We will record the timestamp for each output token, as well as the - # finish reason. - DETOKENIZED = 6 - # Request finishes (or aborts). - FINISHED = 7 - - """ - Valid state updates: - ARRIVED - โ”‚ - โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ–บ INPUT_PROCESSED โ”€โ”€โ”€โ”€โ”€โ”€โ–บ QUEUED โ”€โ”€โ”€โ”€โ”€โ”€โ–บ PREFILLING โ—„โ”€โ”€โ”€โ”€โ” - โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ - โ”‚ โ”‚ โ”‚ โ–ผ โ”‚ - โ”‚ โ”‚ โ”‚ -โ”€โ”€โ–บ DECODING โ”‚ - โ”‚ โ”‚ โ”‚ | โ”‚ โ”‚ - โ”‚ โ”‚ โ”‚ | โ–ผ โ”‚ - โ”‚ โ”‚ โ”‚ โ””โ”€ DETOKENIZED โ”‚ - โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ - โ”‚ โ”‚ โ”‚ โ–ผ โ”‚ - โ”‚ โ–ผ โ–ผ PREEMPTED โ—„โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ - โ”‚ โ”‚ โ”‚ โ”‚ - โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ด - โ”‚ - โ–ผ - FINISHED (All could go to FINISHED) - """ - _VALID_TRANSITIONS: ClassVar[dict[Type, set[Type]]] = { - Type.ARRIVED: { - Type.INPUT_PROCESSED, - Type.FINISHED, - }, - Type.INPUT_PROCESSED: { - Type.QUEUED, - Type.FINISHED, - }, - Type.QUEUED: { - Type.PREFILLING, - Type.FINISHED, - }, - Type.PREFILLING: { - Type.DECODING, - Type.PREEMPTED, - Type.FINISHED, - }, - Type.DECODING: { - Type.DETOKENIZED, - Type.FINISHED, - }, - Type.DETOKENIZED: { - Type.DECODING, - Type.PREEMPTED, - Type.FINISHED, - }, - Type.PREEMPTED: {Type.PREFILLING, Type.FINISHED}, - Type.FINISHED: set(), - } - - request_id: str - - type: Type - - # Timestamp when the update is recorded. This is used to record time - # intervals between events rather than wall clock time. - monotonic_ts_s: float = msgspec_field( - default_factory=lambda: time.monotonic()) - - ############################################################ - # Metadata associated with the update. - ############################################################ - # For input_processed. Metadata needed for stats logging. - num_prompt_tokens: Optional[int] = None - sampling_params: Optional[SamplingParams] = None - - # For running. - # Number of tokens computed when scheduled to run. - num_computed_tokens: Optional[int] = None - # Number of cached tokens when scheduled to run. - num_cached_tokens: Optional[int] = None - - # For decoded. - # The number of new output tokens generated. - num_new_tokens: Optional[int] = None - - # For both detokenized and decoded. - # Finished reason. - finish_reason: Optional[str] = None - - # Non-optional fields for each update type. - _REQUIRED_FIELDS: ClassVar[dict[Type, list[str]]] = { - Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"], - Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"], - Type.DETOKENIZED: ["num_new_tokens"], - Type.FINISHED: ["finish_reason"], - } - - def __post_init__(self): - required_fields = self._REQUIRED_FIELDS.get(self.type, []) - for field in required_fields: - if getattr(self, field) is None: - raise ValueError( - f"Field {field} is required for update type {self.type}.") - - @staticmethod - def check_valid_update( - update: "RequestStatsUpdate", - last_update_type: Optional[Type], - last_updated_ts_s: Optional[float], - ): - if last_update_type is None: - assert update.type == RequestStatsUpdate.Type.ARRIVED - else: - valid_cur_update_types = RequestStatsUpdate._VALID_TRANSITIONS[ - last_update_type] - assert update.type in valid_cur_update_types, ( - f"Invalid update type: {update.type} for last_update_type: " - f"{last_update_type}.") - - if last_updated_ts_s is not None: - assert update.monotonic_ts_s >= last_updated_ts_s, ( - "Update timestamp must be monotonically increasing, but " - f"last_updated_ts_s={last_updated_ts_s} and " - f"update.monotonic_ts_s={update.monotonic_ts_s}.") - - -@dataclass -class RequestStats: - """Stats associated with a request (`Request`).""" - - ############################################################ - # Metadata - ############################################################ - request_id: str - sampling_params: Optional[SamplingParams] = None - num_prompt_tokens: Optional[int] = None - - ############################################################ - # Metrics and Stats - ############################################################ - # Timestamp when the request was last updated. - last_updated_ts_s: Optional[float] = None - - # Last update stats type. - last_update_type: Optional[RequestStatsUpdate.Type] = None - - # Timestamp when the request arrived at the llm engine. - arrival_ts_s: Optional[float] = None - - # Number of tokens cached. When part of the request prefix is cached, - # this will be set. - num_cached_tokens: int = 0 - - # Number of tokens computed. - num_computed_tokens: int = 0 - - # The timestamp when the request become waiting in the queue. - queued_ts_s: Optional[float] = None - - # When the input processor is completed. - input_processor_end_ts_s: Optional[float] = None - - # A sorted list of timestamps when the request was scheduled to prefill. - # This could be when: - # 1. the request is newly scheduled, so it's a new prefill. - # 2. the request was preempted and resumed. It is equivalent to running - # a prefill of the original prefill tokens + generated output tokens - # before preemption. - prefill_start_ts_s_lst: list[float] = dataclass_field(default_factory=list) - - # A list of timestamps when a token is decoded by the engine core. - decoding_ts_s_lst: list[float] = dataclass_field(default_factory=list) - - # A sorted list of timestamps for each output token. - output_token_ts_s_lst: list[float] = dataclass_field(default_factory=list) - - # First token's timestamp. - first_token_ts_s: Optional[float] = None - - # TODO(rickyx): we need model runner to surface these. - model_forward_duration_s: float = 0.0 - # Includes model forward, block/sync across workers, cpu-gpu sync time - # and sampling time. - model_execute_duration_s: float = 0.0 - - # A sorted list of timestamps when the request was preempted at the - # scheduler. - # TODO(rickyx): right now, we don't actually have a good high-level - # metric to measure the impact of preemption other than observation of - # large P99 TPOT. Ideally we could quantify the impact of preemption by - # measuring the number of tokens re-computed due to preemption. - preempted_ts_s_lst: list[float] = dataclass_field(default_factory=list) - - # Timestamp when the request was finished at the engine core. - finished_ts_s: Optional[float] = None - - # Finish reason. - finish_reason: Optional[str] = None - - ############################################################ - # Derived properties. - ############################################################ - @property - def prefill_ts_s(self) -> Optional[float]: - """The timestamp when the request started prefilling. - Since a request could be preempted in decoding and later resumed - to prefill the decoded tokens, we use the first prefill start timestamp. - """ - return (self.prefill_start_ts_s_lst[0] - if self.prefill_start_ts_s_lst else None) - - @property - def e2e_latency_s(self) -> Optional[float]: - if self.finished_ts_s is None or self.arrival_ts_s is None: - return None - assert self.finished_ts_s >= self.arrival_ts_s - return self.finished_ts_s - self.arrival_ts_s - - @property - def queue_duration_s(self) -> Optional[float]: - """How long the request was waiting to run.""" - if self.queued_ts_s is None or self.prefill_ts_s is None: - # Either not queued or not running yet. - return None - assert self.queued_ts_s <= self.prefill_ts_s - return self.prefill_ts_s - self.queued_ts_s - - @property - def inference_latency_s(self) -> Optional[float]: - """How long the request was running inference - (prefill and decode).""" - if self.finished_ts_s is None or self.prefill_ts_s is None: - return None - assert self.finished_ts_s >= self.prefill_ts_s - return self.finished_ts_s - self.prefill_ts_s - - @property - def first_token_latency_s(self) -> Optional[float]: - if self.first_token_ts_s is None or self.arrival_ts_s is None: - return None - assert self.first_token_ts_s >= self.arrival_ts_s - return self.first_token_ts_s - self.arrival_ts_s - - @property - def prefill_latency_s(self) -> Optional[float]: - if self.first_token_ts_s is None or self.prefill_ts_s is None: - return None - assert self.first_token_ts_s >= self.prefill_ts_s - return self.first_token_ts_s - self.prefill_ts_s - - @property - def decode_latency_s(self) -> Optional[float]: - if self.e2e_latency_s is None or self.first_token_latency_s is None: - return None - assert self.e2e_latency_s >= self.first_token_latency_s - return self.e2e_latency_s - self.first_token_latency_s - - @property - def output_token_latency_s_lst(self) -> list[float]: - if len(self.output_token_ts_s_lst) == 0: - return [] - latency_s_lst = [] - for i in range(1, len(self.output_token_ts_s_lst)): - assert (self.output_token_ts_s_lst[i] - >= self.output_token_ts_s_lst[i - 1]) - latency_s = (self.output_token_ts_s_lst[i] - - self.output_token_ts_s_lst[i - 1]) - latency_s_lst.append(latency_s) - return latency_s_lst - - @property - def num_output_tokens(self) -> int: - return len(self.output_token_ts_s_lst) - - @property - def is_finished(self) -> bool: - return self.finished_ts_s is not None - - def update_from(self, update: "RequestStatsUpdate"): - RequestStatsUpdate.check_valid_update(update, self.last_update_type, - self.last_updated_ts_s) - ts = update.monotonic_ts_s - self.last_updated_ts_s = ts - self.last_update_type = update.type - if update.type == RequestStatsUpdate.Type.ARRIVED: - self.arrival_ts_s = ts - elif update.type == RequestStatsUpdate.Type.INPUT_PROCESSED: - self.input_processor_end_ts_s = ts - self.sampling_params = update.sampling_params - self.num_prompt_tokens = update.num_prompt_tokens - elif update.type == RequestStatsUpdate.Type.QUEUED: - self.queued_ts_s = ts - elif update.type == RequestStatsUpdate.Type.PREFILLING: - self.prefill_start_ts_s_lst.append(ts) - self.num_cached_tokens = update.num_cached_tokens or 0 - self.num_computed_tokens = update.num_computed_tokens or 0 - elif update.type == RequestStatsUpdate.Type.PREEMPTED: - self._reset_for_preemption(ts) - elif update.type == RequestStatsUpdate.Type.DECODING: - self.decoding_ts_s_lst.append(ts) - elif update.type == RequestStatsUpdate.Type.DETOKENIZED: - self._record_detokenized_output( - ts, - update.num_new_tokens or 0, - ) - elif update.type == RequestStatsUpdate.Type.FINISHED: - self.finished_ts_s = ts - self.finish_reason = update.finish_reason - else: - raise ValueError(f"Unknown update type: {update.type}") - - def _record_detokenized_output( - self, - ts_s: float, - num_new_tokens: int, - ): - # Update if first output token is generated. - if len(self.output_token_ts_s_lst) == 0: - self.first_token_ts_s = ts_s - assert ( - self.prefill_ts_s is not None - ), "Request must be running before generating output tokens." - - # Some X new tokens were generated at the ts. - self.output_token_ts_s_lst.extend([ts_s] * num_new_tokens) - - def _reset_for_preemption(self, ts_s: float): - self.preempted_ts_s_lst.append(ts_s) - # Reset the computed tokens since it might restart the prefill. - self.num_computed_tokens = 0 - # Cached token count might also change when resumed. - self.num_cached_tokens = 0 - # These stats don't change since they happen before request running. - # - arrival_ts_s - # - input_processor_end_ts_s - # - sampling_params - # - num_prompt_tokens - # - first_token_ts_s - # - # These stats are accumulated over preemptions: - # - output_token_ts_s_lst - # - prefill_start_ts_s_lst (after preemption, it will prefill the - # original prefill tokens and any output tokens generated before - # preemption.) - - -@dataclass -class KVCacheStats: - # KV Cache Usage in % - gpu_cache_usage_sys: float = 0.0 - gpu_prefix_cache_hit_rate: float = 0.0 - - -@dataclass -class SchedulerStats: - """Stats associated with the scheduler.""" - - # Number of requests currently running. - num_running_reqs: int = 0 - # Number of requests currently waiting. - num_waiting_reqs: int = 0 - - kv_cache_stats: KVCacheStats = dataclass_field( - default_factory=KVCacheStats) - - -@dataclass -class EngineCoreProcessStats: - """Stats associated with the engine core process.""" - - # Number of requests currently in the input queue. None if the engine core - # is not running in multiprocess mode. - input_queue_size: Optional[int] = None - # Number of outputs currently in the output queue. None if the engine core - # is not running in multiprocess mode. - output_queue_size: Optional[int] = None - - -class EngineCoreStatsSnapshot( - msgspec.Struct, # type: ignore - array_like=True, - omit_defaults=True, - gc=False): - """ - A snapshot of the EngineCore's current stats over a period of time. - """ - - # Snapshot of the scheduler stats. - scheduler_stats: SchedulerStats = msgspec_field( - default_factory=SchedulerStats) - - # Per request stats updates. - requests_stats_updates: list[RequestStatsUpdate] = msgspec_field( - default_factory=list) - - # Engine core's queue stats. - engine_core_process_stats: EngineCoreProcessStats = msgspec_field( - default_factory=EngineCoreProcessStats) - - # TODO(rickyx): Add other components' stats, - # e.g. model runner/worker and etc. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 3183edb7c94e..c701ab1d35a5 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,16 +7,23 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) +from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend if TYPE_CHECKING: import numpy as np import numpy.typing as npt import torch + from vllm.reasoning import ReasoningParser from vllm.v1.request import Request +else: + torch = LazyLoader("torch", globals(), "torch") logger = init_logger(__name__) @@ -26,9 +33,11 @@ class StructuredOutputManager: def __init__(self, vllm_config: VllmConfig): self.backend: Optional[StructuredOutputBackend] = None + self.reasoner: Optional[ReasoningParser] = None self.vllm_config = vllm_config self._grammar_bitmask: Optional[torch.Tensor] = None + self._full_mask = torch.tensor(-1, dtype=torch.int32) # The default max_workers if not specified is the number of CPUs * 5, # which is way too high since these tasks are CPU-bound, not I/O bound. @@ -36,24 +45,43 @@ def __init__(self, vllm_config: VllmConfig): # compilation, so we set it to half the number of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.tokenizer = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) + reasoning_backend = vllm_config.decoding_config.reasoning_backend + if reasoning_backend: + reasoner_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return + if TYPE_CHECKING: + assert request.sampling_params.guided_decoding is not None + # Initialize the backend the first time it is needed. # # NOTE: We only support a single backend. We do NOT support different # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: backend = request.sampling_params.guided_decoding.backend + vocab_size = self.vllm_config.model_config.get_vocab_size() if backend == "xgrammar": - from vllm.v1.structured_output.backend_xgrammar import ( - XgrammarBackend) - - self.backend = XgrammarBackend(self.vllm_config) + self.backend = XgrammarBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) elif backend == "guidance": - self.backend = GuidanceBackend(self.vllm_config) + self.backend = GuidanceBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend}") @@ -87,14 +115,14 @@ def grammar_bitmask( if not structured_output_request_ids: return None + max_num_spec_tokens = 0 + if self.vllm_config.speculative_config is not None: + max_num_spec_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens + if self._grammar_bitmask is None: assert self.backend is not None max_batch_size = self.vllm_config.scheduler_config.max_num_seqs - if self.vllm_config.speculative_config is not None: - max_num_spec_tokens = self.vllm_config.\ - speculative_config.num_speculative_tokens - else: - max_num_spec_tokens = 0 # Allocate a bitmask for each token needing to be checked: # one for each speculative position, and one more for the @@ -103,6 +131,7 @@ def grammar_bitmask( self.backend.allocate_token_bitmask( max_batch_size * (1 + max_num_spec_tokens)) + bitmask_tensor = self._grammar_bitmask # Generate a batched bitmask for all structured output requests. # When speculative decoding is enabled, we need to include multiple # masks for each request, one for each possible bonus token position. @@ -110,16 +139,30 @@ def grammar_bitmask( cumulative_index = 0 ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) + + # Note that for thinking support, we will need to + # reset the relevant part of the bitmask for consequent + # request here. + bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( + self._full_mask) + # NOTE: This outer loop can likely be parallelized to improve # performance of bitmask generation for large batches. for req_id, _ in ordered_seq: request = requests[req_id].structured_output_request - assert request is not None and request.grammar is not None + if TYPE_CHECKING: + assert request is not None + assert request.grammar is not None + + apply_bitmask = ( + request.reasoning_ended if self.reasoner is not None else True + ) # noqa: E501 + state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] for i, token in enumerate(req_tokens): - if not request.grammar.is_terminated(): - request.grammar.fill_bitmask(self._grammar_bitmask, + if apply_bitmask and not request.grammar.is_terminated(): + request.grammar.fill_bitmask(bitmask_tensor, cumulative_index) if token is not None: # In order to generate the correct bitmask for each @@ -132,15 +175,41 @@ def grammar_bitmask( if state_advancements > 0: request.grammar.rollback(state_advancements) - bitmask_tensor = self._grammar_bitmask - if cumulative_index < self._grammar_bitmask.shape[0]: - bitmask_tensor = self._grammar_bitmask[:cumulative_index] + if cumulative_index < bitmask_tensor.shape[0]: + bitmask_tensor = bitmask_tensor[:cumulative_index] # After finishing with the xgrammar operations, we convert to # np.ndarray, because that is much more efficient for serialization # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() + def should_advance(self, request: Request) -> bool: + if not request.use_structured_output: + return False + + # To determine whether we can advance the FSM. + # Supports thinking usage where we skip the reasoning components. + if TYPE_CHECKING: + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None + # by default, we should always advance + # for cases that doesn't uses thinking mode. + if self.reasoner is not None: + structured_req = request.structured_output_request + + if structured_req.reasoning_ended: + return True + + # Check if reasoning ends in *this* step + if self.reasoner.is_reasoning_end(request.all_token_ids): + # Reasoning just ended, so we shouldn't advanced til + # next pass + structured_req.reasoning_ended = True + + return False + else: + return True + def clear_backend(self) -> None: if self.backend is not None: self.backend.destroy() diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 0ab175e781e7..55c5f609095d 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import copy import json import os @@ -8,10 +10,8 @@ import torch -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, @@ -54,25 +54,17 @@ def process_for_additional_properties( return guide_json_obj +@dataclass class GuidanceBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - self.vllm_config = vllm_config - self.vocab_size = vllm_config.model_config.get_vocab_size() - + def __post_init__(self): self.disable_any_whitespace = \ - vllm_config.decoding_config.disable_any_whitespace + self.vllm_config.decoding_config.disable_any_whitespace self.disable_additional_properties = \ - vllm_config.decoding_config.disable_additional_properties + self.vllm_config.decoding_config.disable_additional_properties - tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( - tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size) def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 33ca9f8cf484..09f6cdf73337 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -1,9 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import enum from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch -import torch + from vllm.config import VllmConfig + from vllm.transformers_utils.tokenizer import AnyTokenizer class StructuredOutputOptions(enum.Enum): @@ -85,9 +93,14 @@ def reset(self): """ +@dataclass class StructuredOutputBackend(ABC): """Engine-level backend for structured output requests.""" + vllm_config: VllmConfig + tokenizer: AnyTokenizer + vocab_size: int + @abstractmethod def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: @@ -104,7 +117,7 @@ def compile_grammar(self, request_type: StructuredOutputOptions, """ @abstractmethod - def allocate_token_bitmask(self, max_num_seqs: int): + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: """ Allocates a token bitmask for the specified maximum number of sequences. diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index c82a3cab2fa3..f2570221da25 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -7,10 +9,8 @@ import torch import vllm.envs -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -28,61 +28,49 @@ logger = init_logger(__name__) +@dataclass class XgrammarBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - + def __post_init__(self): self.disable_any_whitespace = \ - vllm_config.decoding_config.disable_any_whitespace + self.vllm_config.decoding_config.disable_any_whitespace - self.num_speculative_tokens = 0 - if self.vllm_config.speculative_config is not None: - self.num_speculative_tokens = \ - self.vllm_config.speculative_config.num_speculative_tokens - - tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.vocab_size = vllm_config.model_config.get_vocab_size() - if isinstance(tokenizer, MistralTokenizer): + if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 try: - if tokenizer.is_tekken: - encoded_vocab = tokenizer._vocab + if self.tokenizer.is_tekken: + encoded_vocab = self.tokenizer._vocab else: encoded_vocab = [ token for token, _ in sorted( - tokenizer.get_vocab().items(), + self.tokenizer.get_vocab().items(), key=lambda x: x[1], ) ] stop_token_ids = None - if hasattr( - tokenizer, + if (hasattr( + self.tokenizer, "eos_token_id", - ) and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + ) and self.tokenizer.eos_token_id is not None): + stop_token_ids = [self.tokenizer.eos_token_id] except AttributeError as e: raise ValueError( f"Cannot get the vocabulary of the tokenizer " - f"{type(tokenizer)}. The tokenizer should have a " + f"{type(self.tokenizer)}. The tokenizer should have a " "get_vocab method.") from e tokenizer_info = xgr.TokenizerInfo( # type: ignore encoded_vocab=encoded_vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW - if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, + if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, vocab_size=self.vocab_size, stop_token_ids=stop_token_ids, add_prefix_space=True, ) else: tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, + self.tokenizer, vocab_size=self.vocab_size, ) self.compiler = xgr.GrammarCompiler( @@ -92,6 +80,11 @@ def __init__(self, vllm_config: VllmConfig): cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024, ) + self.num_speculative_tokens = 0 + if self.vllm_config.speculative_config is not None: + self.num_speculative_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens + def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: @@ -215,9 +208,8 @@ def check_object(obj: dict[str, Any]) -> bool: # Check for array unsupported keywords if obj.get("type") == "array" and any( - key in obj - for key in ("uniqueItems", "contains", "minContains", - "maxContains", "minItems", "maxItems")): + key in obj for key in ("uniqueItems", "contains", + "minContains", "maxContains")): return True # Unsupported keywords for strings @@ -282,6 +274,12 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: else: schema = gd_params.json + try: + xgr.Grammar.from_json_schema(schema) + except Exception as err: + raise ValueError("Failed to transform json schema into a grammar: " + f"{err}") from err + if has_xgrammar_unsupported_json_features(schema): raise ValueError("The provided JSON schema contains features not " "supported by xgrammar.") diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 6ef472eb896c..c16320b9e74c 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -20,6 +20,7 @@ class StructuredOutputRequest: sampling_params: SamplingParams _grammar: Optional[Union[Future[StructuredOutputGrammar], StructuredOutputGrammar]] = None + reasoning_ended: bool = False def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index f33f4972e103..111e92dc0990 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -2,7 +2,7 @@ from __future__ import annotations -import re +import regex as re def grammar_is_likely_lark(grammar_str: str) -> bool: diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 9c238c3aad8e..0758747a83cc 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,20 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 import os +import time import weakref from collections import defaultdict from collections.abc import Sequence -from multiprocessing import Process -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from multiprocessing import Process, connection +from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, + overload) import torch +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import get_mp_context, kill_process_tree +from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: from vllm.attention.layer import Attention @@ -92,7 +95,7 @@ def __repr__(self): return f"ConstantList({self._x})" -class BackgroundProcHandle: +class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown of background processes used by the AsyncLLM and LLMEngine. @@ -100,49 +103,91 @@ class BackgroundProcHandle: def __init__( self, - input_path: str, - output_path: str, - process_name: str, target_fn: Callable, - process_kwargs: dict[Any, Any], + local_engine_count: int, + start_index: int, + local_start_index: int, + vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, + executor_class: type[Executor], + log_stats: bool, ): context = get_mp_context() + common_kwargs = { + "vllm_config": vllm_config, + "on_head_node": on_head_node, + "input_address": input_address, + "executor_class": executor_class, + "log_stats": log_stats, + } + + self.processes: list[Process] = [] + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index + # Start EngineCore in background process. + self.processes.append( + context.Process(target=target_fn, + name=f"EngineCore_{global_index}", + kwargs=common_kwargs | { + "dp_rank": global_index, + "local_dp_rank": local_index, + })) + + self._finalizer = weakref.finalize(self, shutdown, self.processes, + input_address) + try: + for proc in self.processes: + proc.start() + finally: + # Kill other procs if not all are running. + if self.finished_procs(): + self.close() + + def close(self): + """Shutdown all procs.""" + self._finalizer() - assert ("input_path" not in process_kwargs - and "output_path" not in process_kwargs) - process_kwargs["input_path"] = input_path - process_kwargs["output_path"] = output_path - - # Run busy loop in background process. - self.proc: Process = context.Process(target=target_fn, - kwargs=process_kwargs, - name=process_name) - self._finalizer = weakref.finalize(self, shutdown, self.proc, - input_path, output_path) - self.proc.start() + def join_first(self): + """Wait for any process to exit.""" + connection.wait(proc.sentinel for proc in self.processes) - def fileno(self): - return self.proc.sentinel + def sentinels(self) -> list: + return [proc.sentinel for proc in self.processes] - def shutdown(self): - self._finalizer() + def finished_procs(self) -> dict[str, int]: + """Returns dict of proc name -> exit code for any finished procs.""" + return { + proc.name: proc.exitcode + for proc in self.processes if proc.exitcode is not None + } # Note(rob): shutdown function cannot be a bound method, -# else the gc cannot collect the object. -def shutdown(proc: Process, input_path: str, output_path: str): +# else the gc cannot collect the objedecoupct. +def shutdown(procs: list[Process], input_address: str): # Shutdown the process. - if proc.is_alive(): - proc.terminate() - proc.join(5) - + for proc in procs: + if proc.is_alive(): + proc.terminate() + + # Allow 5 seconds for remaining procs to terminate. + deadline = time.monotonic() + 5 + for proc in procs: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + if proc.is_alive(): + proc.join(remaining) + + for proc in procs: if proc.is_alive() and (pid := proc.pid) is not None: kill_process_tree(pid) # Remove zmq ipc socket files. - ipc_sockets = [output_path, input_path] - for ipc_socket in ipc_sockets: - socket_file = ipc_socket.replace("ipc://", "") + if input_address.startswith("ipc://"): + socket_file = input_address[len("ipc://"):] if os and os.path.exists(socket_file): os.remove(socket_file) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 7d4082b73992..576086ebeb7f 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,6 +4,7 @@ import torch from vllm.logger import init_logger +from vllm.utils import cdiv logger = init_logger(__name__) @@ -14,11 +15,13 @@ def __init__( self, max_num_reqs: int, max_num_blocks_per_req: int, + max_num_batched_tokens: int, pin_memory: bool, device: torch.device, ): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device @@ -36,6 +39,15 @@ def __init__( self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) + def append_row( self, block_ids: list[int], @@ -85,3 +97,43 @@ def get_cpu_tensor(self) -> torch.Tensor: def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table_np + + +class MultiGroupBlockTable: + """The BlockTables for each KV cache group.""" + + def __init__(self, max_num_reqs: int, max_model_len: int, + max_num_batched_tokens: int, pin_memory: bool, + device: torch.device, block_size: int) -> None: + self.block_tables = [ + BlockTable(max_num_reqs, cdiv(max_model_len, block_size), + max_num_batched_tokens, pin_memory, device) + ] + + def append_row(self, block_ids: list[list[int]], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.append_row(block_ids[i], row_idx) + + def add_row(self, block_ids: list[list[int]], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.add_row(block_ids[i], row_idx) + + def move_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.move_row(src, tgt) + + def swap_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.swap_row(src, tgt) + + def commit(self, num_reqs: int) -> None: + for block_table in self.block_tables: + block_table.commit(num_reqs) + + def clear(self) -> None: + for block_table in self.block_tables: + block_table.clear() + + def __getitem__(self, idx: int) -> "BlockTable": + """Returns the BlockTable for the i-th KV cache group.""" + return self.block_tables[idx] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c00424dfea73..b3e65917d3cc 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -14,7 +14,7 @@ from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.block_table import MultiGroupBlockTable _SAMPLING_EPS = 1e-5 @@ -29,7 +29,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: list[int] + block_ids: list[list[int]] num_computed_tokens: int output_token_ids: list[int] @@ -58,14 +58,15 @@ def __init__( self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, + max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, + block_size: int, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size @@ -97,11 +98,13 @@ def __init__( self.num_computed_tokens_cpu_tensor.numpy() # Block table. - self.block_table = BlockTable( + self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_num_blocks_per_req=max_num_blocks_per_req, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, + block_size=block_size, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bd8c87fd9efc..910c0e80bb31 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import copy import gc import time import weakref @@ -11,24 +12,29 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder) from vllm.attention.layer import Attention from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) -from vllm.distributed.parallel_state import get_pp_group, graph_capture -from vllm.forward_context import set_forward_context +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.distributed.parallel_state import ( + get_pp_group, get_tp_group, graph_capture, + prepare_communication_buffer_for_model) +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader import TensorizerLoader, get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LayerBlockType, LazyLoader, cdiv, + GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -42,10 +48,12 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -55,6 +63,7 @@ if TYPE_CHECKING: import xgrammar as xgr + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") @@ -97,61 +106,17 @@ def __init__( self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] - # NOTE(woosuk): sliding_window is None for models with interleaved - # attention. Use interleaved_sliding_window instead. - self.sliding_window = model_config.get_sliding_window() - self.interleaved_sliding_window = getattr( - model_config.hf_text_config, "interleaved_sliding_window", None) - self.window_size = (self.sliding_window - or self.interleaved_sliding_window) - self.is_multimodal_model = model_config.is_multimodal_model - self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) self.num_query_heads = model_config.get_num_attention_heads( parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size - self.attn_backend = get_attn_backend( - self.head_size, - self.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - if self.attn_backend is None: - error_msg = ( - f"Error with get_att_backend: {self.head_size=}, " - f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{self.model_config.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 GPUModelRunner.") - - if self.vllm_config.compilation_config.full_cuda_graph: - attn_backend_name = self.attn_backend.__name__ - flash_attn_version = get_flash_attn_version() - if attn_backend_name != "FlashAttentionBackend" or \ - flash_attn_version != 3: - raise ValueError( - f"full_cuda_graph is only supported with " - f"FA3. Current attention backend is {attn_backend_name}, " - f"FlashAttention version is {flash_attn_version}.") - - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -173,7 +138,10 @@ def __init__( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] + self.attn_backends: list[type[AttentionBackend]] = [] # self.kv_cache_config: KVCacheConfig + # self.input_batch: InputBatch # Persistent batch. # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -183,14 +151,22 @@ def __init__( self.use_aux_hidden_state_outputs = False if self.speculative_config: self.use_spec_decode = True + + # NOTE(Jiayi): currently we put the entire draft model on + # the last PP rank. This is not ideal if there are many + # layers in the draft model. if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, - self.device) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, + self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True + elif self.speculative_config.method == "medusa": + self.drafter = MedusaProposer( + vllm_config=self.vllm_config, + device=self.device) # type: ignore else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -198,14 +174,15 @@ def __init__( # Request states. self.requests: dict[str, CachedRequestState] = {} - # Persistent batch. + self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), + vocab_size=self.model_config.get_vocab_size(), + block_size=self.cache_config.block_size, ) self.use_cuda_graph = (self.vllm_config.compilation_config.level @@ -285,17 +262,11 @@ def __init__( dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) - self.input_ids_np = self.input_ids_cpu.numpy() self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", @@ -307,6 +278,31 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + + Returns: + True if the batch was reordered, False otherwise. + """ + batch_reordered = self.attn_metadata_builders[0].reorder_batch( + self.input_batch, scheduler_output) + + # For models with multiple KV cache groups, the groups should agree on + # the same order of requests. We ensure this by only allowing the first + # group to reorder the batch and asserting that all other groups do not + # reorder the batch. + for i in range(1, len(self.kv_cache_config.kv_cache_groups)): + assert not self.attn_metadata_builders[i].reorder_batch( + self.input_batch, scheduler_output) + return batch_reordered + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -443,7 +439,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - req_state.block_ids.extend(req_data.new_block_ids) + for i in range(len(self.kv_cache_config.kv_cache_groups)): + req_state.block_ids[i].extend(req_data.new_block_ids[i]) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -501,11 +498,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if removed_req_indices: self.input_batch.condense(removed_req_indices) - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - batch_reordered = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) + batch_reordered = self._may_reorder_batch(scheduler_output) if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() @@ -573,20 +566,29 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + # Calculate the slot mapping for each KV cache group. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table: BlockTable = self.input_batch.block_table[ + kv_cache_group_id] + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -614,12 +616,8 @@ def _prepare_inputs( self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) - self.slot_mapping[:total_num_scheduled_tokens].copy_( - self.slot_mapping_cpu[:total_num_scheduled_tokens], - non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache - self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) @@ -632,10 +630,6 @@ def _prepare_inputs( attn_metadata: dict[str, FlashAttentionMetadata] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. - # NOTE(Chen): there is exactly one KV cache group that contains all - # attetnion layers in the model for now, so the current logic for - # getting attn_metadata is not related to kv_cache_group information. - # Will extend this part to support multiple KV cache groups later. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -644,15 +638,19 @@ def _prepare_inputs( if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id], + kv_cache_group_spec.kv_cache_spec, + self.attn_metadata_builders[kv_cache_group_id], ) - attn_metadata_i = self.attn_metadata_builder.build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata) + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata)) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -690,6 +688,8 @@ def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, num_common_prefix_blocks: int, + kv_cache_spec: KVCacheSpec, + attn_metadata_builder: AttentionMetadataBuilder, ) -> int: """Compute the length of the common prefix for cascade attention. @@ -708,7 +708,7 @@ def _compute_cascade_attn_prefix_len( Returns: int: Length of common prefix in tokens. """ - common_prefix_len = num_common_prefix_blocks * self.block_size + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size if common_prefix_len == 0: # Common case. return 0 @@ -757,15 +757,19 @@ def _compute_cascade_attn_prefix_len( common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) - use_cascade = self.attn_metadata_builder.use_cascade_attention( + common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * + kv_cache_spec.block_size) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or + (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None)) + assert isinstance(kv_cache_spec, AttentionSpec) + use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, + num_kv_heads=kv_cache_spec.num_kv_heads, use_alibi=self.use_alibi, - use_sliding_window=self.window_size is not None, + use_sliding_window=use_sliding_window, num_sms=self.num_sms, ) return common_prefix_len if use_cascade else 0 @@ -925,8 +929,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) + batched_mm_inputs = MultiModalKwargs.as_kwargs( + batched_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run the encoder. # `curr_group_outputs` is either of the following: @@ -1067,21 +1074,54 @@ def apply_grammar_bitmask( indices=out_indices, ) + def sync_and_slice_intermediate_tensors( + self, num_tokens: int, intermediate_tensors: IntermediateTensors, + sync_self: bool) -> IntermediateTensors: + + assert self.intermediate_tensors is not None + + tp = self.vllm_config.parallel_config.tensor_parallel_size + enabled_sp = self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism + if enabled_sp: + # When sequence parallelism is enabled, we always pad num_tokens + # to be a multiple of tensor_parallel_size (tp) earlier + assert num_tokens % tp == 0 + is_residual_scattered = tp > 1 and enabled_sp \ + and num_tokens % tp == 0 + + # When sequence parallelism is enabled, the "residual" tensor is sharded + # across tensor parallel ranks, so each rank only needs its own slice. + if sync_self: + assert intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + is_scattered = "residual" and is_residual_scattered + copy_len = num_tokens // tp if is_scattered else \ + num_tokens + self.intermediate_tensors[k][:copy_len].copy_( + v[:copy_len], non_blocking=True) + + return IntermediateTensors({ + k: + v[:num_tokens // tp] + if k == "residual" and is_residual_scattered else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, IntermediateTensors]: - # Update KVConnector with the KVConnector metadata forward(). - if has_kv_transfer_group(): - get_kv_transfer_group().bind_connector_metadata( - scheduler_output.kv_connector_metadata) self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output) # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( @@ -1114,7 +1154,7 @@ def execute_model( else: mm_embeds = [] - if self.is_multimodal_model: + if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -1143,39 +1183,57 @@ def execute_model( if get_pp_group().is_first_rank: intermediate_tensors = None else: - assert intermediate_tensors is not None - assert self.intermediate_tensors is not None - for k, v in intermediate_tensors.items(): - self.intermediate_tensors[k][:num_input_tokens].copy_( - v[:num_input_tokens], non_blocking=True) - intermediate_tensors = IntermediateTensors({ - k: v[:num_input_tokens] - for k, v in self.intermediate_tensors.items() - }) + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True) # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - output = self.model( + self.maybe_setup_kv_connector(scheduler_output) + + model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = output + hidden_states, aux_hidden_states = model_output else: - hidden_states = output - + hidden_states = model_output + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. - return hidden_states - - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + if not broadcast_pp_output: + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict(hidden_states.tensors, + all_gather_group=get_tp_group()) + logits = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: @@ -1193,6 +1251,7 @@ def execute_model( # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. + assert logits is not None bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] sampler_output = self.sampler( logits=bonus_logits, @@ -1266,6 +1325,27 @@ def execute_model( assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) + elif self.speculative_config.method == "medusa": + assert isinstance(self.drafter, MedusaProposer) + if max_gen_len == 1: + hidden_states = sample_hidden_states + else: + indices = [] + offset = 0 + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, + valid_sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + + indices = torch.tensor(indices, + device=sample_hidden_states.device) + hidden_states = sample_hidden_states[indices] + + spec_token_ids = self.drafter.propose( + target_hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. @@ -1286,7 +1366,16 @@ def execute_model( next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) - eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + eagle_attn_metadata = attn_metadata[ + self.drafter.attn_layer_names[0]] + + # NOTE: deepseek_mtp uses MLA which does not have `block_table` + if hasattr(eagle_attn_metadata, "block_table"): + block_table = eagle_attn_metadata.block_table + else: + block_table = None if spec_decode_metadata is None: # input_ids can be None for multimodal models. @@ -1307,14 +1396,16 @@ def execute_model( n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor( + num_rejected_tokens_tensor = async_tensor_h2d( num_rejected_tokens, dtype=torch.int32, - device=self.device, - ) + target_device=self.device, + pin_memory=True) + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) cu_num_tokens, token_indices = self.drafter.prepare_inputs( eagle_attn_metadata.query_start_loc, - num_rejected_tokens, + num_rejected_tokens_tensor, + num_tokens, ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] @@ -1325,7 +1416,6 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] - draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1333,7 +1423,7 @@ def execute_model( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - block_table=eagle_attn_metadata.block_table, + block_table=block_table, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() @@ -1349,8 +1439,56 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + finished_sending=finished_sending, + finished_recving=finished_recving, ) + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # KV send/recv even if no work to do. + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + + if not finished_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfers( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], @@ -1411,6 +1549,16 @@ def load_model(self) -> None: logger.info("Model loading took %.4f GiB and %.6f seconds", self.model_memory_usage / GiB_bytes, time_after_load - time_before_load) + prepare_communication_buffer_for_model(self.model) + + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", + ) -> None: + TensorizerLoader.save_model( + self.model, + tensorizer_config=tensorizer_config, + ) def _get_prompt_logprobs_dict( self, @@ -1529,7 +1677,7 @@ def _dummy_run( dtype=np.int32) if skip_attn: - attn_metadata = None + attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None else: query_start_loc = self.query_start_loc[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] @@ -1537,13 +1685,19 @@ def _dummy_run( common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_tokens, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) + attn_metadata = {} + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_tokens, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + )) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1568,10 +1722,9 @@ def _dummy_run( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, device=self.device)) - intermediate_tensors = IntermediateTensors({ - k: v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False) with set_forward_context(attn_metadata, self.vllm_config, @@ -1587,8 +1740,7 @@ def _dummy_run( else: hidden_states = outputs - if self.use_spec_decode and \ - self.speculative_config.method in ('eagle', 'eagle3'): + if self.use_spec_decode and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) @@ -1600,6 +1752,10 @@ def _dummy_sampler_run( self, hidden_states: torch.Tensor, ) -> torch.Tensor: + # The dummy hidden states may contain special values, + # like `inf` or `nan`. + # To avoid breaking the sampler, we use a random tensor here instead. + hidden_states = torch.rand_like(hidden_states) logits = self.model.compute_logits(hidden_states, None) num_reqs = logits.size(0) @@ -1721,7 +1877,10 @@ def profile_run(self) -> None: batched_dummy_mm_inputs = MultiModalKwargs.batch( [dummy_mm_kwargs] * max_num_mm_items) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, device=self.device) + batched_dummy_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run multimodal encoder. dummy_encoder_outputs = self.model.get_multimodal_embeddings( @@ -1774,6 +1933,56 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the attention backends and attention metadata builders. + """ + assert len(self.attn_backends) == 0 and len( + self.attn_metadata_builders + ) == 0, "Attention backends are already initialized" + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if not isinstance(kv_cache_spec, AttentionSpec): + raise NotImplementedError( + "Only AttentionSpec is supported for now.") + attn_backend_i = get_attn_backend( + kv_cache_spec.head_size, + self.dtype, + kv_cache_spec.dtype, + kv_cache_spec.block_size, + self.model_config.is_attention_free, + use_mla=kv_cache_spec.use_mla, + ) + if attn_backend_i is None: + error_msg = ( + f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " + f"{self.dtype=}, {kv_cache_spec.dtype=}, " + f"{kv_cache_spec.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{kv_cache_spec.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 " + "GPUModelRunner.") + + if self.vllm_config.compilation_config.full_cuda_graph: + attn_backend_name = attn_backend_i.__name__ + flash_attn_version = get_flash_attn_version() + if attn_backend_name != "FlashAttentionBackend" or \ + flash_attn_version != 3: + raise ValueError( + f"full_cuda_graph is only supported with " + f"FA3. Current attention backend is " + f"{attn_backend_name}, FlashAttention version is " + f"{flash_attn_version}.") + + block_table_i = self.input_batch.block_table[i] + attn_metadata_builder_i = attn_backend_i.get_builder_cls()( + weakref.proxy(self), kv_cache_spec, block_table_i) + self.attn_backends.append(attn_backend_i) + self.attn_metadata_builders.append(attn_metadata_builder_i) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1786,10 +1995,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: "Hybrid models with more than one KV cache type are not " "supported yet.") self.kv_cache_config = kv_cache_config + self.initialize_attn_backend(kv_cache_config) kv_caches: dict[str, torch.Tensor] = {} - for kv_cache_group in kv_cache_config.kv_cache_groups: + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] @@ -1804,7 +2014,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backend.get_kv_cache_shape( + kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype @@ -1816,11 +2026,20 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # KV cache specs. raise ValueError("Unknown KV cache spec type.") + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # validate all draft model layers belong to the same kv cache + # group + self.drafter.validate_same_kv_cache_group(kv_cache_config) + bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 5352b1c5a37c..bce5cbb5f9d0 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -31,6 +31,7 @@ logger = init_logger(__name__) if TYPE_CHECKING: + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput @@ -171,10 +172,9 @@ def determine_available_memory(self) -> int: Then, it calculate the free memory that can be used for KV cache in bytes. - :::{tip} - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - ::: + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -275,13 +275,13 @@ def execute_model( output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - - if not get_pp_group().is_last_rank: + parallel_config = self.vllm_config.parallel_config + if parallel_config.distributed_executor_backend != "external_launcher" \ + and not get_pp_group().is_last_rank: assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) return None - assert isinstance(output, ModelRunnerOutput) return output if self.is_driver_worker else None @@ -326,6 +326,13 @@ def save_sharded_state( max_size=max_size, ) + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", + ) -> None: + self.model_runner.save_tensorized_model( + tensorizer_config=tensorizer_config, ) + def init_worker_distributed_environment( vllm_config: VllmConfig, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index be059c30435c..46bcf64ed0c3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -171,18 +171,10 @@ def __init__( self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # self.input_batch: InputBatch # Persistent batch. # Request states. self.requests: dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.vocab_size, - ) # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. @@ -190,20 +182,15 @@ def __init__( self.input_ids_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu") - self.input_ids_np = self.input_ids_cpu.numpy() self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu") self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu") - self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), - dtype=self.input_batch.block_table.get_cpu_tensor().dtype, + dtype=torch.int32, device="cpu") self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, @@ -528,12 +515,13 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + out=self.input_batch.block_table[0]. + slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -557,13 +545,15 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID - slot_mapping = self.slot_mapping_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) + self.input_batch.block_table[0].slot_mapping_cpu[ + total_num_scheduled_tokens:] = _PAD_SLOT_ID + slot_mapping = ( + self.input_batch.block_table[0]. + slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( + self.device)) block_tables = self.block_table_cpu[:self.max_num_reqs] block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) + self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) block_tables = block_tables.to(self.device) query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( self.device) @@ -662,8 +652,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) + batched_mm_inputs = MultiModalKwargs.as_kwargs( + batched_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run the encoder. # `curr_group_outputs` is either of the following: @@ -1264,6 +1257,19 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: "Hybrid models with more than one KV cache type are not " "supported yet.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec. + block_size, + ) + assert self.block_table_cpu.dtype == self.input_batch.block_table[ + 0].get_cpu_tensor().dtype + kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.kv_cache_groups: @@ -1432,8 +1438,11 @@ def _get_mm_dummy_batch(self, modality: str, batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * batch_size) - return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, - device=self.device) + return MultiModalKwargs.as_kwargs( + batched_dummy_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9eea26d85249..fa4eb30ccd9a 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -266,3 +266,11 @@ def init_tpu_worker_distributed_environment( ) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + + +try: + from tpu_commons.worker import TPUWorker as TPUCommonsWorker + TPUWorker = TPUCommonsWorker # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TPUWorker.") + pass diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 267754036b31..91548a52cfc7 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -10,7 +10,7 @@ def sanity_check_mm_encoder_outputs( ) -> None: """ Perform sanity checks for the result of - {meth}`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`. + [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][]. """ assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), ( "Expected multimodal embeddings to be a list/tuple of 2D tensors, " @@ -39,7 +39,7 @@ def scatter_mm_placeholders( Scatter the multimodal embeddings into a contiguous tensor that represents the placeholder tokens. - {class}`vllm.multimodal.processing.PromptUpdateDetails.is_embed`. + [`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][]. Args: embeds: The multimodal embeddings. @@ -66,7 +66,7 @@ def gather_mm_placeholders( """ Reconstructs the embeddings from the placeholder tokens. - This is the operation of {func}`scatter_mm_placeholders`. + This is the operation of [scatter_mm_placeholders][]. """ if is_embed is None: return placeholders diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index c2120c035175..82eeeb570d22 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -297,8 +297,11 @@ def execute_model( model_input.encoder_input_tokens, "encoder_positions": model_input.encoder_input_positions, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), "intermediate_tensors": intermediate_tensors, } diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 710ca1a13b0c..fb436a079f87 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -628,7 +628,10 @@ def execute_model( multimodal_kwargs = {} if model_input.multi_modal_kwargs is not None: multimodal_kwargs = MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs, device=self.device) + model_input.multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ) execute_model_kwargs = {} if previous_hidden_states is not None: execute_model_kwargs.update( diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index 1ceb2557c6b3..2a60e51261ad 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -50,8 +50,11 @@ def execute_model( model_input.input_tokens, "positions": model_input.input_positions, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), **cross_enc_kwargs, "intermediate_tensors": intermediate_tensors, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4864163b0de2..3957e5608524 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -202,9 +202,13 @@ def execute_model( encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), + **seqlen_agnostic_kwargs, + ) logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index a343e2fedb23..e2261cbb26b4 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1554,10 +1554,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - print("aa") self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, kv_caches) - print("bb") self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, kv_caches) diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 7898c645d66a..533fead0e669 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -201,10 +201,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: Then, it calculate the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. - :::{tip} - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - ::: + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d96021cc688e..8c968faa7810 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group +from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, graph_capture) @@ -204,6 +204,7 @@ def simple_reinit(self): self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore + self.prompt_lens[0] = 0 # type: ignore self.query_lens[0] = 0 # type: ignore self.context_lens[0] = 0 # type: ignore self.curr_sliding_window_blocks[0] = 0 # type: ignore @@ -236,6 +237,8 @@ def __init__( # The original sequence length (before applying sliding window). # This is used to compute slot mapping. orig_seq_lens: Optional[List[int]] = None, + # This is used in the dual-chunk flash attention backend. + prompt_lens: Optional[List[int]] = None, # The query length. query_lens: Optional[List[int]] = None, # The number of tokens that are already computed. @@ -316,6 +319,12 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.orig_seq_lens[seq_id] = 0 + if prompt_lens: + self.prompt_lens = prompt_lens + else: + for seq_id in range(len(self.seq_ids)): + self.prompt_lens[seq_id] = 0 + if query_lens: self.query_lens = query_lens else: @@ -370,6 +379,7 @@ def __init__( self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] + self.prompt_lens = prompt_lens or [] self.query_lens = query_lens or [] self.context_lens = context_lens or [] self.curr_sliding_window_blocks = \ @@ -403,6 +413,7 @@ def __post_init__(self): self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs + self.prompt_lens = [0] * self.n_seqs self.query_lens = [0] * self.n_seqs self.context_lens = [0] * self.n_seqs self.curr_sliding_window_blocks = [0] * self.n_seqs @@ -552,6 +563,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len + inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len() inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) inter_data.inputs_embeds = prompt_embeds @@ -717,7 +729,10 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(positions[0], positions[0] + len(positions))) - if not mm_kwargs: + + # M-RoPE requires mrope_positions even for plain text; return early + # when mm_kwargs is empty only if inter_data.is_prompt is False. + if not mm_kwargs and not inter_data.is_prompt: return inter_data.multi_modal_kwargs = mm_kwargs @@ -729,12 +744,6 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, video_grid_thw = mm_kwargs.get("video_grid_thw", None) audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", None) - assert ( - image_grid_thw is not None or video_grid_thw is not None - or audio_feature_lengths is not None), ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw' or " - "'audio_feature_lengths'.") second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) @@ -860,7 +869,7 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = list[int]() - inputs_embeds_lst = list[torch.Tensor]() + inputs_embeds_list = list[torch.Tensor]() token_types = list[int]() for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: @@ -868,15 +877,15 @@ def build(self) -> ModelInputForGPU: for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: - inputs_embeds_lst.append( + inputs_embeds_list.append( inter_data.inputs_embeds.to( dtype=self.runner.model_config.dtype, device=self.runner.device)) inputs_embeds: Optional[torch.Tensor] - if len(inputs_embeds_lst) == 0: + if len(inputs_embeds_list) == 0: inputs_embeds = None else: - inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( + inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to( dtype=self.runner.model_config.dtype, device=self.runner.device) assert len(inputs_embeds) == len(input_tokens) @@ -1836,8 +1845,11 @@ def execute_model( inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), **seqlen_agnostic_kwargs, **model_kwargs, ) @@ -1881,50 +1893,60 @@ def execute_model( logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) - if not self.is_driver_worker: - return [] + if self.is_driver_worker: + if model_input.async_callback is not None: + model_input.async_callback() - if model_input.async_callback is not None: - model_input.async_callback() + # Sample the next token. + assert isinstance(self.sampler, Sampler) + orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor + if model_input.inputs_embeds is not None: + self.sampler.include_gpu_probs_tensor = True - # Sample the next token. - assert isinstance(self.sampler, Sampler) - orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor - if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = True - - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the latency - # from the start time of the driver worker to the end time of the - # driver worker. The model forward time will then end up covering - # the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) + output: SamplerOutput = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + # If there are multiple workers, we are still tracking the + # latency from the start time of the driver worker to the end + # time of the driver worker. The model forward time will then + # end up covering the communication time as well. + output.model_forward_time = (orig_model_forward_time + + model_forward_time) if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = \ - orig_include_gpu_probs_tensor - if output.sampled_token_ids is not None: - output.sampled_token_embeds = self.model.get_input_embeddings( - output.sampled_token_ids.squeeze(1)) - - for token_embed, sequence_group_output in zip( - output.sampled_token_embeds, output.outputs): - assert len(sequence_group_output.samples) == 1 - sequence_group_output.samples[0].output_embed = token_embed + if self.is_driver_worker: + sampled = broadcast_tensor_dict( + {"token_ids": output.sampled_token_ids}) + else: + sampled = broadcast_tensor_dict() + if sampled["token_ids"] is not None: + sampled_token_embeds = self.model.get_input_embeddings( + sampled["token_ids"].squeeze(1)) + if self.is_driver_worker: + self.sampler.include_gpu_probs_tensor = \ + orig_include_gpu_probs + + output.sampled_token_embeds = sampled_token_embeds + + for token_embed, sequence_group_output in zip( + output.sampled_token_embeds, output.outputs): + assert len(sequence_group_output.samples) == 1 + sequence_group_output.samples[ + 0].output_embed = token_embed + + if not self.is_driver_worker: + return [] if self.return_hidden_states: # we only need to pass hidden states of most recent token diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 0825abbed143..f8d5acf586c5 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -733,12 +733,13 @@ def _pythonize_sampler_output( logprobs_tensor: Optional[torch.Tensor], cache: Optional[PythonizationCache], ) -> None: - """ This function is only called when the output tensors are ready. - See {class}`ModelOutput`. - - Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, + """ This function is only called when the output tensors are ready. + See [`ModelOutput`][vllm.worker.multi_step_model_runner.ModelOutput]. + + Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, adding a Pythonized output data structure - ({class}`CompletionSequenceGroupOutput`) for each {class}`SequenceGroup`. + ([`CompletionSequenceGroupOutput`][vllm.sequence.CompletionSequenceGroupOutput]) + for each [`SequenceGroup`][vllm.sequence.SequenceGroup]. Args: model_input @@ -824,7 +825,7 @@ def _pythonize_sampler_output( for sgdx, (seq_group, sample_result) in enumerate(zip(seq_groups, samples_list)): - # Reminder: Please update docs/source/features/compatibility_matrix.md + # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid # (Check for Guided Decoding) if seq_group.sampling_params.logits_processors: diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 9618a4b49ff8..aafb7ab7cfb8 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -70,8 +70,11 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) output = self.model.sample( diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index b6a3492a493b..3a9c0993e004 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -49,8 +49,11 @@ def execute_model( positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) output = self.model.sample( diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index c80b69e78dc0..968596471a26 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -348,7 +348,7 @@ def _convert_to_neuron_sampling_params( if temperature == 0.0: # Enable greedy sampling on zero temperature return (1, 1.0, 1.0) - if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: + if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: top_k = self._MAX_NEURON_SAMPLING_TOP_K return (top_k, top_p, temperature) @@ -378,9 +378,11 @@ def execute_model( positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) elif current_platform.use_transformers_neuronx(): # [TODO] validate on-device sampling @@ -389,9 +391,11 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) # Compute the logits only if the on-device sampling is turned off as diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index fdb7353f2f9c..912e04c435f5 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -119,10 +119,14 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), **cross_enc_kwargs, - **seqlen_agnostic_kwargs) + **seqlen_agnostic_kwargs, + ) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 53541a2579ed..e0cca9072745 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -525,7 +525,7 @@ def _prepare_sample( "Top-p sampling is currently disabled for the TPU backend " "due to performance issues.") p.append(sampling_params.top_p) - if sampling_params.top_k != -1: + if sampling_params.top_k > 0: raise NotImplementedError( "Top-k sampling is currently disabled for the TPU backend " "due to performance issues.") diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index d925f088357b..e2854bcb37ce 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -14,7 +14,7 @@ def assert_enc_dec_mr_supported_scenario( a supported scenario. ''' - # Reminder: Please update docs/source/features/compatibility_matrix.md + # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid if enc_dec_mr.cache_config.enable_prefix_caching: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 1a14919ddfb2..6e45b8423e5e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -71,7 +71,11 @@ def __init__( or (speculative_config.draft_model_config.hf_config.model_type == model_config.hf_config.model_type) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \ + not in ("medusa", + "mlp_speculator", + "eagle", + "deepseek_mtp", + "mimo_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner @@ -230,10 +234,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: Then, it calculate the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. - :::{tip} - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - ::: + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 7042b575aa78..79fa7d2c73e8 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -562,9 +562,12 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device)) + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), + ) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 17f533525171..a5109a982cbf 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -93,10 +93,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: Then, it calculate the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. - :::{tip} - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - ::: + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory.