generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 461
vLLM Model Provider implementation #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 6 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
770f8a8
Add files via upload
AhilanPonnusamy ebf20fe
Update README.md
AhilanPonnusamy b79d449
Update pyproject.toml
AhilanPonnusamy 6bab061
Add files via upload
AhilanPonnusamy 2137064
Add files via upload
AhilanPonnusamy 5a9c5d2
Update pyproject.toml
AhilanPonnusamy c0e2639
Update vllm.py
AhilanPonnusamy 344749a
Update vllm.py
AhilanPonnusamy 086e6f5
Update test_vllm.py
AhilanPonnusamy 624e5cb
Update test_model_vllm.py
AhilanPonnusamy d31649e
Update README.md
AhilanPonnusamy 7e85e87
Update vllm.py
AhilanPonnusamy 1e4d14c
Update test_vllm.py
AhilanPonnusamy 9a1f835
Update test_model_vllm.py
AhilanPonnusamy 35c5d52
Update test_model_vllm.py
AhilanPonnusamy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| """vLLM model provider. | ||
|
|
||
| - Docs: https://github.com/vllm-project/vllm | ||
| """ | ||
|
|
||
| import json | ||
| import logging | ||
| from typing import Any, Iterable, Optional | ||
|
|
||
| import requests | ||
| from typing_extensions import TypedDict, Unpack, override | ||
|
|
||
| from ..types.content import Messages | ||
| from ..types.models import Model | ||
| from ..types.streaming import StreamEvent | ||
| from ..types.tools import ToolSpec | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class VLLMModel(Model): | ||
| """vLLM model provider implementation. | ||
|
|
||
| Assumes OpenAI-compatible vLLM server at `http://<host>/v1/completions`. | ||
|
|
||
| The implementation handles vLLM-specific features such as: | ||
|
|
||
| - Local model invocation | ||
| - Streaming responses | ||
| - Tool/function calling | ||
| """ | ||
|
|
||
| class VLLMConfig(TypedDict, total=False): | ||
| """Configuration parameters for vLLM models. | ||
|
|
||
| Attributes: | ||
| additional_args: Any additional arguments to include in the request. | ||
| max_tokens: Maximum number of tokens to generate in the response. | ||
| model_id: vLLM model ID (e.g., "meta-llama/Llama-3.2-3B,microsoft/Phi-3-mini-128k-instruct"). | ||
| options: Additional model parameters (e.g., top_k). | ||
| temperature: Controls randomness in generation (higher = more random). | ||
| top_p: Controls diversity via nucleus sampling (alternative to temperature). | ||
| """ | ||
|
|
||
| model_id: str | ||
| temperature: Optional[float] | ||
| top_p: Optional[float] | ||
| max_tokens: Optional[int] | ||
| stop_sequences: Optional[list[str]] | ||
| additional_args: Optional[dict[str, Any]] | ||
|
|
||
| def __init__(self, host: str, **model_config: Unpack[VLLMConfig]) -> None: | ||
| """Initialize provider instance. | ||
|
|
||
| Args: | ||
| host: The address of the vLLM server hosting the model. | ||
| **model_config: Configuration options for the vLLM model. | ||
| """ | ||
| self.config = VLLMModel.VLLMConfig(**model_config) | ||
| self.host = host.rstrip("/") | ||
| logger.debug("Initializing vLLM provider with config: %s", self.config) | ||
|
|
||
| @override | ||
| def update_config(self, **model_config: Unpack[VLLMConfig]) -> None: | ||
| """Update the vLLM Model configuration with the provided arguments. | ||
|
|
||
| Args: | ||
| **model_config: Configuration overrides. | ||
| """ | ||
| self.config.update(model_config) | ||
|
|
||
| @override | ||
| def get_config(self) -> VLLMConfig: | ||
| """Get the vLLM model configuration. | ||
|
|
||
| Returns: | ||
| The vLLM model configuration. | ||
| """ | ||
| return self.config | ||
|
|
||
| @override | ||
| def format_request( | ||
| self, | ||
| messages: Messages, | ||
| tool_specs: Optional[list[ToolSpec]] = None, | ||
| system_prompt: Optional[str] = None, | ||
| ) -> dict[str, Any]: | ||
| """Format an vLLM chat streaming request. | ||
|
|
||
| Args: | ||
| messages: List of message objects to be processed by the model. | ||
| tool_specs: List of tool specifications to make available to the model. | ||
| system_prompt: System prompt to provide context to the model. | ||
|
|
||
| Returns: | ||
| An vLLM chat streaming request. | ||
| """ | ||
|
|
||
| # Concatenate messages to form a prompt string | ||
| prompt_parts = [ | ||
| f"{msg['role']}: {content['text']}" for msg in messages for content in msg["content"] if "text" in content | ||
| ] | ||
| if system_prompt: | ||
| prompt_parts.insert(0, f"system: {system_prompt}") | ||
| prompt = "\n".join(prompt_parts) + "\nassistant:" | ||
|
|
||
| payload = { | ||
| "model": self.config["model_id"], | ||
| "prompt": prompt, | ||
| "temperature": self.config.get("temperature", 0.7), | ||
| "top_p": self.config.get("top_p", 1.0), | ||
| "max_tokens": self.config.get("max_tokens", 128), | ||
| "stop": self.config.get("stop_sequences"), | ||
| "stream": False, # Disable streaming | ||
| } | ||
|
|
||
| if self.config.get("additional_args"): | ||
| payload.update(self.config["additional_args"]) | ||
|
|
||
| return payload | ||
|
|
||
| @override | ||
| def format_chunk(self, event: dict[str, Any]) -> StreamEvent: | ||
| """Format the vLLM response events into standardized message chunks. | ||
|
|
||
| Args: | ||
| event: A response event from the vLLM model. | ||
|
|
||
| Returns: | ||
| The formatted chunk. | ||
|
|
||
| """ | ||
| choice = event.get("choices", [{}])[0] | ||
|
|
||
| if "text" in choice: | ||
| return {"contentBlockDelta": {"delta": {"text": choice["text"]}}} | ||
|
|
||
| if "finish_reason" in choice: | ||
| return {"messageStop": {"stopReason": choice["finish_reason"] or "end_turn"}} | ||
|
|
||
| return {} | ||
|
|
||
| @override | ||
| def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: | ||
| """Send the request to the vLLM model and get the streaming response. | ||
|
|
||
| This method calls the /v1/completions endpoint and returns the stream of response events. | ||
|
|
||
| Args: | ||
| request: The formatted request to send to the vLLM model. | ||
|
|
||
| Returns: | ||
| An iterable of response events from the vLLM model. | ||
| """ | ||
| headers = {"Content-Type": "application/json"} | ||
| url = f"{self.host}/v1/completions" | ||
| request["stream"] = True # Enable streaming | ||
|
|
||
| full_output = "" | ||
|
|
||
| try: | ||
| with requests.post(url, headers=headers, data=json.dumps(request), stream=True) as response: | ||
| if response.status_code != 200: | ||
| logger.error("vLLM server error: %d - %s", response.status_code, response.text) | ||
| raise Exception(f"Request failed: {response.status_code} - {response.text}") | ||
|
|
||
| yield {"chunk_type": "message_start"} | ||
| yield {"chunk_type": "content_start", "data_type": "text"} | ||
|
|
||
| for line in response.iter_lines(decode_unicode=True): | ||
| if not line: | ||
| continue | ||
|
|
||
| if line.startswith("data: "): | ||
| line = line[len("data: ") :] | ||
|
|
||
| if line.strip() == "[DONE]": | ||
| break | ||
|
|
||
| try: | ||
| data = json.loads(line) | ||
| choice = data.get("choices", [{}])[0] | ||
| text = choice.get("text", "") | ||
| finish_reason = choice.get("finish_reason") | ||
|
|
||
| if text: | ||
| full_output += text | ||
| print(text, end="", flush=True) # Stream to stdout without newline | ||
| yield { | ||
| "chunk_type": "content_delta", | ||
| "data_type": "text", | ||
| "data": text, | ||
| } | ||
|
|
||
| if finish_reason: | ||
| yield {"chunk_type": "content_stop", "data_type": "text"} | ||
| yield {"chunk_type": "message_stop", "data": finish_reason} | ||
| break | ||
|
|
||
| except json.JSONDecodeError: | ||
| logger.warning("Failed to decode streamed line: %s", line) | ||
|
|
||
| else: | ||
| yield {"chunk_type": "content_stop", "data_type": "text"} | ||
| yield {"chunk_type": "message_stop", "data": "end_turn"} | ||
|
|
||
| except requests.RequestException as e: | ||
| logger.error("Request to vLLM failed: %s", str(e)) | ||
| raise Exception("Failed to reach vLLM server") from e | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| import pytest | ||
| import strands | ||
| from strands import Agent | ||
| from strands.models.vllm import VLLMModel | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def model(): | ||
| return VLLMModel( | ||
| model_id="meta-llama/Llama-3.2-3B", # or whatever your model ID is | ||
| host="http://localhost:8000", # adjust as needed | ||
| max_tokens=128, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def tools(): | ||
| @strands.tool | ||
| def tool_time() -> str: | ||
| return "12:00" | ||
|
|
||
| @strands.tool | ||
| def tool_weather() -> str: | ||
| return "cloudy" | ||
|
|
||
| return [tool_time, tool_weather] | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def agent(model, tools): | ||
| return Agent(model=model, tools=tools) | ||
|
|
||
|
|
||
| def test_agent(agent): | ||
| result = agent("What is the time and weather in Melboune Australia?") | ||
| text = result.message["content"][0]["text"].lower() | ||
|
|
||
| assert all(string in text for string in ["3:00", "cloudy"]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since vLLM assumes OpenAI compatibility (docs), I think it actually makes sense for us to generalize here and instead of creating a
VLLMModelprovider, we create anOpenAIModelprovider. This is something the Strands Agents team is already discussing as we have other OpenAI compatible providers that could all share theformat_requestandformat_chunklogic.With that said, I would suggest keeping this PR open for the time being as we further work out the details. Thank you for your contribution and patience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it @pgrayy, will wait for the updates. The one thing we will miss out on this approach is the support for the native vLLM APIs via native vLLM endpoint, designed for direct use with the vLLM engine. Do you foresee just that aspect being a vLLM Model Provider by itself? I see this VLLModel provider gradually built to cover all aspects of vLLM including OpenAI compatible end points.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the PR for the OpenAI model provider. This should work to connect to models served with vLLM.
Regarding your question, could you elaborate on what you mean by native vLLM endpoint? It should still be OpenAI compatible correct? Based on what I am reading in the docs, you can query vLLM using the openai client, which is what this new Strands OpenAI model provider uses under the hood.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @pgrayy , Regarding vLLM endpoints, there are two endpoints provided by vLLM.
openai.api_server in vLLM provides an OpenAI-compatible REST API layer, mimicking the behavior and request/response format of OpenAI's API (/v1/chat/completions, etc.).This layer supports both standard and streaming completion modes.
api_server, on the other hand, is vLLM’s native API server, offering endpoints like /generate or /completion, designed specifically for internal or custom integrations.While api_server is more flexible and may expose additional low-level features, openai.api_server ensures broader compatibility with the OpenAI ecosystem.
You can run either by setting the appropriate flag in vllm.entrypoints.openai.api_server or vllm.entrypoints.api_server when launching the server.
On another note, I tried to test the OpenAI Model provider from PR against my local vLLM instance. I couldn't pass the API_KEY stage even after setting the env variable "openai.AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: empty. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}" .
As you suggested, shall we keep this vLLM Model Provider PR open until we get the OpenAI one working, or would it be better to merge it for now and deprecate it later if it's deemed redundant?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you share more details on your testing? What happens when you try the following:
api_keywill need to be explicitly passed into the model provider unless you set theOPENAI_API_KEYenvironment variable.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No luck @pgrayy please find below
call
OpenAIModel(
api_key="abc123",
base_url="http://localhost:8000",
model_id="Qwen/Qwen3-4B", # Qwen/Qwen3-8B
Error
openai.AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: abc123. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}
Also, I think the way the tool_calls are emitted will be different for different Model Providers. I had to convert the tool_call to a function call with a different format to make it work for vLLM Model Provider implementation as you see in my code.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need to pass in
api_keyto theclient_args. Here are some steps I took for testing:Setup a local vllm server:
$ python -m vllm.entrypoints.openai.api_server \ --model Qwen/Qwen2.5-7B-Instruct \ --host 0.0.0.0 \ --port 8000 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --max-model-len 4096I set
hermesas a tool parser based on instructions here. This seems to be the mechanism to make tool calls OpenAI compatible and so we shouldn't need any special handling in our code.Next, I setup my agent script passing in
api_keyto theclient_args.Result:
$ python test_vllm.py Tool #1: calculator The result of 2 + 2 is 4.For this to work, you'll also need to include changes in #97 (already merged into main).
As an alternative, you should also be able to get this working with the LiteLLMModel provider. The LiteLLM docs have a dedicated page on vllm (see here).