Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions litellm/llms/custom_httpx/http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,21 @@
except Exception:
version = "0.0.0"

headers = {
"User-Agent": f"litellm/{version}",
}
def get_default_headers() -> dict:
"""
Get default headers for HTTP requests.

- Default: `User-Agent: litellm/{version}`
- Override: set `LITELLM_USER_AGENT` to fully override the header value.
"""
user_agent = os.environ.get("LITELLM_USER_AGENT")
if user_agent is not None:
return {"User-Agent": user_agent}

return {"User-Agent": f"litellm/{version}"}

# Initialize headers (User-Agent)
headers = get_default_headers()

# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
Expand Down Expand Up @@ -371,13 +383,16 @@ def create_client(
shared_session=shared_session,
)

# Get default headers (User-Agent, overridable via LITELLM_USER_AGENT)
default_headers = get_default_headers()

return httpx.AsyncClient(
transport=transport,
event_hooks=event_hooks,
timeout=timeout,
verify=ssl_config,
cert=cert,
headers=headers,
headers=default_headers,
follow_redirects=True,
)

Expand Down Expand Up @@ -899,6 +914,9 @@ def __init__(
# /path/to/client.pem
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)

# Get default headers (User-Agent, overridable via LITELLM_USER_AGENT)
default_headers = get_default_headers() if not disable_default_headers else None

if client is None:
transport = self._create_sync_transport()

Expand All @@ -908,7 +926,7 @@ def __init__(
timeout=timeout,
verify=ssl_config,
cert=cert,
headers=headers if not disable_default_headers else None,
headers=default_headers,
follow_redirects=True,
)
else:
Expand Down
16 changes: 13 additions & 3 deletions litellm/llms/custom_httpx/httpx_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional, Union

import httpx
Expand All @@ -7,13 +8,22 @@
except Exception:
version = "0.0.0"

headers = {
"User-Agent": f"litellm/{version}",
}
def get_default_headers() -> dict:
"""
Get default headers for HTTP requests.

- Default: `User-Agent: litellm/{version}`
- Override: set `LITELLM_USER_AGENT` to fully override the header value.
"""
user_agent = os.environ.get("LITELLM_USER_AGENT")
if user_agent is not None:
return {"User-Agent": user_agent}

return {"User-Agent": f"litellm/{version}"}

class HTTPHandler:
def __init__(self, concurrent_limit=1000):
headers = get_default_headers()
# Create a client with a connection pool
self.client = httpx.AsyncClient(
limits=httpx.Limits(
Expand Down
84 changes: 84 additions & 0 deletions tests/test_litellm/llms/custom_httpx/test_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,87 @@ def test_ssl_ecdh_curve(env_curve, litellm_curve, expected_curve, should_call, m
assert isinstance(ssl_context, ssl.SSLContext)
finally:
litellm.ssl_ecdh_curve = original_value


def test_default_user_agent_is_litellm_version(monkeypatch):
from litellm._version import version
from litellm.llms.custom_httpx.http_handler import get_default_headers

monkeypatch.delenv("LITELLM_USER_AGENT", raising=False)

assert get_default_headers()["User-Agent"] == f"litellm/{version}"


def test_user_agent_can_be_overridden_via_env_var(monkeypatch):
from litellm.llms.custom_httpx.http_handler import get_default_headers

monkeypatch.setenv("LITELLM_USER_AGENT", "Claude Code")

assert get_default_headers()["User-Agent"] == "Claude Code"


def test_user_agent_env_var_can_be_empty_string(monkeypatch):
from litellm.llms.custom_httpx.http_handler import get_default_headers

monkeypatch.setenv("LITELLM_USER_AGENT", "")

assert get_default_headers()["User-Agent"] == ""


def test_user_agent_override_is_not_appended_to_default(monkeypatch):
from litellm.llms.custom_httpx.http_handler import HTTPHandler

monkeypatch.delenv("LITELLM_USER_AGENT", raising=False)

handler = HTTPHandler()
try:
req = handler.client.build_request(
"GET",
"https://example.com",
headers={"user-agent": "Claude Code"},
)

assert req.headers.get_list("User-Agent") == ["Claude Code"]
finally:
handler.close()


def test_sync_http_handler_uses_env_user_agent(monkeypatch):
from litellm.llms.custom_httpx.http_handler import HTTPHandler

monkeypatch.setenv("LITELLM_USER_AGENT", "Claude Code")

handler = HTTPHandler()
try:
req = handler.client.build_request("GET", "https://example.com")
assert req.headers.get("User-Agent") == "Claude Code"
finally:
handler.close()


@pytest.mark.asyncio
async def test_async_http_handler_uses_env_user_agent(monkeypatch):
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler

monkeypatch.setenv("LITELLM_USER_AGENT", "Claude Code")

handler = AsyncHTTPHandler()
try:
req = handler.client.build_request("GET", "https://example.com")
assert req.headers.get("User-Agent") == "Claude Code"
finally:
await handler.close()


@pytest.mark.asyncio
async def test_httpx_handler_uses_env_user_agent(monkeypatch):
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler

monkeypatch.setenv("LITELLM_USER_AGENT", "Claude Code")

handler = HTTPHandler()
try:
req = handler.client.build_request("GET", "https://example.com")
assert req.headers.get("User-Agent") == "Claude Code"
finally:
await handler.close()
Loading