Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/modules/dockermodelrunner.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ The Docker Model Runner container exposes the following methods:

- Since testcontainers-go <a href="https://github.com/testcontainers/testcontainers-go/releases/tag/v0.37.0"><span class="tc-version">:material-tag: v0.37.0</span></a>

Use the `PullModel` method to pull a model from the Docker Model Runner.
Use the `PullModel` method to pull a model from the Docker Model Runner. Make sure the passed context is not done before the pull operation is completed, so that the pull operation is cancelled.

<!--codeinclude-->
[Pulling a model at runtime](../../modules/dockermodelrunner/examples_test.go) inside_block:runPullModel
Expand Down
8 changes: 8 additions & 0 deletions modules/dockermodelrunner/dockermodelrunner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dockermodelrunner_test
import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -62,6 +63,13 @@ func TestRun_client(t *testing.T) {
err := ctr.PullModel(ctx, testNonExistentFQMN)
require.Error(t, err)
})

t.Run("failure/timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond)
defer cancel()
err := ctr.PullModel(ctx, testModelFQMN)
require.Error(t, err)
})
})

t.Run("model-inspect", func(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions modules/dockermodelrunner/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"slices"
"strings"
"time"

"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
Expand Down Expand Up @@ -105,6 +106,9 @@ func ExampleRun_pullModel() {
fqModelName = modelNamespace + "/" + modelName + ":" + modelTag
)

ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()

err = dmrCtr.PullModel(ctx, fqModelName)
if err != nil {
log.Printf("failed to pull model: %s", err)
Expand Down
35 changes: 33 additions & 2 deletions modules/dockermodelrunner/internal/sdk/client/pull.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package client

import (
"bufio"
"context"
"fmt"
"log"
"net/http"
"strings"

"github.com/testcontainers/testcontainers-go/log"
)

// PullModel creates a model in the Docker Model Runner, by pulling the model from Docker Hub.
Expand All @@ -18,19 +20,48 @@ func (c *Client) PullModel(ctx context.Context, fullyQualifiedModelName string)
return fmt.Errorf("new post request (%s): %w", reqURL, err)
}

log.Default().Printf("🙏 Pulling model %s. Please be patient, no progress bar yet!", fullyQualifiedModelName)
log.Default().Printf("🙏 Pulling model %s. Please be patient", fullyQualifiedModelName)

// Check context before making request
if err := ctx.Err(); err != nil {
return fmt.Errorf("context done before request: %w", err)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("http post: %w", err)
}
defer resp.Body.Close()

// Check context after getting response
if err := ctx.Err(); err != nil {
return fmt.Errorf("context done after response: %w", err)
}

// The Docker Model Runner returns a 200 status code for a successful pulls
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

scanner := bufio.NewScanner(resp.Body)
done := make(chan error, 1)

go func() {
for scanner.Scan() {
log.Default().Printf(scanner.Text())
}
done <- scanner.Err()
}()

select {
case <-ctx.Done():
return fmt.Errorf("context done: %w", ctx.Err())
case err := <-done:
if err != nil {
return fmt.Errorf("scan error: %w", err)
}
}

log.Default().Printf("✅ Model %s pulled successfully!", fullyQualifiedModelName)

return nil
Expand Down
Loading