diff --git a/core/src/main/java/org/testcontainers/containers/DockerModelRunnerContainer.java b/core/src/main/java/org/testcontainers/containers/DockerModelRunnerContainer.java new file mode 100644 index 00000000000..8ce2adbc319 --- /dev/null +++ b/core/src/main/java/org/testcontainers/containers/DockerModelRunnerContainer.java @@ -0,0 +1,35 @@ +package org.testcontainers.containers; + +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.utility.DockerImageName; + +/** + * Testcontainers proxy container for the Docker Model Runner service + * provided by Docker Desktop. + *

+ * Supported images: {@code alpine/socat} + *

+ * Exposed ports: 80 + */ +public class DockerModelRunnerContainer extends SocatContainer { + + private static final String MODEL_RUNNER_ENDPOINT = "model-runner.docker.internal"; + + public DockerModelRunnerContainer(String image) { + this(DockerImageName.parse(image)); + } + + public DockerModelRunnerContainer(DockerImageName image) { + super(image); + withTarget(80, MODEL_RUNNER_ENDPOINT); + waitingFor(Wait.forHttp("/").forResponsePredicate(res -> res.contains("The service is running"))); + } + + public String getBaseEndpoint() { + return "http://" + getHost() + ":" + getMappedPort(80); + } + + public String getOpenAIEndpoint() { + return getBaseEndpoint() + "/engines"; + } +} diff --git a/core/src/test/java/org/testcontainers/containers/DockerModelRunnerContainerTest.java b/core/src/test/java/org/testcontainers/containers/DockerModelRunnerContainerTest.java new file mode 100644 index 00000000000..dfe9df9e905 --- /dev/null +++ b/core/src/test/java/org/testcontainers/containers/DockerModelRunnerContainerTest.java @@ -0,0 +1,41 @@ +package org.testcontainers.containers; + +import io.restassured.RestAssured; +import io.restassured.response.Response; +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +public class DockerModelRunnerContainerTest { + + @Test + public void pullsModelAndExposesInference() { + assumeThat(System.getenv("CI")).isNull(); + + String modelName = "ai/smollm2:360M-Q4_K_M"; + + try ( + // container { + DockerModelRunnerContainer dmr = new DockerModelRunnerContainer("alpine/socat:1.7.4.3-r0") + // } + ) { + dmr.start(); + + // pullModel { + RestAssured + .given() + .body(String.format("{\"from\":\"%s\"}", modelName)) + .post(dmr.getBaseEndpoint() + "/models/create") + .then() + .statusCode(200); + // } + + Response modelResponse = RestAssured.get(dmr.getBaseEndpoint() + "/models").thenReturn(); + assertThat(modelResponse.body().jsonPath().getList("tags.flatten()")).contains(modelName); + + Response openAiResponse = RestAssured.get(dmr.getOpenAIEndpoint() + "/v1/models").prettyPeek().thenReturn(); + assertThat(openAiResponse.body().jsonPath().getList("data.id")).contains(modelName); + } + } +} diff --git a/docs/modules/docker_model_runner.md b/docs/modules/docker_model_runner.md new file mode 100644 index 00000000000..b610279e93b --- /dev/null +++ b/docs/modules/docker_model_runner.md @@ -0,0 +1,41 @@ +# Docker Model Runner + +This module helps connect to [Docker Model Runner](https://docs.docker.com/desktop/features/model-runner/) +provided by Docker Desktop 4.40.0. + +## DockerModelRunner's usage examples + +You can start a Docker Model Runner proxy container instance from any Java application by using: + + +[Create a DockerModelRunnerContainer](../../core/src/test/java/org/testcontainers/containers/DockerModelRunnerContainerTest.java) inside_block:container + + +### Pulling the model + +Pulling the model is as simple as: + + +[Pull model](../../core/src/test/java/org/testcontainers/containers/DockerModelRunnerContainerTest.java) inside_block:pullModel + + +## Adding this module to your project dependencies + +*Docker Model Runner support is part of the core Testcontainers library.* + +Add the following dependency to your `pom.xml`/`build.gradle` file: + +=== "Gradle" + ```groovy + testImplementation "org.testcontainers:testcontainers:{{latest_version}}" + ``` +=== "Maven" + ```xml + + org.testcontainers + testcontainers + {{latest_version}} + test + + ``` + diff --git a/mkdocs.yml b/mkdocs.yml index 536f73a0603..88aa92efbe8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -82,6 +82,7 @@ nav: - modules/chromadb.md - modules/consul.md - modules/docker_compose.md + - modules/docker_model_runner.md - modules/elasticsearch.md - modules/gcloud.md - modules/grafana.md