diff --git a/docs/modules/dockermodelrunner.md b/docs/modules/dockermodelrunner.md index 6a5379034f..5587fb14a9 100644 --- a/docs/modules/dockermodelrunner.md +++ b/docs/modules/dockermodelrunner.md @@ -76,7 +76,7 @@ The Docker Model Runner container exposes the following methods: - Since testcontainers-go :material-tag: v0.37.0 -Use the `PullModel` method to pull a model from the Docker Model Runner. +Use the `PullModel` method to pull a model from the Docker Model Runner. Make sure the passed context is not done before the pull operation is completed, so that the pull operation is cancelled. [Pulling a model at runtime](../../modules/dockermodelrunner/examples_test.go) inside_block:runPullModel diff --git a/modules/dockermodelrunner/dockermodelrunner_test.go b/modules/dockermodelrunner/dockermodelrunner_test.go index b2bcaef234..577df45996 100644 --- a/modules/dockermodelrunner/dockermodelrunner_test.go +++ b/modules/dockermodelrunner/dockermodelrunner_test.go @@ -3,6 +3,7 @@ package dockermodelrunner_test import ( "context" "testing" + "time" "github.com/stretchr/testify/require" @@ -62,6 +63,13 @@ func TestRun_client(t *testing.T) { err := ctr.PullModel(ctx, testNonExistentFQMN) require.Error(t, err) }) + + t.Run("failure/timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond) + defer cancel() + err := ctr.PullModel(ctx, testModelFQMN) + require.Error(t, err) + }) }) t.Run("model-inspect", func(t *testing.T) { diff --git a/modules/dockermodelrunner/examples_test.go b/modules/dockermodelrunner/examples_test.go index 5d717afdb4..b91e078b94 100644 --- a/modules/dockermodelrunner/examples_test.go +++ b/modules/dockermodelrunner/examples_test.go @@ -6,6 +6,7 @@ import ( "log" "slices" "strings" + "time" "github.com/openai/openai-go" "github.com/openai/openai-go/option" @@ -105,6 +106,9 @@ func ExampleRun_pullModel() { fqModelName = modelNamespace + "/" + modelName + ":" + modelTag ) + ctx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + err = dmrCtr.PullModel(ctx, fqModelName) if err != nil { log.Printf("failed to pull model: %s", err) diff --git a/modules/dockermodelrunner/internal/sdk/client/pull.go b/modules/dockermodelrunner/internal/sdk/client/pull.go index 057ffbdc63..26e7cf035c 100644 --- a/modules/dockermodelrunner/internal/sdk/client/pull.go +++ b/modules/dockermodelrunner/internal/sdk/client/pull.go @@ -1,11 +1,13 @@ package client import ( + "bufio" "context" "fmt" - "log" "net/http" "strings" + + "github.com/testcontainers/testcontainers-go/log" ) // PullModel creates a model in the Docker Model Runner, by pulling the model from Docker Hub. @@ -18,7 +20,9 @@ func (c *Client) PullModel(ctx context.Context, fullyQualifiedModelName string) return fmt.Errorf("new post request (%s): %w", reqURL, err) } - log.Default().Printf("🙏 Pulling model %s. Please be patient, no progress bar yet!", fullyQualifiedModelName) + logger := log.Default() + + logger.Printf("🙏 Pulling model %s. Please be patient", fullyQualifiedModelName) resp, err := http.DefaultClient.Do(req) if err != nil { @@ -31,7 +35,17 @@ func (c *Client) PullModel(ctx context.Context, fullyQualifiedModelName string) return fmt.Errorf("unexpected status code: %d", resp.StatusCode) } - log.Default().Printf("✅ Model %s pulled successfully!", fullyQualifiedModelName) + scanner := bufio.NewScanner(resp.Body) + // TODO: use a progressbar instead of multiple line output. + for scanner.Scan() { + logger.Printf(scanner.Text()) + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("scan error: %w", err) + } + + logger.Printf("✅ Model %s pulled successfully!", fullyQualifiedModelName) return nil }