Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/compile/test_compile_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,88 @@ def test_compile_config_get_compile_ranges():
]


class PostGradStaticShapeChecker(InductorPass):
"""Asserts that compile_sizes entries produce graphs with fully concrete
(non-symbolic) shapes, and compile_ranges entries have symbolic shapes."""

def __init__(self):
self.num_static_calls = 0
self.num_dynamic_calls = 0

def __call__(self, graph: fx.Graph):
from torch.fx.experimental.symbolic_shapes import is_symbolic

compile_range = get_pass_context().compile_range
is_single = compile_range.is_single_size()

for node in graph.nodes:
val = node.meta.get("val")
if val is None:
val = node.meta.get("example_value")
if isinstance(val, torch.Tensor):
has_symbolic = any(is_symbolic(d) for d in val.shape)
if is_single:
assert not has_symbolic, (
f"compile_sizes entry {compile_range}: "
f"node '{node.name}' has symbolic shape "
f"{val.shape}"
)
else:
# compile_ranges should have at least some
# symbolic shapes (the batch dimension)
if has_symbolic:
self.num_dynamic_calls += 1
return

if is_single:
self.num_static_calls += 1

def uuid(self) -> str:
state: dict[str, Any] = {}
return InductorPass.hash_dict(state)


def test_compile_sizes_produce_static_shapes(use_fresh_inductor_cache):
"""Verify that compile_sizes entries are compiled with fully concrete
shapes (no SymInts), while compile_ranges entries retain dynamic shapes."""
checker = PostGradStaticShapeChecker()
torch.set_default_device("cuda")
vllm_config = VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_endpoints=[8],
compile_sizes=[16],
inductor_compile_config={
"post_grad_custom_post_pass": checker,
},
),
)

with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval()
# 3 compilations: Range(1,8), Range(9,8192), single-size 16
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=3,
):
run_model(vllm_config, model, [1, 16, 64])

# compile_sizes=16 should produce static shapes
assert checker.num_static_calls == 1, (
f"Expected 1 static compilation, got {checker.num_static_calls}"
)
# compile_ranges should produce dynamic shapes
assert checker.num_dynamic_calls == 2, (
f"Expected 2 dynamic compilations, got {checker.num_dynamic_calls}"
)


def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
# To force multiple compilations, we disable the compile cache
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
Expand Down
32 changes: 29 additions & 3 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,39 @@ def compile(
# Can remove this after the following issue gets fixed
# https://github.com/pytorch/pytorch/issues/174502
if envs.VLLM_ENABLE_PREGRAD_PASSES:
ctx: Any = contextlib.nullcontext()
pregrad_ctx: Any = contextlib.nullcontext()
else:
ctx = patch(
pregrad_ctx = patch(
"torch._inductor.compile_fx._recursive_pre_grad_passes",
lambda gm, _: gm,
)
with ctx, _patch_constrain_to_fx_strides():

# When inputs are FakeTensors (from create_concrete_args),
# standalone_compile("from_example_inputs") would normally create
# a fresh FakeTensorMode, causing a mode mismatch assertion.
# Patch FakeTensorMode in standalone_compile so it reuses the
# mode already attached to our FakeTensors. This gives us both
# ignore_shape_env=True (from "from_example_inputs") and mode
# consistency (from reusing our mode).
# Can remove this after the following issue gets fixed:
# https://github.com/pytorch/pytorch/issues/176562
from torch._subclasses.fake_tensor import FakeTensor

input_fake_mode = None
for x in example_inputs:
if isinstance(x, FakeTensor):
input_fake_mode = x.fake_mode
break

if input_fake_mode is not None:
fake_mode_ctx: Any = patch(
"torch._inductor.standalone_compile.FakeTensorMode",
lambda *a, **kw: input_fake_mode,
)
else:
fake_mode_ctx = contextlib.nullcontext()

with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides():
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)

if use_aot:
Expand Down
48 changes: 26 additions & 22 deletions vllm/compilation/piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]:


def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
"""Create example inputs with symbolic dims replaced by a concrete size.
"""Create Fake example inputs with symbolic dims replaced by a concrete size.

Used for single-size eager compilation where we need concrete-shaped
inputs but don't have real runtime tensors yet.
Used for single-size compilation where we need concrete-shaped inputs.
The Dynamo-captured graph gives us example inputs with SymInts in them.
"""
from torch._prims_common import compute_required_storage_length
from torch.fx.experimental.symbolic_shapes import is_symbolic
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_symbolic

def concretize(sym_val: Any) -> int:
"""Replace all symbolic variables in a SymInt expression with size."""
Expand All @@ -49,25 +50,28 @@ def concretize(sym_val: Any) -> int:
expr = sym_val.node.expr
return int(expr.subs({s: size for s in expr.free_symbols}))

fake_mode = FakeTensorMode(shape_env=ShapeEnv())

args: list[Any] = []
for node in graph.graph.nodes:
if node.op != "placeholder":
break
val = node.meta["example_value"]
if isinstance(val, torch.SymInt):
args.append(concretize(val))
elif isinstance(val, torch.Tensor):
new_shape = tuple(concretize(d) for d in val.shape)
new_strides = tuple(concretize(s) for s in val.stride())
new_storage_offset = concretize(val.storage_offset())
needed_size = compute_required_storage_length(
new_shape, new_strides, new_storage_offset
)
t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
t = t.as_strided(new_shape, new_strides, new_storage_offset)
args.append(t)
else:
args.append(val)
with fake_mode:
for node in graph.graph.nodes:
if node.op != "placeholder":
break
val = node.meta["example_value"]
if isinstance(val, torch.SymInt):
args.append(concretize(val))
elif isinstance(val, torch.Tensor):
new_shape = tuple(concretize(d) for d in val.shape)
new_strides = tuple(concretize(s) for s in val.stride())
new_storage_offset = concretize(val.storage_offset())
needed_size = compute_required_storage_length(
new_shape, new_strides, new_storage_offset
)
t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
t = t.as_strided(new_shape, new_strides, new_storage_offset)
args.append(t)
else:
args.append(val)
return args


Expand Down
Loading