Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/compile/fusions_e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# Get the compile ranges split points after vllm config post init
# Get the compile ranges endpoints after vllm config post init
# in order to compute compile ranges correctly
compilation_config.compile_ranges_split_points = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
compilation_config.compile_ranges_endpoints = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
)


Expand Down
6 changes: 3 additions & 3 deletions tests/compile/test_compile_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8, 32],
compile_ranges_endpoints=[8, 32],
compile_sizes=[16, 64, 128],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
Expand All @@ -110,7 +110,7 @@ def test_compile_ranges(use_fresh_inductor_cache):

def test_compile_config_get_compile_ranges():
compilation_config = CompilationConfig(
compile_ranges_split_points=[8, 32],
compile_ranges_endpoints=[8, 32],
)
VllmConfig(
scheduler_config=SchedulerConfig(
Expand Down Expand Up @@ -149,7 +149,7 @@ def create_vllm_config():
scheduler_config=scheduler_config,
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8],
compile_ranges_endpoints=[8],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
},
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,8 +830,8 @@ def list_to_str(lst: list | None) -> str:
"splitting_ops": list_to_str(cc.splitting_ops),
"cudagraph_mode": str(cc.cudagraph_mode),
"compile_sizes": list_to_str(cc.compile_sizes),
"compile_ranges_split_points": list_to_str(
cc.compile_ranges_split_points
"compile_ranges_endpoints": list_to_str(
cc.compile_ranges_endpoints
),
"use_inductor_graph_partition": cc.use_inductor_graph_partition,
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
Expand Down
21 changes: 10 additions & 11 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation:
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`compile_ranges_split_points`]
[vllm.config.CompilationConfig.compile_ranges_split_points]
- [`compile_ranges_endpoints`]
[vllm.config.CompilationConfig.compile_ranges_endpoints]
- [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
Expand Down Expand Up @@ -492,12 +492,12 @@ class CompilationConfig:
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""

compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor.
compile_ranges_endpoints: list[int] | None = None
"""Endpoints for Inductor compile ranges.
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
[1, endpoints[0]],
[endpoints[0] + 1, endpoints[1]], ...,
[endpoints[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].

Expand Down Expand Up @@ -1246,10 +1246,9 @@ def adjust_cudagraph_sizes_for_mamba_cache(

def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
if self.compile_ranges_endpoints is None:
return []
split_points = sorted(set(self.compile_ranges_split_points))
endpoints = sorted(set(self.compile_ranges_endpoints))
return [
Range(start=s + 1, end=e)
for s, e in zip([0] + split_points[:-1], split_points)
Range(start=s + 1, end=e) for s, e in zip([0] + endpoints[:-1], endpoints)
]
24 changes: 12 additions & 12 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,12 +1496,12 @@ def _set_compile_ranges(self):
Set the compile ranges for the compilation config.
"""
compilation_config = self.compilation_config
computed_compile_ranges_split_points = []
computed_compile_ranges_endpoints = []

# The upper bound of the compile ranges is the max_num_batched_tokens.
compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
computed_compile_ranges_split_points.append(compile_range_end)
computed_compile_ranges_endpoints.append(compile_range_end)

# Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms:
Expand All @@ -1513,7 +1513,7 @@ def _set_compile_ranges(self):
* self.model_config.dtype.itemsize
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below allreduce-rms fusion threshold, "
Expand Down Expand Up @@ -1545,33 +1545,33 @@ def _set_compile_ranges(self):
and min_token_num < max_num_batched_tokens
and min_token_num > 1
):
# Add split point at min_token_num - 1 to ensure SP applies
# Add endpoint at min_token_num - 1 to ensure SP applies
# starting from min_token_num
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
computed_compile_ranges_split_points.append(min_token_num - 1)
computed_compile_ranges_endpoints.append(min_token_num - 1)

if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = (
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
)
if max_token_num is not None:
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below rope+kvcache fusion threshold, "
"rope+kvcache fusion enabled for num_tokens <= %d.",
compile_range_end,
)

if compilation_config.compile_ranges_split_points is not None:
for x in compilation_config.compile_ranges_split_points:
if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}"
assert x > 0, f"Invalid compile range endpoint: {x}"
if compile_range_end is not None and x < compile_range_end and x > 1:
computed_compile_ranges_split_points.append(x)
compilation_config.compile_ranges_split_points = sorted(
computed_compile_ranges_split_points
computed_compile_ranges_endpoints.append(x)
compilation_config.compile_ranges_endpoints = sorted(
computed_compile_ranges_endpoints
)

def try_verify_and_update_config(self):
Expand Down