Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/modules/dockermodelrunner.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ The Docker Model Runner container exposes the following methods:

- Since testcontainers-go <a href="https://github.com/testcontainers/testcontainers-go/releases/tag/v0.37.0"><span class="tc-version">:material-tag: v0.37.0</span></a>

Use the `PullModel` method to pull a model from the Docker Model Runner.
Use the `PullModel` method to pull a model from the Docker Model Runner. Make sure the passed context is not done before the pull operation is completed, so that the pull operation is cancelled.

<!--codeinclude-->
[Pulling a model at runtime](../../modules/dockermodelrunner/examples_test.go) inside_block:runPullModel
Expand Down
8 changes: 8 additions & 0 deletions modules/dockermodelrunner/dockermodelrunner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dockermodelrunner_test
import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -62,6 +63,13 @@ func TestRun_client(t *testing.T) {
err := ctr.PullModel(ctx, testNonExistentFQMN)
require.Error(t, err)
})

t.Run("failure/timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond)
defer cancel()
err := ctr.PullModel(ctx, testModelFQMN)
require.Error(t, err)
})
})

t.Run("model-inspect", func(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions modules/dockermodelrunner/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"slices"
"strings"
"time"

"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
Expand Down Expand Up @@ -105,6 +106,9 @@ func ExampleRun_pullModel() {
fqModelName = modelNamespace + "/" + modelName + ":" + modelTag
)

ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()

err = dmrCtr.PullModel(ctx, fqModelName)
if err != nil {
log.Printf("failed to pull model: %s", err)
Expand Down
2 changes: 1 addition & 1 deletion modules/dockermodelrunner/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/testcontainers/testcontainers-go/modules/dockermodelrunner
go 1.23.0

require (
github.com/cenkalti/backoff/v4 v4.2.1
github.com/openai/openai-go v0.1.0-beta.9
github.com/stretchr/testify v1.10.0
github.com/testcontainers/testcontainers-go v0.37.0
Expand All @@ -14,7 +15,6 @@ require (
dario.cat/mergo v1.0.1 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
Expand Down
36 changes: 35 additions & 1 deletion modules/dockermodelrunner/internal/sdk/client/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ package client

import (
"context"
"errors"
"fmt"
"log"
"net/http"
"strings"
"time"

"github.com/cenkalti/backoff/v4"

"github.com/testcontainers/testcontainers-go/log"
)

// PullModel creates a model in the Docker Model Runner, by pulling the model from Docker Hub.
Expand All @@ -31,6 +36,35 @@ func (c *Client) PullModel(ctx context.Context, fullyQualifiedModelName string)
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

namespace := strings.Split(fullyQualifiedModelName, "/")[0]
modelName := strings.Split(fullyQualifiedModelName, "/")[1]

// Verify that the model is pulled successfully, honoring the parent context
// This is because the pull cancels when the connection is closed.
err = backoff.RetryNotify(
func() error {
if ctx.Err() != nil {
return backoff.Permanent(ctx.Err())
}

model, err := c.InspectModel(ctx, namespace, modelName)
if err != nil {
return err
}
if model == nil {
return errors.New("model not found")
}
return nil
},
backoff.WithContext(backoff.NewExponentialBackOff(), ctx),
func(err error, _ time.Duration) {
log.Default().Printf("🙏 Pulling model %s. Please be patient, no progress bar yet! %w", fullyQualifiedModelName, err)
},
)
if err != nil {
return fmt.Errorf("pull model: %w", err)
}

log.Default().Printf("✅ Model %s pulled successfully!", fullyQualifiedModelName)

return nil
Expand Down
Loading