Skip to content

Commit 2d411c2

Browse files
committed
feat(core): Add support for ollama module
- Added a new class OllamaContainer with few methods to handle the Ollama container. - The `_check_and_add_gpu_capabilities` method checks if the host has GPUs and adds the necessary capabilities to the container. - The `commit_to_image` allows to save somehow the state of a container into an image so that we can reuse it, especially for the ones having some models pulled. - Added tests to check the functionality of the new class.
1 parent ead0f79 commit 2d411c2

File tree

5 files changed

+129
-1
lines changed

5 files changed

+129
-1
lines changed

modules/ollama/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.. autoclass:: testcontainers.ollama.OllamaContainer
2+
.. title:: testcontainers.ollama.OllamaContainer
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
3+
# not use this file except in compliance with the License. You may obtain
4+
# a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
# License for the specific language governing permissions and limitations
12+
# under the License.
13+
14+
import docker
15+
from docker.types.containers import DeviceRequest
16+
17+
from testcontainers.core.generic import ServerContainer
18+
from testcontainers.core.waiting_utils import wait_container_is_ready
19+
20+
21+
class OllamaContainer(ServerContainer):
22+
def __init__(self, image="ollama/ollama:latest"):
23+
self.port = 11434
24+
super().__init__(port=self.port, image=image)
25+
self.with_exposed_ports(self.port)
26+
self._check_and_add_gpu_capabilities()
27+
28+
def _check_and_add_gpu_capabilities(self):
29+
info = docker.APIClient().info()
30+
if "nvidia" in info["Runtimes"]:
31+
self.with_kwargs(device_requests=DeviceRequest(count=-1, capabilities=[["gpu"]]))
32+
33+
def start(self) -> "OllamaContainer":
34+
"""
35+
Start the Ollama server
36+
"""
37+
super().start()
38+
wait_container_is_ready(self, "Ollama started successfully")
39+
self._connect()
40+
41+
return self
42+
43+
def get_endpoint(self):
44+
"""
45+
Return the endpoint of the Ollama server
46+
"""
47+
return self._create_connection_url()
48+
49+
@property
50+
def id(self) -> str:
51+
"""
52+
Return the container object
53+
"""
54+
return self._container.id
55+
56+
def pull_model(self, model_name: str) -> None:
57+
"""
58+
Pull a model from the Ollama server
59+
60+
Args:
61+
model_name (str): Name of the model
62+
"""
63+
self.exec(f"ollama pull {model_name}")
64+
65+
def commit_to_image(self, image_name: str) -> None:
66+
"""
67+
Commit the current container to a new image
68+
69+
Args:
70+
image_name (str): Name of the new image
71+
"""
72+
docker_client = self.get_docker_client()
73+
existing_images = docker_client.client.images.list(name=image_name)
74+
if not existing_images and self.id:
75+
docker_client.client.containers.get(self.id).commit(repository=image_name, conf={"Labels": {"org.testcontainers.session-id": ""}})
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import requests
2+
from testcontainers.ollama import OllamaContainer
3+
import random
4+
import string
5+
6+
7+
def random_string(length=6):
8+
return "".join(random.choices(string.ascii_lowercase, k=length))
9+
10+
11+
def test_ollama_container():
12+
with OllamaContainer() as ollama:
13+
url = ollama.get_endpoint()
14+
response = requests.get(url)
15+
assert response.status_code == 200
16+
assert response.text == "Ollama is running"
17+
18+
19+
def test_with_default_config():
20+
with OllamaContainer("ollama/ollama:0.1.26") as ollama:
21+
ollama.start()
22+
response = requests.get(f"{ollama.get_endpoint()}/api/version")
23+
version = response.json().get("version")
24+
assert version == "0.1.26"
25+
26+
27+
def test_download_model_and_commit_to_image():
28+
new_image_name = f"tc-ollama-allminilm-{random_string(length=4).lower()}"
29+
with OllamaContainer("ollama/ollama:0.1.26") as ollama:
30+
ollama.start()
31+
# Pull the model
32+
ollama.pull_model("all-minilm")
33+
34+
response = requests.get(f"{ollama.get_endpoint()}/api/tags")
35+
model_name = response.json().get("models", [])[0].get("name")
36+
assert "all-minilm" in model_name
37+
38+
# Commit the container state to a new image
39+
ollama.commit_to_image(new_image_name)
40+
41+
# Verify the new image
42+
with OllamaContainer(new_image_name) as ollama:
43+
ollama.start()
44+
response = requests.get(f"{ollama.get_endpoint()}/api/tags")
45+
model_name = response.json().get("models", [])[0].get("name")
46+
assert "all-minilm" in model_name

poetry.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ packages = [
5151
{ include = "testcontainers", from = "modules/nats" },
5252
{ include = "testcontainers", from = "modules/neo4j" },
5353
{ include = "testcontainers", from = "modules/nginx" },
54+
{ include = "testcontainers", from = "modules/ollama" },
5455
{ include = "testcontainers", from = "modules/opensearch" },
5556
{ include = "testcontainers", from = "modules/oracle-free" },
5657
{ include = "testcontainers", from = "modules/postgres" },
@@ -127,6 +128,7 @@ nats = ["nats-py"]
127128
neo4j = ["neo4j"]
128129
nginx = []
129130
opensearch = ["opensearch-py"]
131+
ollama = []
130132
oracle = ["sqlalchemy", "oracledb"]
131133
oracle-free = ["sqlalchemy", "oracledb"]
132134
postgres = []
@@ -272,6 +274,7 @@ mypy_path = [
272274
# "modules/mysql",
273275
# "modules/neo4j",
274276
# "modules/nginx",
277+
# "modules/ollama",
275278
# "modules/opensearch",
276279
# "modules/oracle",
277280
# "modules/postgres",

0 commit comments

Comments
 (0)