Skip to content

Fix chat completion url for OpenAI compatibility #2418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,13 +816,16 @@ def chat_completion(
# First, resolve the model chat completions URL
if model == self.base_url:
# base_url passed => add server route
model_url = model + "/v1/chat/completions"
model_url = model.rstrip("/")
if not model_url.endswith("/v1"):
model_url += "/v1"
model_url += "/chat/completions"
elif is_url:
# model is a URL => use it directly
model_url = model
else:
# model is a model ID => resolve it + add server route
model_url = self._resolve_url(model) + "/v1/chat/completions"
model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
Expand Down
7 changes: 5 additions & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,13 +822,16 @@ async def chat_completion(
# First, resolve the model chat completions URL
if model == self.base_url:
# base_url passed => add server route
model_url = model + "/v1/chat/completions"
model_url = model.rstrip("/")
if not model_url.endswith("/v1"):
model_url += "/v1"
model_url += "/chat/completions"
elif is_url:
# model is a URL => use it directly
model_url = model
else:
# model is a model ID => resolve it + add server route
model_url = self._resolve_url(model) + "/v1/chat/completions"
model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
Expand Down
21 changes: 21 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,27 @@ def test_model_and_base_url_mutually_exclusive(self):
InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000")


@pytest.mark.parametrize(
"base_url",
[
"http://0.0.0.0:8080/v1", # expected from OpenAI client
"http://0.0.0.0:8080", # but not mandatory
"http://0.0.0.0:8080/v1/", # ok with trailing '/' as well
"http://0.0.0.0:8080/", # ok with trailing '/' as well
],
)
def test_chat_completion_base_url_works_with_v1(base_url: str):
"""Test that `/v1/chat/completions` is correctly appended to the base URL.

This is a regression test for https://github.com/huggingface/huggingface_hub/issues/2414
"""
with patch("huggingface_hub.inference._client.InferenceClient.post") as post_mock:
client = InferenceClient(base_url=base_url)
post_mock.return_value = "{}"
client.chat_completion(messages=CHAT_COMPLETION_MESSAGES, stream=False)
assert post_mock.call_args_list[0].kwargs["model"] == "http://0.0.0.0:8080/v1/chat/completions"


def test_stream_text_generation_response():
data = [
b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}',
Expand Down
Loading