Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ from strands import Agent
from strands.models import BedrockModel
from strands.models.ollama import OllamaModel
from strands.models.llamaapi import LlamaAPIModel
from strands.models.vllm import VLLMModel

# Bedrock
bedrock_model = BedrockModel(
Expand All @@ -130,6 +131,14 @@ llama_model = LlamaAPIModel(
)
agent = Agent(model=llama_model)
response = agent("Tell me about Agentic AI")

# vLLM
vllm_modal = VLLMModel(
host="http://localhost:8000",
model_id="meta-llama/Llama-3.2-3B"
)
agent_vllm = Agent(model=vllm_modal)
agent_vllm("Tell me about Agentic AI")
```

Built-in providers:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ ollama = [
llamaapi = [
"llama-api-client>=0.1.0,<1.0.0",
]
vllm = [
"vllm>=0.8.5",
]

[tool.hatch.envs.hatch-static-analysis]
features = ["anthropic", "litellm", "llamaapi", "ollama"]
features = ["anthropic", "litellm", "llamaapi", "ollama","vllm"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
Expand Down
209 changes: 209 additions & 0 deletions src/strands/models/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""vLLM model provider.

- Docs: https://github.com/vllm-project/vllm
"""

import json
import logging
from typing import Any, Iterable, Optional

import requests
from typing_extensions import TypedDict, Unpack, override

from ..types.content import Messages
from ..types.models import Model
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec

logger = logging.getLogger(__name__)


class VLLMModel(Model):
"""vLLM model provider implementation.

Assumes OpenAI-compatible vLLM server at `http://<host>/v1/completions`.
Copy link
Member

@pgrayy pgrayy May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since vLLM assumes OpenAI compatibility (docs), I think it actually makes sense for us to generalize here and instead of creating a VLLMModel provider, we create an OpenAIModel provider. This is something the Strands Agents team is already discussing as we have other OpenAI compatible providers that could all share the format_request and format_chunk logic.

With that said, I would suggest keeping this PR open for the time being as we further work out the details. Thank you for your contribution and patience.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it @pgrayy, will wait for the updates. The one thing we will miss out on this approach is the support for the native vLLM APIs via native vLLM endpoint, designed for direct use with the vLLM engine. Do you foresee just that aspect being a vLLM Model Provider by itself? I see this VLLModel provider gradually built to cover all aspects of vLLM including OpenAI compatible end points.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the PR for the OpenAI model provider. This should work to connect to models served with vLLM.

Regarding your question, could you elaborate on what you mean by native vLLM endpoint? It should still be OpenAI compatible correct? Based on what I am reading in the docs, you can query vLLM using the openai client, which is what this new Strands OpenAI model provider uses under the hood.

Copy link
Author

@AhilanPonnusamy AhilanPonnusamy May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @pgrayy , Regarding vLLM endpoints, there are two endpoints provided by vLLM.

  1. openai.api_server in vLLM provides an OpenAI-compatible REST API layer, mimicking the behavior and request/response format of OpenAI's API (/v1/chat/completions, etc.).This layer supports both standard and streaming completion modes.

  2. api_server, on the other hand, is vLLM’s native API server, offering endpoints like /generate or /completion, designed specifically for internal or custom integrations.While api_server is more flexible and may expose additional low-level features, openai.api_server ensures broader compatibility with the OpenAI ecosystem.

You can run either by setting the appropriate flag in vllm.entrypoints.openai.api_server or vllm.entrypoints.api_server when launching the server.

On another note, I tried to test the OpenAI Model provider from PR against my local vLLM instance. I couldn't pass the API_KEY stage even after setting the env variable "openai.AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: empty. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}" .

As you suggested, shall we keep this vLLM Model Provider PR open until we get the OpenAI one working, or would it be better to merge it for now and deprecate it later if it's deemed redundant?

Copy link
Member

@pgrayy pgrayy May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you share more details on your testing? What happens when you try the following:

from strands.models.openai import OpenAIModel

openai_model = OpenAIModel({"api_key": "<YOUR_API_KEY>", "base_url": "<YOUR_MODEL_ENDPOINT>"}, model_id="<YOUR_MODEL_ID>")

api_key will need to be explicitly passed into the model provider unless you set the OPENAI_API_KEY environment variable.

Copy link
Author

@AhilanPonnusamy AhilanPonnusamy May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No luck @pgrayy please find below

call

OpenAIModel(
api_key="abc123",
base_url="http://localhost:8000",
model_id="Qwen/Qwen3-4B", # Qwen/Qwen3-8B

Error

openai.AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: abc123. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}

Also, I think the way the tool_calls are emitted will be different for different Model Providers. I had to convert the tool_call to a function call with a different format to make it work for vLLM Model Provider implementation as you see in my code.

Copy link
Member

@pgrayy pgrayy May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to pass in api_key to the client_args. Here are some steps I took for testing:

Setup a local vllm server:

$ python -m vllm.entrypoints.openai.api_server \
    --model Qwen/Qwen2.5-7B-Instruct \
    --host 0.0.0.0 \
    --port 8000 \
    --enable-auto-tool-choice \
    --tool-call-parser hermes \
    --max-model-len 4096

I set hermes as a tool parser based on instructions here. This seems to be the mechanism to make tool calls OpenAI compatible and so we shouldn't need any special handling in our code.

Next, I setup my agent script passing in api_key to the client_args.

from strands import Agent
from strands.models.openai import OpenAIModel
from strands_tools import calculator

model = OpenAIModel(
    # can also pass dict as first argument
    client_args={"api_key": "abc123", "base_url": "http://localhost:8000/v1"},
    # everything from this point is a kwarg though
    model_id="Qwen/Qwen2.5-7B",
)

agent = Agent(model=model, tools=[calculator])
agent("What is 2+2?")

Result:

$ python test_vllm.py
Tool #1: calculator
The result of 2 + 2 is 4.

For this to work, you'll also need to include changes in #97 (already merged into main).

As an alternative, you should also be able to get this working with the LiteLLMModel provider. The LiteLLM docs have a dedicated page on vllm (see here).


The implementation handles vLLM-specific features such as:

- Local model invocation
- Streaming responses
- Tool/function calling
"""

class VLLMConfig(TypedDict, total=False):
"""Configuration parameters for vLLM models.

Attributes:
additional_args: Any additional arguments to include in the request.
max_tokens: Maximum number of tokens to generate in the response.
model_id: vLLM model ID (e.g., "meta-llama/Llama-3.2-3B,microsoft/Phi-3-mini-128k-instruct").
options: Additional model parameters (e.g., top_k).
temperature: Controls randomness in generation (higher = more random).
top_p: Controls diversity via nucleus sampling (alternative to temperature).
"""

model_id: str
temperature: Optional[float]
top_p: Optional[float]
max_tokens: Optional[int]
stop_sequences: Optional[list[str]]
additional_args: Optional[dict[str, Any]]

def __init__(self, host: str, **model_config: Unpack[VLLMConfig]) -> None:
"""Initialize provider instance.

Args:
host: The address of the vLLM server hosting the model.
**model_config: Configuration options for the vLLM model.
"""
self.config = VLLMModel.VLLMConfig(**model_config)
self.host = host.rstrip("/")
logger.debug("Initializing vLLM provider with config: %s", self.config)

@override
def update_config(self, **model_config: Unpack[VLLMConfig]) -> None:
"""Update the vLLM Model configuration with the provided arguments.

Args:
**model_config: Configuration overrides.
"""
self.config.update(model_config)

@override
def get_config(self) -> VLLMConfig:
"""Get the vLLM model configuration.

Returns:
The vLLM model configuration.
"""
return self.config

@override
def format_request(
self,
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
) -> dict[str, Any]:
"""Format an vLLM chat streaming request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An vLLM chat streaming request.
"""

# Concatenate messages to form a prompt string
prompt_parts = [
f"{msg['role']}: {content['text']}" for msg in messages for content in msg["content"] if "text" in content
]
if system_prompt:
prompt_parts.insert(0, f"system: {system_prompt}")
prompt = "\n".join(prompt_parts) + "\nassistant:"

payload = {
"model": self.config["model_id"],
"prompt": prompt,
"temperature": self.config.get("temperature", 0.7),
"top_p": self.config.get("top_p", 1.0),
"max_tokens": self.config.get("max_tokens", 128),
"stop": self.config.get("stop_sequences"),
"stream": False, # Disable streaming
}

if self.config.get("additional_args"):
payload.update(self.config["additional_args"])

return payload

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the vLLM response events into standardized message chunks.

Args:
event: A response event from the vLLM model.

Returns:
The formatted chunk.

"""
choice = event.get("choices", [{}])[0]

if "text" in choice:
return {"contentBlockDelta": {"delta": {"text": choice["text"]}}}

if "finish_reason" in choice:
return {"messageStop": {"stopReason": choice["finish_reason"] or "end_turn"}}

return {}

@override
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
"""Send the request to the vLLM model and get the streaming response.

This method calls the /v1/completions endpoint and returns the stream of response events.

Args:
request: The formatted request to send to the vLLM model.

Returns:
An iterable of response events from the vLLM model.
"""
headers = {"Content-Type": "application/json"}
url = f"{self.host}/v1/completions"
request["stream"] = True # Enable streaming

full_output = ""

try:
with requests.post(url, headers=headers, data=json.dumps(request), stream=True) as response:
if response.status_code != 200:
logger.error("vLLM server error: %d - %s", response.status_code, response.text)
raise Exception(f"Request failed: {response.status_code} - {response.text}")

yield {"chunk_type": "message_start"}
yield {"chunk_type": "content_start", "data_type": "text"}

for line in response.iter_lines(decode_unicode=True):
if not line:
continue

if line.startswith("data: "):
line = line[len("data: ") :]

if line.strip() == "[DONE]":
break

try:
data = json.loads(line)
choice = data.get("choices", [{}])[0]
text = choice.get("text", "")
finish_reason = choice.get("finish_reason")

if text:
full_output += text
print(text, end="", flush=True) # Stream to stdout without newline
yield {
"chunk_type": "content_delta",
"data_type": "text",
"data": text,
}

if finish_reason:
yield {"chunk_type": "content_stop", "data_type": "text"}
yield {"chunk_type": "message_stop", "data": finish_reason}
break

except json.JSONDecodeError:
logger.warning("Failed to decode streamed line: %s", line)

else:
yield {"chunk_type": "content_stop", "data_type": "text"}
yield {"chunk_type": "message_stop", "data": "end_turn"}

except requests.RequestException as e:
logger.error("Request to vLLM failed: %s", str(e))
raise Exception("Failed to reach vLLM server") from e
38 changes: 38 additions & 0 deletions tests-integ/test_model_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import strands
from strands import Agent
from strands.models.vllm import VLLMModel


@pytest.fixture
def model():
return VLLMModel(
model_id="meta-llama/Llama-3.2-3B", # or whatever your model ID is
host="http://localhost:8000", # adjust as needed
max_tokens=128,
)


@pytest.fixture
def tools():
@strands.tool
def tool_time() -> str:
return "12:00"

@strands.tool
def tool_weather() -> str:
return "cloudy"

return [tool_time, tool_weather]


@pytest.fixture
def agent(model, tools):
return Agent(model=model, tools=tools)


def test_agent(agent):
result = agent("What is the time and weather in Melboune Australia?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["3:00", "cloudy"])
Loading