Skip to content
Open
Show file tree
Hide file tree
Changes from 128 commits
Commits
Show all changes
164 commits
Select commit Hold shift + click to select a range
21d7d67
Functionalized patterns in prep for utility
ProExpertProg Sep 6, 2025
f3b4cf1
TEMP Mostly working
ProExpertProg Sep 9, 2025
cdad3c0
TEMP: fixed rmsnorm issue (TODO assert dtypes in fused norm_quant ker…
ProExpertProg Sep 12, 2025
8e4a56f
rms works fully now, had to remove more conversions (and add them in …
ProExpertProg Sep 16, 2025
e151e6d
quant works except (torch,torch)
ProExpertProg Sep 16, 2025
14fdc8b
quant with fix for pure torch, broke others
ProExpertProg Sep 18, 2025
05a65f3
ALL WORKS
ProExpertProg Sep 18, 2025
e6b394e
Add TODO
ProExpertProg Sep 20, 2025
d96913a
Cleanup test_fusion.py, added extra layer of rms/quant
ProExpertProg Sep 25, 2025
b172747
Functionalize attn+quant patterns
ProExpertProg Sep 25, 2025
1ae80c6
Move global vllm_config to pass manager
ProExpertProg Sep 25, 2025
77835fd
Attention fusion works with custom ops
ProExpertProg Sep 25, 2025
1277999
Remove V0 attn fusion test
ProExpertProg Sep 25, 2025
d843a67
Add triton attn test to attn+quant fusion
ProExpertProg Sep 26, 2025
cdd1529
Flat product for better test names/visibility
ProExpertProg Sep 26, 2025
141a37e
Fix rmsnorm
ProExpertProg Sep 26, 2025
c6d6c3b
Refactor E2E attn fusion test
ProExpertProg Sep 26, 2025
490ac86
Add TP=2 test (untested)
ProExpertProg Sep 26, 2025
d0b1b56
improve tests by adding more cases
ProExpertProg Sep 26, 2025
47b4688
TEMP working on caplog
ProExpertProg Sep 27, 2025
ae7f56f
Temp MP workaround P2
ProExpertProg Sep 30, 2025
eb899a4
Temp MP workaround P3
ProExpertProg Sep 30, 2025
a2aa978
Test for caplog utils
ProExpertProg Oct 1, 2025
21a9f9f
Fixed tests, passing with 2.8, 2.9 tbd
ProExpertProg Oct 2, 2025
66a35a9
Update tests/compile/backend.py
ProExpertProg Oct 2, 2025
7eb1364
Update csrc/layernorm_kernels.cu
ProExpertProg Oct 2, 2025
5fef180
clean up fullgraph tests
ProExpertProg Oct 2, 2025
db479ae
TEMP allreduce fusion
ProExpertProg Oct 2, 2025
54189a9
allreduce fusion working (custom ops on)
ProExpertProg Oct 3, 2025
b7f52bf
allreduce fusion working with/without custom ops (except fp4)
ProExpertProg Oct 3, 2025
d09a278
allreduce fusion working with/without custom ops (with fp4)
ProExpertProg Oct 3, 2025
c8675ff
log depyf folder, fix context for TestBackend, fix pattern dump
ProExpertProg Oct 3, 2025
d3f95fe
fullgraph allreduce test update requirements
ProExpertProg Oct 3, 2025
4dbfcf7
Move e2e tests to new file, add to test pipeline
ProExpertProg Oct 3, 2025
31d0127
Add e2e fusions to fullgraph test (should work with Triton backend), …
ProExpertProg Oct 3, 2025
c653d24
Fix spelling, precommit
ProExpertProg Oct 4, 2025
1756f67
add back fp4
ProExpertProg Oct 4, 2025
5619bc3
clean up e2e tests
ProExpertProg Oct 10, 2025
32989d8
add pattern for final allreduce in model
ProExpertProg Oct 10, 2025
46ee626
add more comprehensive testing for quantfp8 (-rmsnorm+-quant still fa…
ProExpertProg Oct 10, 2025
a1c7fdb
add more comprehensive testing for allreduce-rmsnorm, fix fp4 (-rmsno…
ProExpertProg Oct 10, 2025
c3264d8
Fix partial match rmsnorm+quant, fix allreduce+rmsnorm match
ProExpertProg Oct 10, 2025
095277c
Simplify matcher utils by using RMSNorm.forward_static
ProExpertProg Oct 10, 2025
52f78ce
Add allreduce test to 2-gpu test
ProExpertProg Oct 11, 2025
1b1a63e
Fix e2e allreduce fusion test
ProExpertProg Oct 11, 2025
0d6e550
fix func test
ProExpertProg Oct 12, 2025
26892df
fix pass manager test
ProExpertProg Oct 12, 2025
3547b87
fix sequence parallelism test
ProExpertProg Oct 12, 2025
af1ffa7
PR review
ProExpertProg Oct 15, 2025
97b3ff2
Merge remote-tracking branch 'upstream/main' into luka/custom-op-matc…
ProExpertProg Oct 15, 2025
b5f89e5
Cleanup test_full_graph.py
ProExpertProg Oct 15, 2025
f6429e4
Cleanup test_fusion_attn.py
ProExpertProg Oct 15, 2025
8a363d3
Slight improvement for E2E fusion
ProExpertProg Oct 15, 2025
12a7c6d
Tests & docs for flat_product
ProExpertProg Oct 15, 2025
db16ee1
Merge branch 'main' into luka/custom-op-matching-2
ProExpertProg Oct 15, 2025
8ffb474
Remove/fix TODOs
ProExpertProg Oct 15, 2025
2a6299c
Fix e2e test patterns
ProExpertProg Oct 15, 2025
465ce58
Update tests/compile/test_fusion.py
ProExpertProg Oct 15, 2025
bb0254a
Merge branch 'main' into luka/custom-op-matching-2
ProExpertProg Oct 15, 2025
bcd95b5
Fix func test
ProExpertProg Oct 15, 2025
db2b1c7
Smaller model for e2e fusion test
ProExpertProg Oct 15, 2025
a3ebf0a
fix fp8 quant tests
ProExpertProg Oct 15, 2025
3943257
Restore original torch.Parameter behavior in RMSNorm
ProExpertProg Oct 15, 2025
532cbcf
Add comment to test_logger
ProExpertProg Oct 15, 2025
7e6f5b3
add flat_product example
ProExpertProg Oct 15, 2025
24f1298
PR comments: cleanup fusion passes, & matching
ProExpertProg Oct 15, 2025
de7405b
PR comments: add _custom_op suffix
ProExpertProg Oct 15, 2025
6253d5b
Add e2e to L40 distributed, move tests to start of B200 distributed
ProExpertProg Oct 15, 2025
876ef22
Fix tests, PR feedback
ProExpertProg Oct 15, 2025
e99a759
Break up B200 tests, move allreduce to H200
ProExpertProg Oct 15, 2025
a226864
Merge branch 'main' into luka/custom-op-matching-2
ProExpertProg Oct 16, 2025
ae581e1
Fix attention fusion test numerics
ProExpertProg Oct 16, 2025
c03b29b
Remove inductor graph partition from unit test (included in e2e tests)
ProExpertProg Oct 16, 2025
d2e0489
Relax tolerance for L40 fusion test
ProExpertProg Oct 16, 2025
65ef5fd
Merge branch 'main' into luka/custom-op-matching-2
ProExpertProg Oct 16, 2025
d4fe977
Fix NamedTuple
ProExpertProg Oct 16, 2025
6319e39
Update test durations
ProExpertProg Oct 16, 2025
e34d36d
More tweaking of precision
ProExpertProg Oct 16, 2025
f72ee43
Split original pr
ilmarkov Sep 4, 2025
c4c0215
Update bench
ilmarkov Sep 5, 2025
309d79e
Update threshold configuration
ilmarkov Sep 8, 2025
afcfd73
Move all_reduce from custom op in fused_moe
ilmarkov Sep 8, 2025
0248dcd
Linter fixes
ilmarkov Oct 16, 2025
18e4771
Upd
ilmarkov Oct 16, 2025
1debd8e
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov Oct 21, 2025
9516d2b
Upd after review
ilmarkov Oct 21, 2025
b789044
Update fused_moe
ilmarkov Oct 27, 2025
4001935
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov Oct 27, 2025
6077616
Address comments
ilmarkov Nov 2, 2025
afc8af8
Remove bench_compile
ilmarkov Nov 2, 2025
c3af2af
Split PR. Second part. Compile ranges
ilmarkov Sep 4, 2025
0cbb065
Remove general shape graph
ilmarkov Sep 4, 2025
d5392f5
Add test to test pipeline
ilmarkov Sep 5, 2025
027c9eb
Fix pre-commit
ilmarkov Sep 9, 2025
b2992d3
Upd
ilmarkov Oct 16, 2025
3499384
Upd config
ilmarkov Oct 16, 2025
5336ee6
Fix
ilmarkov Oct 16, 2025
4958474
Priotitize compile_sizes
ilmarkov Oct 17, 2025
04306ed
Fix inductor config
ilmarkov Oct 28, 2025
9dc4eea
Laith's fix
ilmarkov Nov 3, 2025
2c63f0b
Upd
ilmarkov Nov 4, 2025
67f7ae1
Update config
ilmarkov Nov 4, 2025
8b8d01d
Merge branch 'imarkov/fused_allreduce_torch_native' into imarkov/cond…
ilmarkov Nov 4, 2025
fcebc21
Add caching
ilmarkov Nov 4, 2025
65151bc
Address comments
ilmarkov Nov 5, 2025
1f7afdb
Add debug log
ilmarkov Nov 5, 2025
8da1585
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov Nov 5, 2025
df22202
Update benchmark
ilmarkov Nov 5, 2025
a21de2b
Fix
ilmarkov Nov 5, 2025
45f4093
Update bench and constants
ilmarkov Nov 5, 2025
c26e056
Rename in benchmark
ilmarkov Nov 5, 2025
1bee5a6
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov Nov 5, 2025
bcc0cc0
Add max_token_num to object
ilmarkov Nov 5, 2025
43b163c
Add test
ilmarkov Nov 5, 2025
71c6b72
Update comments
ilmarkov Nov 6, 2025
ada24e6
Merge branch 'imarkov/fused_allreduce_torch_native' into imarkov/cond…
ilmarkov Nov 6, 2025
6766e4f
Update fakify for compile sizes
ilmarkov Nov 5, 2025
af87d7a
Linter fix
ilmarkov Nov 6, 2025
56273da
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov Nov 6, 2025
459f71c
Merge branch 'imarkov/fused_allreduce_torch_native' into imarkov/cond…
ilmarkov Nov 6, 2025
2785e4d
Minor updates
ilmarkov Nov 7, 2025
1f83a66
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ProExpertProg Nov 7, 2025
b4c1b1d
Address the review
ilmarkov Nov 10, 2025
ab33605
Merge branch 'main' into imarkov/fused_allreduce_torch_native
robertgshaw2-redhat Nov 10, 2025
3fac39b
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov Nov 10, 2025
b0a3884
Fix SP
ilmarkov Nov 10, 2025
a3e7bdc
Merge branch 'imarkov/fused_allreduce_torch_native' into imarkov/cond…
ilmarkov Nov 10, 2025
a810969
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 11, 2025
319abd5
Remove dynamic shape
ilmarkov Nov 12, 2025
d168de0
Make ranges inclusive-inclusive
ilmarkov Nov 13, 2025
b65e752
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 14, 2025
af10400
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 18, 2025
6c05919
Add test for inductor cache hits
ilmarkov Nov 19, 2025
03637e7
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 19, 2025
3f72483
Address comments
ilmarkov Nov 19, 2025
9b00ebc
Address comments
ilmarkov Nov 20, 2025
8a40ac6
Update test
ilmarkov Nov 20, 2025
ef05682
Address comments
ilmarkov Nov 20, 2025
63af962
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 20, 2025
7647089
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 21, 2025
ee89388
Update test utils
ilmarkov Nov 21, 2025
925e87d
Fix pre-commit after merge
ilmarkov Nov 21, 2025
809e170
Fix tests
ilmarkov Nov 21, 2025
e07c939
Add fixture instead of decorator
ilmarkov Nov 21, 2025
f4db45c
Fix re-used compilation config
ilmarkov Nov 23, 2025
97a8d58
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 23, 2025
4f280ce
Fix e2e
ilmarkov Nov 23, 2025
f714957
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 23, 2025
b27f89d
Fix e2e adapt to number of compile ranges
ilmarkov Nov 24, 2025
eedc70e
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 25, 2025
cc8f2f8
Slight fix of test
ilmarkov Nov 25, 2025
d1dd4db
Fix tests after refactor
ilmarkov Nov 25, 2025
a2b67a4
Simplify
ilmarkov Nov 25, 2025
0776364
Address comments
ilmarkov Nov 26, 2025
42bf355
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov Nov 26, 2025
ca832fc
Merge remote-tracking branch 'upstream/main' into imarkov/conditional…
ProExpertProg Dec 2, 2025
ba90b9e
Only warm up model if mode=VLLM_COMPILE
ProExpertProg Dec 2, 2025
771203f
Fix capture-sizes
ProExpertProg Dec 2, 2025
0e0eab9
Fix doc range
ProExpertProg Dec 2, 2025
3d2c36b
pre-commit
ProExpertProg Dec 2, 2025
18ff16e
Fix types for precommit
ProExpertProg Dec 3, 2025
6bc8258
Update vllm/v1/worker/gpu_worker.py
ProExpertProg Dec 3, 2025
c43458b
Merge remote-tracking branch 'upstream/main' into imarkov/conditional…
ProExpertProg Dec 3, 2025
f4c0ae7
Check that the pass was skipped in other range
ProExpertProg Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ steps:
- pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py
- pytest -v -s compile/test_aot_compile.py
- pytest -v -s compile/test_compile_ranges.py

- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30
Expand Down
104 changes: 104 additions & 0 deletions tests/compile/test_compile_ranges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import fx as fx
from torch import nn

# This import automatically registers `torch.ops.silly.attention`
import tests.compile.silly_attention # noqa
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.inductor_pass import (
InductorPass,
get_pass_context,
)
from vllm.config import (
VllmConfig,
set_current_vllm_config,
)
from vllm.config.compilation import CompilationConfig, CompilationMode
from vllm.config.scheduler import SchedulerConfig
from vllm.config.utils import Range
from vllm.forward_context import set_forward_context

BATCH_SIZE = 64
MLP_SIZE = 128


@support_torch_compile
class TestModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x


@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
with set_forward_context({}, vllm_config=vllm_config):
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
for batch_size in batch_sizes:
model(torch.randn(batch_size, MLP_SIZE).cuda())


class PostGradPassManagerCheckRanges(InductorPass):
def __init__(self, ranges: list[Range]):
self.ranges = ranges
self.num_calls = 0

def __call__(self, graph: fx.Graph):
compile_range = get_pass_context().compile_range
assert compile_range in self.ranges, (
f"Compile range {compile_range} not in {self.ranges}"
)
self.num_calls += 1

def uuid(self) -> str:
state = {
"ranges": [str(range) for range in self.ranges],
"current_compile_range": str(get_pass_context().compile_range),
}
return InductorPass.hash_dict(state)


def test_compile_ranges():
post_grad_pass_manager = PostGradPassManagerCheckRanges(
[
Range(start=1, end=8),
Range(start=8, end=32),
Range(start=32, end=8193),
]
)
vllm_config = VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8, 32],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_pass_manager,
# Disable inductor cache to get the number of passes correctly
"force_disable_caches": True,
},
),
)

with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda()
batch_sizes = [1, 4, 16, 24, 48, 64]
# A has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=3,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model(vllm_config, model, batch_sizes)
assert post_grad_pass_manager.num_calls == 3
86 changes: 40 additions & 46 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
Expand Down Expand Up @@ -83,7 +84,7 @@ class CompilerManager:
"""

def __init__(self, compilation_config: CompilationConfig):
self.cache: dict[tuple[int | None, int, str], Any] = dict()
self.cache: dict[tuple[Range | None, int, str], Any] = dict()
self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)
Expand All @@ -92,11 +93,11 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)

@contextmanager
def compile_context(self, runtime_shape: int | None = None):
def compile_context(self, compile_range: Range | None = None):
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape):
with pass_context(compile_range):
if self.compilation_config.use_inductor_graph_partition:
with inductor_partition_rule_context(
self.compilation_config.splitting_ops
Expand Down Expand Up @@ -152,26 +153,28 @@ def load(
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range | None = None,
) -> Callable | None:
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, runtime_shape
handle, graph, example_inputs, graph_index, compile_range
)
if runtime_shape is None:
if compile_range is None:
logger.debug(
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
"Directly load the %s-th graph for dynamic compile range"
"from %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via handle %s",
"Directly load the %s-th graph for compile range %s"
"from %s via handle %s",
graph_index,
str(runtime_shape),
str(compile_range),
self.compiler.name,
handle,
)
Expand All @@ -185,7 +188,7 @@ def compile(
compilation_config: CompilationConfig,
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: int | None = None,
compile_range: Range | None = None,
) -> Any:
if graph_index == 0:
# before compiling the first graph, record the start time
Expand All @@ -197,25 +200,25 @@ def compile(
compiled_graph = None

# try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
if compiled_graph is not None:
if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation.
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
if compile_range is None:
logger.info(
"Directly load the compiled graph(s) for dynamic shape "
"from the cache, took %.3f s",
elapsed,
)
else:
logger.info(
"Directly load the compiled graph(s) for shape %s "
"Directly load the compiled graph(s) for compile range %s "
"from the cache, took %.3f s",
str(runtime_shape),
str(compile_range),
elapsed,
)
return compiled_graph
Expand All @@ -226,48 +229,52 @@ def compile(
# Let compile_fx generate a key for us
maybe_key = None
else:
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"

with self.compile_context(runtime_shape):
maybe_key = "artifact_compile_range_"
if compile_range is None:
maybe_key += "dynamic_shape"
else:
maybe_key += f"{compile_range.start}_{compile_range.end}"
maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range):
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
runtime_shape,
compile_range,
maybe_key,
)

assert compiled_graph is not None, "Failed to compile the graph"

# store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True
if graph_index == 0:
# adds some info logging for the first graph
if runtime_shape is None:
if compile_range is None:
logger.info_once(
"Cache the graph for dynamic shape for later use", scope="local"
)
else:
logger.info_once(
"Cache the graph of shape %s for later use",
str(runtime_shape),
scope="local",
"Cache the graph of compile range %s for later use",
str(compile_range),
)
if runtime_shape is None:
if compile_range is None:
logger.debug(
"Store the %s-th graph for dynamic shape from %s via handle %s",
"Store the %s-th graph for dynamic compile range"
"from %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Store the %s-th graph for shape %s from %s via handle %s",
"Store the %s-th graph for compile range%s from %s via handle %s",
graph_index,
str(runtime_shape),
str(compile_range),
self.compiler.name,
handle,
)
Expand All @@ -277,16 +284,16 @@ def compile(
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
if compile_range is None:
logger.info_once(
"Compiling a graph for dynamic shape takes %.2f s",
"Compiling a graph for dynamic compile range takes %.2f s",
elapsed,
scope="local",
)
else:
logger.info_once(
"Compiling a graph for shape %s takes %.2f s",
runtime_shape,
"Compiling a graph for compile range %s takes %.2f s",
str(compile_range),
elapsed,
scope="local",
)
Expand Down Expand Up @@ -405,19 +412,7 @@ def call_module(
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time

compiled_graph_for_dynamic_shape = (
self.vllm_backend.compiler_manager.compile(
submod,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
)
)
# Lazy import here to avoid circular import
from .piecewise_backend import PiecewiseBackend

Expand All @@ -427,7 +422,6 @@ def call_module(
index,
len(self.compile_submod_names),
sym_shape_indices,
compiled_graph_for_dynamic_shape,
self.vllm_backend,
)

Expand Down
Loading