Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 0 additions & 59 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,65 +152,6 @@ def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()

def test_chat_completions_endpoint(self):
data = self.client.chat_completions(
messages=[{"role": "user", "content": "Say hello"}],
max_tokens=32,
)

assert "id" in data
assert "choices" in data
assert "usage" in data
assert len(data["choices"]) > 0
assert data["choices"][0]["message"]["role"] == "assistant"
assert data["choices"][0]["finish_reason"] in ["stop", "length", "tool_calls"]

def test_chat_completions_with_tools(self):
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information for a location",
"parameters": {"type": "object", "properties": {"location": {"type": "string"}}},
},
}
]
data = self.client.chat_completions(
messages=[{"role": "user", "content": "What's the weather in San Francisco?"}],
tools=tools,
max_tokens=100,
)

assert "choices" in data
assert len(data["choices"]) > 0
assert "message" in data["choices"][0]

def test_chat_completions_with_params(self):
data = self.client.chat_completions(
messages=[{"role": "user", "content": "Tell me a joke"}],
n=2,
temperature=0.8,
top_p=0.9,
max_tokens=32,
)

assert len(data["choices"]) == 2

for i, choice in enumerate(data["choices"]):
assert choice["index"] == i, f"Expected choice at position {i} to have index {i}, got {choice['index']}"
assert "message" in choice
assert choice["message"]["role"] == "assistant"

def test_tokenize_endpoint(self):
data = self.client.tokenize(messages=[{"role": "user", "content": "Hello, how are you?"}])

assert "tokens" in data
assert "model" in data
assert isinstance(data["tokens"], list)
assert len(data["tokens"]) > 0
assert all(isinstance(tok, int) for tok in data["tokens"])

@pytest.mark.xfail(reason="Importing `bitsandbytes` causes issues, see vllm-project/vllm#32793")
def test_logprobs_match_with_non_default_sampling(self):
prompts = ["Hello, AI!", "Tell me a joke"]
Expand Down
76 changes: 0 additions & 76 deletions trl/generation/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,82 +514,6 @@ def reset_prefix_cache(self):
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

def chat_completions(
self,
messages: list[dict],
model: str | None = None,
temperature: float = 1.0,
top_p: float = 1.0,
max_tokens: int | None = None,
n: int = 1,
tools: list[dict] | None = None,
**kwargs,
) -> dict:
"""
OpenAI-compatible chat completions endpoint.

Args:
messages (`list[dict]`):
List of messages in OpenAI format with "role" and "content" keys.
model (`str`, *optional*):
Model name to use.
temperature (`float`, *optional*, defaults to `1.0`):
Temperature for sampling.
top_p (`float`, *optional*, defaults to `1.0`):
Top-p sampling parameter.
max_tokens (`int`, *optional*):
Maximum number of tokens to generate.
n (`int`, *optional*, defaults to `1`):
Number of completions to generate.
tools (`list[dict]`, *optional*):
List of tool definitions for tool calling.
**kwargs:
Additional parameters to pass to the endpoint.

Returns:
`dict`:
OpenAI-compatible response with "choices", "usage", etc.
"""
url = f"{self.base_url}/v1/chat/completions"
response = self.session.post(
url,
json={
"messages": messages,
"model": model,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"n": n,
"tools": tools,
**kwargs,
},
)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

def tokenize(self, messages: list[dict], tools: list[dict] | None = None) -> dict:
"""
Tokenize messages to get token IDs.

Args:
messages (`list[dict]`):
List of messages to tokenize.
tools (`list[dict]`, *optional*):
List of tool definitions.

Returns:
`dict`:
Dictionary with "tokens" (list of token IDs) and "model" keys.
"""
url = f"{self.base_url}/tokenize"
response = self.session.post(url, json={"messages": messages, "tools": tools})
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

def close_communicator(self):
"""
Closes the weight update group and cleans up the communication group.
Expand Down
Loading
Loading