Skip to content

Commit

Permalink
community[minor]: passthrough auth parameter on requests to Ollama-LL…
Browse files Browse the repository at this point in the history
…Ms (langchain-ai#24068)

Thank you for contributing to LangChain!

**Description:**
This PR allows users of `langchain_community.llms.ollama.Ollama` to
specify the `auth` parameter, which is then forwarded to all internal
calls of `requests.request`. This works in the same way as the existing
`headers` parameters. The auth parameter enables the usage of the given
class with Ollama instances, which are secured by more complex
authentication mechanisms, that do not only rely on static headers. An
example are AWS API Gateways secured by the IAM authorizer, which
expects signatures dynamically calculated on the specific HTTP request.

**Issue:**

Integrating a remote LLM running through Ollama using
`langchain_community.llms.ollama.Ollama` only allows setting static HTTP
headers with the parameter `headers`. This does not work, if the given
instance of Ollama is secured with an authentication mechanism that
makes use of dynamically created HTTP headers which for example may
depend on the content of a given request.

**Dependencies:**

None

**Twitter handle:**

None

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
  • Loading branch information
2 people authored and olgamurraft committed Aug 16, 2024
1 parent 6d7bf78 commit 21de5df
Showing 2 changed files with 50 additions and 6 deletions.
21 changes: 20 additions & 1 deletion libs/community/langchain_community/llms/ollama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from __future__ import annotations

import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Union
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Union,
)

import aiohttp
import requests
@@ -132,6 +145,10 @@ class _OllamaCommon(BaseLanguageModel):
tokens for authentication.
"""

auth: Union[Callable, Tuple, None] = None
"""Additional auth tuple or callable to enable Basic/Digest/Custom HTTP Auth.
Expects the same format, type and values as requests.request auth parameter."""

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
@@ -237,6 +254,7 @@ def _create_stream(
"Content-Type": "application/json",
**(self.headers if isinstance(self.headers, dict) else {}),
},
auth=self.auth,
json=request_payload,
stream=True,
timeout=self.timeout,
@@ -300,6 +318,7 @@ async def _acreate_stream(
"Content-Type": "application/json",
**(self.headers if isinstance(self.headers, dict) else {}),
},
auth=self.auth,
json=request_payload,
timeout=self.timeout,
) as response:
35 changes: 30 additions & 5 deletions libs/community/tests/unit_tests/llms/test_ollama.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300,
)

def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@@ -49,10 +49,35 @@ def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-d
llm.invoke("Test prompt")


def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(
base_url="https://ollama-hostname:8000",
model="foo",
auth=("Test-User", "Test-Password"),
timeout=300,
)

def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
}
assert json is not None
assert stream is True
assert timeout == 300
assert auth == ("Test-User", "Test-Password")

return mock_response_stream()

monkeypatch.setattr(requests, "post", mock_post)

llm.invoke("Test prompt")


def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)

def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@@ -72,7 +97,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"""Test that top level params are sent to the endpoint as top level params"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)

def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@@ -120,7 +145,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)

def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@@ -169,7 +194,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)

def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",

0 comments on commit 21de5df

Please sign in to comment.