Skip to content

Commit

Permalink
[core][torch.compile] discard the compile for profiling (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Aug 27, 2024
1 parent d13cb0c commit 257b7bb
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 2 deletions.
3 changes: 1 addition & 2 deletions .buildkite/run-tpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ remove_docker_container
# For HF_TOKEN.
source /etc/environment
# Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
python3 /workspace/vllm/examples/offline_inference_tpu.py
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
34 changes: 34 additions & 0 deletions tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import glob
import os
import runpy
import tempfile

import depyf

temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
cur_dir = os.path.dirname(__file__)
parent_dir = os.path.dirname(cur_dir)
root_dir = os.path.dirname(parent_dir)
example_file = os.path.join(root_dir, "examples",
"offline_inference_tpu.py")
runpy.run_path(example_file)

compiled_code = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0]
# we should only trigger Dynamo compilation three times:
# one for the profiling phase (and the compiled artifact will be discarded)
# one for the prefill phase with symbolic shapes
# one for the decode phase with symbolic shapes
# and later calls should not trigger Dynamo compilation again.
# NOTE: it might still trigger XLA compilation.

# check we have three compiled code
assert len(compiled_code) == 3

# check the first compilation is discarded
with open(full_code) as f:
full_code_content = f.read()
profile_function = compiled_code[0].split(".")[0]
assert profile_function not in full_code_content
4 changes: 4 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,10 @@ def profile_run(self) -> None:
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()

# reset and discard the guard and compiled bytecode for profiling runs
torch._dynamo.reset()

return

def remove_all_loras(self):
Expand Down
4 changes: 4 additions & 0 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.

# reset and discard the guard and compiled bytecode for profiling runs
torch._dynamo.reset()

return num_tpu_blocks, num_cpu_blocks

def initialize_cache(
Expand Down

0 comments on commit 257b7bb

Please sign in to comment.