|
11 | 11 | from dataclasses import dataclass |
12 | 12 | from typing import Any, Optional |
13 | 13 |
|
| 14 | +import pytest |
14 | 15 | import torch |
15 | 16 | from torch import nn |
16 | 17 | from torch.library import Library |
|
19 | 20 | from vllm.compilation.decorators import support_torch_compile |
20 | 21 | from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, |
21 | 22 | set_current_vllm_config) |
| 23 | +from vllm.forward_context import set_forward_context |
22 | 24 | from vllm.utils import direct_register_custom_op |
23 | 25 |
|
24 | 26 | # create a library to hold the custom op |
@@ -285,29 +287,32 @@ def run_model(llama_config, |
285 | 287 | vllm_config=vllm_config, |
286 | 288 | prefix="").eval().cuda() |
287 | 289 |
|
288 | | - B = 16 # max batch size |
289 | | - input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() |
290 | | - positions = torch.arange(B).cuda() |
| 290 | + with set_forward_context({}, vllm_config=vllm_config): |
| 291 | + B = 16 # max batch size |
| 292 | + input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() |
| 293 | + positions = torch.arange(B).cuda() |
291 | 294 |
|
292 | | - model(input_ids, positions) |
293 | | - model(input_ids[:2], positions[:2]) |
294 | | - model(input_ids[:1], positions[:1]) |
| 295 | + model(input_ids, positions) |
| 296 | + model(input_ids[:2], positions[:2]) |
| 297 | + model(input_ids[:1], positions[:1]) |
295 | 298 |
|
296 | | - input_ids[:2].zero_() |
297 | | - output = model(input_ids[:2], positions[:2]) |
| 299 | + input_ids[:2].zero_() |
| 300 | + output = model(input_ids[:2], positions[:2]) |
298 | 301 |
|
299 | | - output = output.cpu() |
| 302 | + output = output.cpu() |
300 | 303 |
|
301 | | - if llama_config.tractable_init: |
302 | | - expected_output = tractable_computation(input_ids[:2], positions[:2], |
303 | | - llama_config).cpu() |
| 304 | + if llama_config.tractable_init: |
| 305 | + expected_output = tractable_computation(input_ids[:2], |
| 306 | + positions[:2], |
| 307 | + llama_config).cpu() |
304 | 308 |
|
305 | | - assert torch.allclose(output, expected_output) |
306 | | - else: |
307 | | - return output.cpu() |
| 309 | + assert torch.allclose(output, expected_output) |
| 310 | + else: |
| 311 | + return output.cpu() |
308 | 312 |
|
309 | 313 |
|
310 | | -def _test_toy_llama(*, use_inductor): |
| 314 | +@pytest.mark.parametrize("use_inductor", [True, False]) |
| 315 | +def test_toy_llama(use_inductor: bool): |
311 | 316 | # compare output with and without piecewise compilation |
312 | 317 |
|
313 | 318 | llama_config = LlamaConfig(hidden_size=128, |
@@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor): |
379 | 384 | assert torch.allclose(outputs[0], outputs[i]) |
380 | 385 |
|
381 | 386 |
|
382 | | -def test_toy_llama_inductor(): |
383 | | - _test_toy_llama(use_inductor=True) |
384 | | - |
385 | | - |
386 | | -def test_toy_no_inductor(): |
387 | | - _test_toy_llama(use_inductor=False) |
388 | | - |
389 | | - |
390 | 387 | @torch.inference_mode |
391 | 388 | def benchmark(): |
392 | 389 | from triton.testing import do_bench |
|
0 commit comments