Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ bin/
coverage.out
covdatafiles/
.DS_Store
vendor
30 changes: 30 additions & 0 deletions pkg/bridge/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ func Convert(ctx context.Context, dockerCli command.Cli, project *types.Project,
if err != nil {
return err
}
// Set default model_var and endpoint_var if missing
setDefaultModelVariablesIfMissing(project)
// for user to rely on compose.yaml attribute names, not go struct ones, we marshall back into YAML
raw, err := project.MarshalYAML(types.WithSecretContent)
// Marshall to YAML
Expand Down Expand Up @@ -222,3 +224,31 @@ func inspectWithPull(ctx context.Context, dockerCli command.Cli, imageName strin
}
return inspect, err
}

// setDefaultModelVariablesIfMissing sets default model_var and endpoint_var for services that use models
// but don't have these variables explicitly defined.
func setDefaultModelVariablesIfMissing(project *types.Project) {
for serviceName, service := range project.Services {
if len(service.Models) == 0 {
continue
}
for modelRef, modelConfig := range service.Models {
if modelConfig == nil {
modelConfig = &types.ServiceModelConfig{}
service.Models[modelRef] = modelConfig
}

if modelConfig.ModelVariable == "" || modelConfig.EndpointVariable == "" {
defaultModelVar, defaultEndpointVar := utils.GetModelVariables(modelRef)

if modelConfig.ModelVariable == "" {
modelConfig.ModelVariable = defaultModelVar
}
if modelConfig.EndpointVariable == "" {
modelConfig.EndpointVariable = defaultEndpointVar
}
}
}
project.Services[serviceName] = service
}
}
9 changes: 5 additions & 4 deletions pkg/compose/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import (
"os/exec"
"slices"
"strconv"
"strings"

"github.com/compose-spec/compose-go/v2/types"
"github.com/containerd/errdefs"
"github.com/docker/cli/cli-plugins/manager"
"github.com/docker/compose/v2/pkg/progress"
"github.com/docker/compose/v2/pkg/utils"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -200,19 +200,20 @@ func (m *modelAPI) SetModelVariables(ctx context.Context, project *types.Project
for _, service := range project.Services {
for ref, modelConfig := range service.Models {
model := project.Models[ref]
varPrefix := strings.ReplaceAll(strings.ToUpper(ref), "-", "_")
defaultModelVar, defaultEndpointVar := utils.GetModelVariables(ref)

var variable string
if modelConfig != nil && modelConfig.ModelVariable != "" {
variable = modelConfig.ModelVariable
} else {
variable = varPrefix + "_MODEL"
variable = defaultModelVar
}
service.Environment[variable] = &model.Model

if modelConfig != nil && modelConfig.EndpointVariable != "" {
variable = modelConfig.EndpointVariable
} else {
variable = varPrefix + "_URL"
variable = defaultEndpointVar
}
service.Environment[variable] = &status.Endpoint
}
Expand Down
27 changes: 27 additions & 0 deletions pkg/utils/modelvar.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
Copyright 2020 Docker Compose CLI authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package utils

import "strings"

// GetModelVariables generates default model and endpoint variable names from a model reference.
// It converts the model reference to uppercase and replaces hyphens with underscores.
// Returns modelVariable (e.g., "AI_RUNNER_MODEL") and endpointVariable (e.g., "AI_RUNNER_URL").
func GetModelVariables(modelRef string) (modelVariable, endpointVariable string) {
prefix := strings.ReplaceAll(strings.ToUpper(modelRef), "-", "_")
return prefix + "_MODEL", prefix + "_URL"
}
91 changes: 91 additions & 0 deletions pkg/utils/modelvar_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
Copyright 2020 Docker Compose CLI authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package utils

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetModelVariables(t *testing.T) {
tests := []struct {
name string
modelRef string
expectedModelVar string
expectedEndpointVar string
}{
{
name: "simple name with underscore",
modelRef: "ai_runner",
expectedModelVar: "AI_RUNNER_MODEL",
expectedEndpointVar: "AI_RUNNER_URL",
},
{
name: "name with hyphens",
modelRef: "ai-runner",
expectedModelVar: "AI_RUNNER_MODEL",
expectedEndpointVar: "AI_RUNNER_URL",
},
{
name: "complex name with multiple hyphens",
modelRef: "my-llm-engine",
expectedModelVar: "MY_LLM_ENGINE_MODEL",
expectedEndpointVar: "MY_LLM_ENGINE_URL",
},
{
name: "single word",
modelRef: "model",
expectedModelVar: "MODEL_MODEL",
expectedEndpointVar: "MODEL_URL",
},
{
name: "mixed case",
modelRef: "AiRunner",
expectedModelVar: "AIRUNNER_MODEL",
expectedEndpointVar: "AIRUNNER_URL",
},
{
name: "mixed case with hyphens",
modelRef: "Ai-Runner",
expectedModelVar: "AI_RUNNER_MODEL",
expectedEndpointVar: "AI_RUNNER_URL",
},
{
name: "already uppercase with underscores",
modelRef: "AI_RUNNER",
expectedModelVar: "AI_RUNNER_MODEL",
expectedEndpointVar: "AI_RUNNER_URL",
},
{
name: "lowercase simple",
modelRef: "airunner",
expectedModelVar: "AIRUNNER_MODEL",
expectedEndpointVar: "AIRUNNER_URL",
},
}

for _, tt := range tests {
t.Run(
tt.name, func(t *testing.T) {
modelVar, endpointVar := GetModelVariables(tt.modelRef)
assert.Equal(t, tt.expectedModelVar, modelVar, "modelVariable mismatch")
assert.Equal(t, tt.expectedEndpointVar, endpointVar, "endpointVariable mismatch")
},
)
}
}