Skip to content

Commit

Permalink
feat: supper whisper now (close lobehub#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
songquanpeng committed Aug 27, 2023
1 parent 1c4409a commit d09d317
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 4 deletions.
2 changes: 1 addition & 1 deletion common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"davinci": 10,
"curie": 10,
"babbage": 10,
Expand Down
9 changes: 9 additions & 0 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ func init() {
Root: "dall-e",
Parent: nil,
},
{
Id: "whisper-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "whisper-1",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Expand Down
147 changes: 147 additions & 0 deletions controller/relay-audio.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package controller

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"

"github.com/gin-gonic/gin"
)

func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
audioModel := "whisper-1"

tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
group := c.GetString("group")

preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}

// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioModel] != "" {
audioModel = modelMap[audioModel]
}
}

baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()

if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}

fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
requestBody := c.Request.Body

req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))

resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}

err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var audioResponse AudioResponse

defer func() {
go func() {
quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}()

responseBody, err := io.ReadAll(resp.Body)

if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}

resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))

for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)

_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}
9 changes: 9 additions & 0 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudio
)

// https://platform.openai.com/docs/api-reference/chat
Expand Down Expand Up @@ -63,6 +64,10 @@ type ImageRequest struct {
Size string `json:"size"`
}

type AudioResponse struct {
Text string `json:"text,omitempty"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
Expand Down Expand Up @@ -159,11 +164,15 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode = RelayModeAudio
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudio:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
}
Expand Down
10 changes: 9 additions & 1 deletion middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ func Distribute() func(c *gin.Context) {
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
var err error
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
Expand All @@ -84,6 +87,11 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
Expand Down
4 changes: 2 additions & 2 deletions router/relay-router.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.Relay)
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
Expand Down

0 comments on commit d09d317

Please sign in to comment.