|
| 1 | +# Prompt Embedding Inputs |
| 2 | + |
| 3 | +This page teaches you how to pass prompt embedding inputs to vLLM. |
| 4 | + |
| 5 | +## What are prompt embeddings? |
| 6 | + |
| 7 | +The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary. |
| 8 | + |
| 9 | +:::{note} |
| 10 | +Prompt embeddings are currently only supported in the v0 engine. |
| 11 | +::: |
| 12 | + |
| 13 | +## Offline Inference |
| 14 | + |
| 15 | +To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`: |
| 16 | + |
| 17 | +- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model. |
| 18 | + |
| 19 | +### Hugging Face Transformers Inputs |
| 20 | + |
| 21 | +You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples: |
| 22 | + |
| 23 | +```python |
| 24 | +from vllm import LLM |
| 25 | +import transformers |
| 26 | + |
| 27 | +model_name = "meta-llama/Llama-3.2-1B-Instruct" |
| 28 | + |
| 29 | +# Transformers |
| 30 | +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) |
| 31 | +transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) |
| 32 | + |
| 33 | +llm = LLM(model=model_name, enable_prompt_embeds=True) |
| 34 | + |
| 35 | +# Refer to the HuggingFace repo for the correct format to use |
| 36 | +chat = [{"role": "user", "content": "Please tell me about the capital of France."}] |
| 37 | +token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') |
| 38 | + |
| 39 | +prompt_embeds = embedding_layer(token_ids).squeeze(0) |
| 40 | + |
| 41 | +# Single prompt inference |
| 42 | +outputs = llm.generate({ |
| 43 | + "prompt_embeds": prompt_embeds, |
| 44 | +}) |
| 45 | + |
| 46 | +for o in outputs: |
| 47 | + generated_text = o.outputs[0].text |
| 48 | + print(generated_text) |
| 49 | + |
| 50 | +# Batch inference |
| 51 | + |
| 52 | +chats = [ |
| 53 | + [{"role": "user", "content": "Please tell me about the capital of France."}], |
| 54 | + [{"role": "user", "content": "When is the day longest during the year?"}], |
| 55 | + [{"role": "user", "content": "Where is bigger, the moon or the sun?"}] |
| 56 | +] |
| 57 | + |
| 58 | +token_ids_list = [ |
| 59 | + tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats |
| 60 | +] |
| 61 | +prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list] |
| 62 | + |
| 63 | +outputs = llm.generate( |
| 64 | + [ |
| 65 | + { |
| 66 | + "prompt_embeds": prompt_embeds, |
| 67 | + } for prompt_embeds in prompt_embeds_list |
| 68 | + ] |
| 69 | +) |
| 70 | + |
| 71 | +for o in outputs: |
| 72 | + generated_text = o.outputs[0].text |
| 73 | + print(generated_text) |
| 74 | +``` |
| 75 | + |
| 76 | +## Online Serving |
| 77 | + |
| 78 | +Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. |
| 79 | + |
| 80 | +When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. |
| 81 | + |
| 82 | +Prompt embeddings are passed in as base64 encoded torch tensors. |
| 83 | + |
| 84 | +### Transformers Inputs via OpenAI Client |
| 85 | + |
| 86 | +First, launch the OpenAI-compatible server: |
| 87 | + |
| 88 | +```bash |
| 89 | +vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \ |
| 90 | + --max-model-len 4096 --enable-prompt-embeds |
| 91 | +``` |
| 92 | + |
| 93 | +Then, you can use the OpenAI client as follows: |
| 94 | + |
| 95 | +```python |
| 96 | +from openai import OpenAI |
| 97 | +import transformers |
| 98 | +import torch |
| 99 | + |
| 100 | +openai_api_key = "EMPTY" |
| 101 | +openai_api_base = "http://localhost:8000/v1" |
| 102 | + |
| 103 | +client = OpenAI( |
| 104 | + api_key=openai_api_key, |
| 105 | + base_url=openai_api_base, |
| 106 | +) |
| 107 | + |
| 108 | +model_name = "meta-llama/Llama-3.2-1B-Instruct" |
| 109 | + |
| 110 | +# Transformers |
| 111 | +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) |
| 112 | +transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) |
| 113 | + |
| 114 | + |
| 115 | +# Refer to the HuggingFace repo for the correct format to use |
| 116 | +chat = [{"role": "user", "content": "Please tell me about the capital of France."}] |
| 117 | +token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') |
| 118 | + |
| 119 | +prompt_embeds = embedding_layer(token_ids).squeeze(0) |
| 120 | + |
| 121 | +# Prompt embeddings |
| 122 | +buffer = io.BytesIO() |
| 123 | +torch.save(prompt_embeds, buffer) |
| 124 | +buffer.seek(0) |
| 125 | +binary_data = buffer.read() |
| 126 | +encoded_embeds = base64.b64encode(binary_data).decode('utf-8') |
| 127 | + |
| 128 | + |
| 129 | +completion = client_with_prompt_embeds.completions.create( |
| 130 | + model=model_name, |
| 131 | + # NOTE: The OpenAI client does not allow `None` as an input to |
| 132 | + # `prompt`. Use an empty string if you have no text prompts. |
| 133 | + prompt="", |
| 134 | + max_tokens=5, |
| 135 | + temperature=0.0, |
| 136 | + # NOTE: The OpenAI client allows passing in extra JSON body via the |
| 137 | + # `extra_body` argument. |
| 138 | + extra_body={"prompt_embeds": encoded_embeds} |
| 139 | +) |
| 140 | + |
| 141 | +print(completion.choices[0].text) |
| 142 | +``` |
0 commit comments