Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][torch.compile] not compile for profiling #7796

Merged
merged 17 commits into from
Aug 27, 2024
2 changes: 1 addition & 1 deletion .buildkite/run-tpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ remove_docker_container
source /etc/environment
# Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
python3 /workspace/vllm/examples/offline_inference_tpu.py
python3 /workspace/vllm/tests/tpu/test_compilation.py
3 changes: 3 additions & 0 deletions Dockerfile.tpu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ WORKDIR /workspace
RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# for testing torch.compile
RUN python3 -m pip install git+https://github.com/thuml/depyf.git

youkaichao marked this conversation as resolved.
Show resolved Hide resolved
# Build vLLM.
COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu"
Expand Down
18 changes: 18 additions & 0 deletions tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import glob
import os
import runpy
import tempfile

import depyf

temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
cur_dir = os.path.dirname(__file__)
parent_dir = os.path.dirname(cur_dir)
root_dir = os.path.dirname(parent_dir)
example_file = os.path.join(root_dir, "examples",
"offline_inference_tpu.py")
runpy.run_path(example_file)

compiled_code = glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))
assert len(compiled_code) == 2
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 10 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,7 @@ def load_model(self) -> None:
"This may lead to less accurate results!")

if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
self._not_compiled_model = self.model
self.model = torch.compile(self.model,
fullgraph=True,
backend="eager")
Expand Down Expand Up @@ -1094,7 +1095,11 @@ def profile_run(self) -> None:
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
# no compilation is needed for profiling memory
self.execute_model(model_input,
kv_caches,
intermediate_tensors,
compiled=False)
torch.cuda.synchronize()
return

Expand Down Expand Up @@ -1367,6 +1372,7 @@ def execute_model(
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
compiled: bool = True,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
Expand Down Expand Up @@ -1398,8 +1404,10 @@ def execute_model(
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else:
elif compiled:
model_executable = self.model
else:
model_executable = self.not_compiled_model
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
Expand Down
11 changes: 11 additions & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ def execute_model(
"""
raise NotImplementedError

@property
def not_compiled_model(self):
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
"""
Return the model that is not compiled.
"""
if hasattr(self, "_not_compiled_model"):
return self._not_compiled_model
if hasattr(self, "model"):
return self.model
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("No model found")

def get_generators(self, finished_request_ids: Optional[List[str]] = None):
"""
Return dict of per-request generators used for random sampling.
Expand Down
9 changes: 7 additions & 2 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def load_model(self) -> None:
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
self._not_compiled_model = model
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
Expand All @@ -156,6 +157,7 @@ def _dummy_run(
seq_len: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
is_prompt: bool,
compiled: bool = True,
) -> None:
if is_prompt:
seq_len = (seq_len + 15) // 16 * 16
Expand Down Expand Up @@ -235,7 +237,8 @@ def _dummy_run(
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
executable = self.model if compiled else self.not_compiled_model
executable(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)

def warmup_model(
Expand Down Expand Up @@ -510,6 +513,7 @@ def execute_model(
kv_caches: Optional[List[Any]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
compiled: bool = True,
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if num_steps > 1:
Expand All @@ -530,7 +534,8 @@ def _execute_model(*args):
if getattr(arg, "context_lens", None) is not None:
arg.context_lens = arg.context_lens.to(self.device)
new_args.append(arg)
return self.model(*new_args)
executable = self.model if compiled else self.not_compiled_model
return executable(*new_args)

num_prefills = model_input.attn_metadata.num_prefills
is_prompt = num_prefills > 0
Expand Down
1 change: 1 addition & 0 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
seq_len=self.scheduler_config.max_num_batched_tokens,
kv_caches=kv_caches,
is_prompt=True,
compiled=False,
)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
Expand Down
Loading