Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions examples/offline_inference_with_prefix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from time import time

from vllm import LLM, SamplingParams

# Common prefix.
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
Expand All @@ -18,36 +21,60 @@
"The capital of France is",
"The future of AI is",
]

generating_prompts = [prefix + prompt for prompt in prompts]

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)

# Create an LLM.
llm = LLM(model="facebook/opt-125m", enable_prefix_caching=True)
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)

generating_prompts = [prefix + prompt for prompt in prompts]
prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.4)
print("Results without `enable_prefix_caching`")

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(generating_prompts, sampling_params)
start_time_regular = time()
outputs = regular_llm.generate(generating_prompts, sampling_params)
duration_regular = time() - start_time_regular

regular_generated_texts = []
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
regular_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# The llm.generate call will batch all prompts and send the batch at once
# if resources allow. The prefix will only be cached after the first batch
# is processed, so we need to call generate once to calculate the prefix
# and cache it.
outputs = llm.generate(generating_prompts[0], sampling_params)
# if resources allow.
start_time_cached = time()
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
duration_cached = time() - start_time_cached

# Subsequent batches can leverage the cached prefix
outputs = llm.generate(generating_prompts, sampling_params)
print("Results with `enable_prefix_caching`")

# Print the outputs. You should see the same outputs as before
cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Compare the results and display the speedup
generated_same = all([
regular_generated_texts[i] == cached_generated_texts[i]
for i in range(len(prompts))
])
print(f"Generated answers are the same: {generated_same}")

speedup = round(duration_regular / duration_cached, 2)
print(f"Speed up of cached generation compared to the regular is: {speedup}")