Skip to content

Commit fe81c7b

Browse files
feat: Add function config (milvus-io#40534)
milvus-io#35856 1. Add function-related configuration in milvus.yaml 2. Add null and empty value check to TextEmbeddingFunction Signed-off-by: junjie.jiang <[email protected]>
1 parent 16efcda commit fe81c7b

27 files changed

+498
-176
lines changed

cmd/tools/config/generate.go

+5
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,11 @@ func WriteYaml(w io.Writer) {
351351
header: `
352352
# Any configuration related to the knowhere vector search engine`,
353353
},
354+
{
355+
name: "function",
356+
header: `
357+
# Any configuration related to functions`,
358+
},
354359
}
355360
marshller := YamlMarshaller{w, groups, result}
356361
marshller.writeYamlRecursive(lo.Filter(result, func(d DocContent, _ int) bool {

configs/milvus.yaml

+30
Original file line numberDiff line numberDiff line change
@@ -1157,3 +1157,33 @@ knowhere:
11571157
search_list_size: 100 # Size of the candidate list during building graph
11581158
search:
11591159
beam_width_ratio: 4 # Ratio between the maximum number of IO requests per search iteration and CPU number
1160+
1161+
# Any configuration related to functions
1162+
function:
1163+
textEmbedding:
1164+
enableVerifiInfoInParams: true # Controls whether to allow configuration of apikey and model service url on function parameters
1165+
providers:
1166+
azure_openai:
1167+
api_key: # Your azure openai embedding url, Default is the official embedding url
1168+
resource_name: # Your azure openai resource name
1169+
url: # Your azure openai api key
1170+
bedrock:
1171+
aws_access_key_id: # Your aws_access_key_id
1172+
aws_secret_access_key: # Your aws_secret_access_key
1173+
cohere:
1174+
api_key: # Your cohere embedding url, Default is the official embedding url
1175+
url: # Your cohere api key
1176+
dashscope:
1177+
api_key: # Your dashscope embedding url, Default is the official embedding url
1178+
url: # Your dashscope api key
1179+
openai:
1180+
api_key: # Your openai embedding url, Default is the official embedding url
1181+
url: # Your openai api key
1182+
siliconflow:
1183+
api_key: # Your siliconflow api key
1184+
url: # Your siliconflow embedding url, Default is the official embedding url
1185+
tei:
1186+
enable: true # Whether to enable TEI model service
1187+
vertexai:
1188+
credentials_file_path: # Path to your google application credentials, change the file path to refresh the configuration
1189+
url: # Your VertexAI embedding url

internal/util/function/ali_embedding_provider.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package function
2020

2121
import (
2222
"fmt"
23-
"os"
2423
"strings"
2524

2625
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@@ -41,9 +40,6 @@ type AliEmbeddingProvider struct {
4140
}
4241

4342
func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) {
44-
if apiKey == "" {
45-
apiKey = os.Getenv(dashscopeAKEnvStr)
46-
}
4743
if apiKey == "" {
4844
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr)
4945
}
@@ -55,12 +51,12 @@ func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbed
5551
return c, nil
5652
}
5753

58-
func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*AliEmbeddingProvider, error) {
54+
func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*AliEmbeddingProvider, error) {
5955
fieldDim, err := typeutil.GetDim(fieldSchema)
6056
if err != nil {
6157
return nil, err
6258
}
63-
apiKey, url := parseAKAndURL(functionSchema.Params)
59+
apiKey, url := parseAKAndURL(functionSchema.Params, params, dashscopeAKEnvStr)
6460
var modelName string
6561
var dim int64
6662

internal/util/function/alitext_embedding_provider_test.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"fmt"
2424
"net/http"
2525
"net/http/httptest"
26-
"os"
2726
"testing"
2827

2928
"github.com/stretchr/testify/suite"
@@ -77,7 +76,7 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st
7776
}
7877
switch providerName {
7978
case aliDashScopeProvider:
80-
return NewAliDashScopeEmbeddingProvider(schema, functionSchema)
79+
return NewAliDashScopeEmbeddingProvider(schema, functionSchema, map[string]string{})
8180
default:
8281
return nil, fmt.Errorf("Unknow provider")
8382
}
@@ -170,11 +169,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
170169
func (s *AliTextEmbeddingProviderSuite) TestCreateAliEmbeddingClient() {
171170
_, err := createAliEmbeddingClient("", "")
172171
s.Error(err)
173-
174-
os.Setenv(dashscopeAKEnvStr, "mock_key")
175-
defer os.Unsetenv(dashscopeAKEnvStr)
176-
_, err = createAliEmbeddingClient("", "")
177-
s.NoError(err)
178172
}
179173

180174
func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
@@ -193,6 +187,6 @@ func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
193187
}
194188
// invalid dim
195189
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
196-
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
190+
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
197191
s.Error(err)
198192
}

internal/util/function/bedrock_embedding_provider.go

+40-17
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/aws/aws-sdk-go-v2/credentials"
3131
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
3232

33+
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
3334
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
3435
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
3536
)
@@ -51,16 +52,9 @@ type BedrockEmbeddingProvider struct {
5152
}
5253

5354
func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) {
54-
if awsAccessKeyId == "" {
55-
awsAccessKeyId = os.Getenv(bedrockAccessKeyId)
56-
}
5755
if awsAccessKeyId == "" {
5856
return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId)
5957
}
60-
61-
if awsSecretAccessKey == "" {
62-
awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr)
63-
}
6458
if awsSecretAccessKey == "" {
6559
return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr)
6660
}
@@ -79,12 +73,47 @@ func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey stri
7973
return bedrockruntime.NewFromConfig(cfg), nil
8074
}
8175

82-
func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient) (*BedrockEmbeddingProvider, error) {
76+
func parseAccessInfo(params []*commonpb.KeyValuePair, confParams map[string]string) (string, string) {
77+
// function param > env > yaml
78+
var awsAccessKeyId, awsSecretAccessKey string
79+
80+
// from function params
81+
if isEnableVerifiInfoInParamsKey(confParams) {
82+
for _, param := range params {
83+
switch strings.ToLower(param.Key) {
84+
case awsAKIdParamKey:
85+
awsAccessKeyId = param.Value
86+
case awsSAKParamKey:
87+
awsSecretAccessKey = param.Value
88+
}
89+
}
90+
}
91+
92+
// from milvus.yaml
93+
if awsAccessKeyId == "" {
94+
awsAccessKeyId = confParams[awsAKIdParamKey]
95+
}
96+
if awsSecretAccessKey == "" {
97+
awsSecretAccessKey = confParams[awsSAKParamKey]
98+
}
99+
100+
// from env
101+
if awsAccessKeyId == "" {
102+
awsAccessKeyId = os.Getenv(bedrockAccessKeyId)
103+
}
104+
if awsSecretAccessKey == "" {
105+
awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr)
106+
}
107+
108+
return awsAccessKeyId, awsSecretAccessKey
109+
}
110+
111+
func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient, params map[string]string) (*BedrockEmbeddingProvider, error) {
83112
fieldDim, err := typeutil.GetDim(fieldSchema)
84113
if err != nil {
85114
return nil, err
86115
}
87-
var awsAccessKeyId, awsSecretAccessKey, region, modelName string
116+
var region, modelName string
88117
var dim int64
89118
normalize := true
90119

@@ -97,14 +126,6 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
97126
if err != nil {
98127
return nil, err
99128
}
100-
case awsAKIdParamKey:
101-
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
102-
awsAccessKeyId = param.Value
103-
}
104-
case awsSAKParamKey:
105-
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
106-
awsSecretAccessKey = param.Value
107-
}
108129
case regionParamKey:
109130
region = param.Value
110131
case normalizeParamKey:
@@ -120,6 +141,8 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
120141
}
121142
}
122143

144+
awsAccessKeyId, awsSecretAccessKey := parseAccessInfo(functionSchema.Params, params)
145+
123146
var client BedrockClient
124147
if c == nil {
125148
client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region)

internal/util/function/bedrock_text_embedding_provider_test.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, di
7171
}
7272
switch providerName {
7373
case bedrockProvider:
74-
return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim})
74+
return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim}, map[string]string{})
7575
default:
7676
return nil, fmt.Errorf("Unknow provider")
7777
}
@@ -151,22 +151,25 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
151151
{Key: normalizeParamKey, Value: "false"},
152152
},
153153
}
154-
provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
154+
provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{})
155155
s.NoError(err)
156156
s.True(provider.MaxBatch() > 0)
157157
s.Equal(provider.FieldDim(), int64(4))
158158

159+
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{awsAKIdParamKey: "mock", awsSAKParamKey: "mock"})
160+
s.NoError(err)
161+
159162
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
160-
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
163+
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{})
161164
s.NoError(err)
162165

163166
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "invalid"}
164-
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
167+
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{})
165168
s.Error(err)
166169

167170
// invalid dim
168171
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel}
169172
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
170-
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
173+
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{})
171174
s.Error(err)
172175
}

internal/util/function/cohere_embedding_provider.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package function
2020

2121
import (
2222
"fmt"
23-
"os"
2423
"strings"
2524

2625
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@@ -42,9 +41,6 @@ type CohereEmbeddingProvider struct {
4241
}
4342

4443
func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbedding, error) {
45-
if apiKey == "" {
46-
apiKey = os.Getenv(cohereAIAKEnvStr)
47-
}
4844
if apiKey == "" {
4945
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", cohereAIAKEnvStr)
5046
}
@@ -57,12 +53,12 @@ func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbed
5753
return c, nil
5854
}
5955

60-
func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*CohereEmbeddingProvider, error) {
56+
func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*CohereEmbeddingProvider, error) {
6157
fieldDim, err := typeutil.GetDim(fieldSchema)
6258
if err != nil {
6359
return nil, err
6460
}
65-
apiKey, url := parseAKAndURL(functionSchema.Params)
61+
apiKey, url := parseAKAndURL(functionSchema.Params, params, cohereAIAKEnvStr)
6662
var modelName string
6763
truncate := "END"
6864
for _, param := range functionSchema.Params {

internal/util/function/cohere_embedding_provider_test.go

+6-13
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"fmt"
2424
"net/http"
2525
"net/http/httptest"
26-
"os"
2726
"testing"
2827

2928
"github.com/stretchr/testify/suite"
@@ -76,7 +75,7 @@ func createCohereProvider(url string, schema *schemapb.FieldSchema, providerName
7675
}
7776
switch providerName {
7877
case cohereProvider:
79-
return NewCohereEmbeddingProvider(schema, functionSchema)
78+
return NewCohereEmbeddingProvider(schema, functionSchema, map[string]string{})
8079
default:
8180
return nil, fmt.Errorf("Unknow provider")
8281
}
@@ -264,18 +263,18 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() {
264263
},
265264
}
266265

267-
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
266+
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
268267
s.NoError(err)
269268
s.Equal(provider.truncate, "END")
270269

271270
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "START"})
272-
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
271+
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
273272
s.NoError(err)
274273
s.Equal(provider.truncate, "START")
275274

276275
// Invalid truncateParam
277276
functionSchema.Params[2].Value = "Unknow"
278-
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
277+
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
279278
s.Error(err)
280279
}
281280

@@ -293,13 +292,13 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
293292
},
294293
}
295294

296-
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
295+
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
297296
s.NoError(err)
298297
s.Equal(provider.getInputType(InsertMode), "")
299298
s.Equal(provider.getInputType(SearchMode), "")
300299

301300
functionSchema.Params[0].Value = "model-v3.0"
302-
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
301+
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
303302
s.NoError(err)
304303
s.Equal(provider.getInputType(InsertMode), "search_document")
305304
s.Equal(provider.getInputType(SearchMode), "search_query")
@@ -308,12 +307,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
308307
func (s *CohereTextEmbeddingProviderSuite) TestCreateCohereEmbeddingClient() {
309308
_, err := createCohereEmbeddingClient("", "")
310309
s.Error(err)
311-
312-
os.Setenv(cohereAIAKEnvStr, "mockKey")
313-
defer os.Unsetenv(openaiAKEnvStr)
314-
315-
_, err = createCohereEmbeddingClient("", "")
316-
s.NoError(err)
317310
}
318311

319312
func (s *CohereTextEmbeddingProviderSuite) TestRuntimeDimNotMatch() {

0 commit comments

Comments
 (0)