From 14929a49a48cce9b2e72df9264c703dedc9875bc Mon Sep 17 00:00:00 2001 From: rzou Date: Fri, 18 Apr 2025 08:29:05 -0700 Subject: [PATCH] Log how much time loading a compiled artifact takes When people say they don't like the "torch.compile startup time", they mean two things: 1) the cold start time 2) the warm start time (when the vLLM disk cache has already been populated). We had logging for (1), we didn't have (2). This PR adds (2) Test Plan: I ran `VLLM_USE_V1=1 python benchmark_latency.py --model meta-llama/Meta-Llama-3-8B --batch-size 1 -O '{"level": 3, "compile_sizes": {1, 2}}'` And observed the following logs: ``` INFO 04-18 08:26:11 [backends.py:431] Dynamo bytecode transform time: 5.03 s INFO 04-18 08:26:15 [backends.py:120] Directly load the compiled graph(s) for shape None from the cache, took 4.190 s INFO 04-18 08:26:18 [kv_cache_utils.py:634] GPU KV cache size: 532,032 tokens ``` Side note: it's probably not good that loading from the cache takes 4 seconds? Signed-off-by: rzou --- vllm/compilation/backends.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 45988c2e9b0d..c493a764f56d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -110,10 +110,14 @@ def compile(self, compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) if compiled_graph is not None: - if graph_index == 0: - # adds some info logging for the first graph - logger.info("Directly load the compiled graph for shape %s " - "from the cache", str(runtime_shape)) # noqa + if graph_index == num_graphs - 1: + # after loading the last graph for this shape, record the time. + # there can be multiple graphs due to piecewise compilation. + now = time.time() + elapsed = now - compilation_start_time + logger.info( + "Directly load the compiled graph(s) for shape %s " + "from the cache, took %.3f s", str(runtime_shape), elapsed) return compiled_graph # no compiler cached the graph, or the cache is disabled,