Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: aiproxy baidu ernie v2 #5287

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 1 addition & 47 deletions service/aiproxy/relay/adaptor/baidu/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,9 @@ import (
type Adaptor struct{}

const (
baseURL = "https://aip.baidubce.com"
baseURLV2 = "https://qianfan.baidubce.com"
baseURL = "https://aip.baidubce.com"
)

// func IsV2(modelName string) bool {
// return strings.HasPrefix(strings.ToLower(modelName), "ernie-")
// }

// func (a *Adaptor) getRequestURLV2(_ *meta.Meta) string {
// return baseURLV2 + "/v2/chat/completions"
// }

// var v2ModelMap = map[string]string{
// "ERNIE-4.0-8K-Latest": "ernie-4.0-8k-latest",
// "ERNIE-4.0-8K-Preview": "ernie-4.0-8k-preview",
// "ERNIE-4.0-8K": "ernie-4.0-8k",
// "ERNIE-4.0-Turbo-8K-Latest": "ernie-4.0-turbo-8k-latest",
// "ERNIE-4.0-Turbo-8K-Preview": "ernie-4.0-turbo-8k-preview",
// "ERNIE-4.0-Turbo-8K": "ernie-4.0-turbo-8k",
// "ERNIE-4.0-Turbo-128K": "ernie-4.0-turbo-128k",
// "ERNIE-3.5-8K-Preview": "ernie-3.5-8k-preview",
// "ERNIE-3.5-8K": "ernie-3.5-8k",
// "ERNIE-3.5-128K": "ernie-3.5-128k",
// "ERNIE-Speed-8K": "ernie-speed-8k",
// "ERNIE-Speed-128K": "ernie-speed-128k",
// "ERNIE-Speed-Pro-128K": "ernie-speed-pro-128k",
// "ERNIE-Lite-8K": "ernie-lite-8k",
// "ERNIE-Lite-Pro-128K": "ernie-lite-pro-128k",
// "ERNIE-Tiny-8K": "ernie-tiny-8k",
// "ERNIE-Character-8K": "ernie-char-8k",
// "ERNIE-Character-Fiction-8K": "ernie-char-fiction-8k",
// "ERNIE-Novel-8K": "ernie-novel-8k",
// }

// Get model-specific endpoint using map
var modelEndpointMap = map[string]string{
"ERNIE-4.0-8K": "completions_pro",
Expand Down Expand Up @@ -108,14 +77,6 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
}

func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, _ *gin.Context, req *http.Request) error {
// if IsV2(meta.ActualModelName) {
// token, err := GetBearerToken(meta.APIKey)
// if err != nil {
// return err
// }
// req.Header.Set("Authorization", "Bearer "+token.Token)
// return nil
// }
req.Header.Set("Authorization", "Bearer "+meta.Channel.Key)
accessToken, err := GetAccessToken(context.Background(), meta.Channel.Key)
if err != nil {
Expand All @@ -135,9 +96,6 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (http.Heade
case relaymode.ImagesGenerations:
return openai.ConvertRequest(meta, req)
default:
// if IsV2(meta.ActualModelName) {
// return openai.ConvertRequest(meta, req)
// }
return ConvertRequest(meta, req)
}
}
Expand All @@ -155,10 +113,6 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons
case relaymode.ImagesGenerations:
usage, err = ImageHandler(meta, c, resp)
default:
// if IsV2(meta.ActualModelName) {
// usage, err = openai.DoResponse(meta, c, resp)
// return
// }
if utils.IsStreamResponse(resp) {
err, usage = StreamHandler(meta, c, resp)
} else {
Expand Down
60 changes: 0 additions & 60 deletions service/aiproxy/relay/adaptor/baidu/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,6 @@ import (
)

var ModelList = []*model.ModelConfig{
{
Model: "ERNIE-4.0-8K",
Type: relaymode.ChatCompletions,
Owner: model.ModelOwnerBaidu,
InputPrice: 0.03,
OutputPrice: 0.09,
Config: map[model.ModelConfigKey]any{
model.ModelConfigMaxContextTokensKey: 5120,
model.ModelConfigMaxInputTokensKey: 5120,
model.ModelConfigMaxOutputTokensKey: 1024,
},
},
{
Model: "ERNIE-3.5-8K",
Type: relaymode.ChatCompletions,
Owner: model.ModelOwnerBaidu,
InputPrice: 0.0008,
OutputPrice: 0.002,
Config: map[model.ModelConfigKey]any{
model.ModelConfigMaxContextTokensKey: 5120,
model.ModelConfigMaxInputTokensKey: 5120,
model.ModelConfigMaxOutputTokensKey: 1024,
},
},
{
Model: "ERNIE-Tiny-8K",
Type: relaymode.ChatCompletions,
Owner: model.ModelOwnerBaidu,
InputPrice: 0.0001,
OutputPrice: 0.0001,
Config: map[model.ModelConfigKey]any{
model.ModelConfigMaxContextTokensKey: 6144,
model.ModelConfigMaxInputTokensKey: 6144,
model.ModelConfigMaxOutputTokensKey: 1024,
},
},
{
Model: "ERNIE-Speed-8K",
Type: relaymode.ChatCompletions,
Owner: model.ModelOwnerBaidu,
InputPrice: 0.0001,
OutputPrice: 0.0001,
Config: map[model.ModelConfigKey]any{
model.ModelConfigMaxContextTokensKey: 6144,
model.ModelConfigMaxInputTokensKey: 6144,
model.ModelConfigMaxOutputTokensKey: 1024,
},
},
{
Model: "ERNIE-Speed-128K",
Type: relaymode.ChatCompletions,
Owner: model.ModelOwnerBaidu,
InputPrice: 0.0001,
OutputPrice: 0.0001,
Config: map[model.ModelConfigKey]any{
model.ModelConfigMaxContextTokensKey: 126976,
model.ModelConfigMaxInputTokensKey: 126976,
model.ModelConfigMaxOutputTokensKey: 4096,
},
},
{
Model: "BLOOMZ-7B",
Type: relaymode.ChatCompletions,
Expand Down
82 changes: 6 additions & 76 deletions service/aiproxy/relay/adaptor/baidu/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@ package baidu

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"

json "github.com/json-iterator/go"
"github.com/labring/sealos/service/aiproxy/common/client"
log "github.com/sirupsen/logrus"
)

type AccessToken struct {
Expand All @@ -25,11 +22,6 @@ type AccessToken struct {
ExpiresIn int64 `json:"expires_in,omitempty"`
}

type TokenResponse struct {
ExpireTime time.Time `json:"expireTime"`
Token string `json:"token"`
}

var baiduTokenStore sync.Map

func GetAccessToken(ctx context.Context, apiKey string) (string, error) {
Expand All @@ -39,14 +31,18 @@ func GetAccessToken(ctx context.Context, apiKey string) (string, error) {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
_, _ = getBaiduAccessTokenHelper(context.Background(), apiKey)
_, err := getBaiduAccessTokenHelper(context.Background(), apiKey)
if err != nil {
log.Errorf("get baidu access token failed: %v", err)
}
}()
}
return accessToken.AccessToken, nil
}
}
accessToken, err := getBaiduAccessTokenHelper(ctx, apiKey)
if err != nil {
log.Errorf("get baidu access token failed: %v", err)
return "", errors.New("get baidu access token failed")
}
if accessToken == nil {
Expand Down Expand Up @@ -91,69 +87,3 @@ func getBaiduAccessTokenHelper(ctx context.Context, apiKey string) (*AccessToken
baiduTokenStore.Store(apiKey, accessToken)
return &accessToken, nil
}

func GetBearerToken(ctx context.Context, apiKey string) (*TokenResponse, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
}
if val, ok := baiduTokenStore.Load("bearer|" + apiKey); ok {
var tokenResponse TokenResponse
if tokenResponse, ok = val.(TokenResponse); ok {
if time.Now().Add(time.Hour).After(tokenResponse.ExpireTime) {
go func() {
_, _ = GetBearerToken(context.Background(), apiKey)
}()
}
return &tokenResponse, nil
}
}
authorization := generateAuthorizationString(parts[0], parts[1])
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://iam.bj.baidubce.com/v1/BCE-BEARER/token", nil)
if err != nil {
return nil, err
}
query := url.Values{}
query.Add("expireInSeconds", "86400")
req.URL.RawQuery = query.Encode()
req.Header.Add("Authorization", authorization)
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
var tokenResponse TokenResponse
err = json.NewDecoder(res.Body).Decode(&tokenResponse)
if err != nil {
return nil, err
}
baiduTokenStore.Store("bearer|"+apiKey, tokenResponse)
return &tokenResponse, nil
}

func generateAuthorizationString(ak, sk string) string {
httpMethod := http.MethodGet
uri := "/v1/BCE-BEARER/token"
queryString := "expireInSeconds=86400"
hostHeader := "iam.bj.baidubce.com"
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\nhost:%s", httpMethod, uri, queryString, hostHeader)

timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
expirationPeriodInSeconds := 1800
authStringPrefix := fmt.Sprintf("bce-auth-v1/%s/%s/%d", ak, timestamp, expirationPeriodInSeconds)

signingKey := hmacSHA256(sk, authStringPrefix)

signature := hmacSHA256(signingKey, canonicalRequest)

signedHeaders := "host"
authorization := fmt.Sprintf("%s/%s/%s", authStringPrefix, signedHeaders, signature)

return authorization
}

func hmacSHA256(key, data string) string {
h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
114 changes: 114 additions & 0 deletions service/aiproxy/relay/adaptor/baiduv2/adaptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package baiduv2

import (
"context"
"fmt"
"io"
"net/http"
"strings"

"github.com/labring/sealos/service/aiproxy/model"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/meta"
"github.com/labring/sealos/service/aiproxy/relay/relaymode"
"github.com/labring/sealos/service/aiproxy/relay/utils"

"github.com/gin-gonic/gin"
relaymodel "github.com/labring/sealos/service/aiproxy/relay/model"
)

type Adaptor struct{}

const (
baseURL = "https://qianfan.baidubce.com"
)

// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Fm2vrveyu
var v2ModelMap = map[string]string{
"ERNIE-4.0-8K-Latest": "ernie-4.0-8k-latest",
"ERNIE-4.0-8K-Preview": "ernie-4.0-8k-preview",
"ERNIE-4.0-8K": "ernie-4.0-8k",
"ERNIE-4.0-Turbo-8K-Latest": "ernie-4.0-turbo-8k-latest",
"ERNIE-4.0-Turbo-8K-Preview": "ernie-4.0-turbo-8k-preview",
"ERNIE-4.0-Turbo-8K": "ernie-4.0-turbo-8k",
"ERNIE-4.0-Turbo-128K": "ernie-4.0-turbo-128k",
"ERNIE-3.5-8K-Preview": "ernie-3.5-8k-preview",
"ERNIE-3.5-8K": "ernie-3.5-8k",
"ERNIE-3.5-128K": "ernie-3.5-128k",
"ERNIE-Speed-8K": "ernie-speed-8k",
"ERNIE-Speed-128K": "ernie-speed-128k",
"ERNIE-Speed-Pro-128K": "ernie-speed-pro-128k",
"ERNIE-Lite-8K": "ernie-lite-8k",
"ERNIE-Lite-Pro-128K": "ernie-lite-pro-128k",
"ERNIE-Tiny-8K": "ernie-tiny-8k",
"ERNIE-Character-8K": "ernie-char-8k",
"ERNIE-Character-Fiction-8K": "ernie-char-fiction-8k",
"ERNIE-Novel-8K": "ernie-novel-8k",
}

func toV2ModelName(modelName string) string {
if v2Model, ok := v2ModelMap[modelName]; ok {
return v2Model
}
return strings.ToLower(modelName)
}

func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Channel.BaseURL == "" {
meta.Channel.BaseURL = baseURL
}

switch meta.Mode {
case relaymode.ChatCompletions:
return meta.Channel.BaseURL + "/v2/chat/completions", nil
default:
return "", fmt.Errorf("unsupported mode: %d", meta.Mode)
}
}

func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, _ *gin.Context, req *http.Request) error {
token, err := GetBearerToken(context.Background(), meta.Channel.Key)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+token.Token)
return nil
}

func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (http.Header, io.Reader, error) {
switch meta.Mode {
case relaymode.ChatCompletions:
actModel := meta.ActualModelName
v2Model := toV2ModelName(actModel)
meta.ActualModelName = v2Model
defer func() { meta.ActualModelName = actModel }()
return openai.ConvertRequest(meta, req)
default:
return nil, nil, fmt.Errorf("unsupported mode: %d", meta.Mode)
}
}

func (a *Adaptor) DoRequest(_ *meta.Meta, _ *gin.Context, req *http.Request) (*http.Response, error) {
return utils.DoRequest(req)
}

func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (usage *relaymodel.Usage, err *relaymodel.ErrorWithStatusCode) {
switch meta.Mode {
case relaymode.ChatCompletions:
return openai.DoResponse(meta, c, resp)
default:
return nil, openai.ErrorWrapperWithMessage(
fmt.Sprintf("unsupported mode: %d", meta.Mode),
nil,
http.StatusBadRequest,
)
}
}

func (a *Adaptor) GetModelList() []*model.ModelConfig {
return ModelList
}

func (a *Adaptor) GetChannelName() string {
return "baidu v2"
}
Loading
Loading