Skip to content

Commit 4e92ff8

Browse files
committed
Fix tests: add forward context
Signed-off-by: luka <[email protected]>
1 parent 93949f8 commit 4e92ff8

File tree

2 files changed

+26
-35
lines changed

2 files changed

+26
-35
lines changed

tests/compile/piecewise/test_simple.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Test the piecewise compilation with a simple model so that we
55
can exactly calculate the expected output and side effects.
66
"""
7-
7+
import pytest
88
import torch
99
from torch import nn
1010
from torch.library import Library
@@ -14,6 +14,7 @@
1414
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
1515
set_current_vllm_config)
1616
from vllm.envs import VLLM_USE_V1
17+
from vllm.forward_context import set_forward_context
1718
from vllm.utils import direct_register_custom_op
1819

1920
global_counter = 0
@@ -76,7 +77,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7677
return x
7778

7879

79-
def _test_simple_piecewise_compile(*, use_inductor):
80+
@pytest.mark.parametrize("use_inductor", [True, False])
81+
def test_simple_piecewise_compile(use_inductor):
8082
assert VLLM_USE_V1
8183

8284
vllm_config = VllmConfig(compilation_config=CompilationConfig(
@@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
99101
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
100102
num_cudagraph_captured=
101103
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
102-
):
104+
), set_forward_context({}, vllm_config=vllm_config):
103105

104106
model(inputs)
105107

@@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
112114
output = model(input)
113115
assert global_counter == 2
114116
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
115-
116-
117-
def test_simple_piecewise_compile_inductor():
118-
_test_simple_piecewise_compile(use_inductor=True)
119-
120-
121-
def test_simple_piecewise_compile_no_inductor():
122-
_test_simple_piecewise_compile(use_inductor=False)

tests/compile/piecewise/test_toy_llama.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from dataclasses import dataclass
1212
from typing import Any, Optional
1313

14+
import pytest
1415
import torch
1516
from torch import nn
1617
from torch.library import Library
@@ -19,6 +20,7 @@
1920
from vllm.compilation.decorators import support_torch_compile
2021
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
2122
set_current_vllm_config)
23+
from vllm.forward_context import set_forward_context
2224
from vllm.utils import direct_register_custom_op
2325

2426
# create a library to hold the custom op
@@ -285,29 +287,32 @@ def run_model(llama_config,
285287
vllm_config=vllm_config,
286288
prefix="").eval().cuda()
287289

288-
B = 16 # max batch size
289-
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
290-
positions = torch.arange(B).cuda()
290+
with set_forward_context({}, vllm_config=vllm_config):
291+
B = 16 # max batch size
292+
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
293+
positions = torch.arange(B).cuda()
291294

292-
model(input_ids, positions)
293-
model(input_ids[:2], positions[:2])
294-
model(input_ids[:1], positions[:1])
295+
model(input_ids, positions)
296+
model(input_ids[:2], positions[:2])
297+
model(input_ids[:1], positions[:1])
295298

296-
input_ids[:2].zero_()
297-
output = model(input_ids[:2], positions[:2])
299+
input_ids[:2].zero_()
300+
output = model(input_ids[:2], positions[:2])
298301

299-
output = output.cpu()
302+
output = output.cpu()
300303

301-
if llama_config.tractable_init:
302-
expected_output = tractable_computation(input_ids[:2], positions[:2],
303-
llama_config).cpu()
304+
if llama_config.tractable_init:
305+
expected_output = tractable_computation(input_ids[:2],
306+
positions[:2],
307+
llama_config).cpu()
304308

305-
assert torch.allclose(output, expected_output)
306-
else:
307-
return output.cpu()
309+
assert torch.allclose(output, expected_output)
310+
else:
311+
return output.cpu()
308312

309313

310-
def _test_toy_llama(*, use_inductor):
314+
@pytest.mark.parametrize("use_inductor", [True, False])
315+
def test_toy_llama(use_inductor: bool):
311316
# compare output with and without piecewise compilation
312317

313318
llama_config = LlamaConfig(hidden_size=128,
@@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor):
379384
assert torch.allclose(outputs[0], outputs[i])
380385

381386

382-
def test_toy_llama_inductor():
383-
_test_toy_llama(use_inductor=True)
384-
385-
386-
def test_toy_no_inductor():
387-
_test_toy_llama(use_inductor=False)
388-
389-
390387
@torch.inference_mode
391388
def benchmark():
392389
from triton.testing import do_bench

0 commit comments

Comments
 (0)