Skip to content

Commit

Permalink
Refactor chat stream handling and add per-word stream limit function
Browse files Browse the repository at this point in the history
  • Loading branch information
swuecho committed Sep 11, 2024
1 parent 3e5a176 commit f232c20
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ func (h *ChatHandler) CheckModelAccess(w http.ResponseWriter, chatSessionUuid st
// no rows
if errors.Is(err, sql.ErrNoRows) {
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_get_rate_limit", err)
return true
return true
}
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_get_rate_limit", err)
return true
Expand All @@ -456,6 +456,21 @@ func (h *ChatHandler) CheckModelAccess(w http.ResponseWriter, chatSessionUuid st
return false
}

func getPerWordStreamLimit() (int, error) {
perWordStreamLimitStr := os.Getenv("PER_WORD_STREAM_LIMIT")

if perWordStreamLimitStr == "" {
perWordStreamLimitStr = "200"
}

perWordStreamLimit, err := strconv.Atoi(perWordStreamLimitStr)
if err != nil {
return 0, fmt.Errorf("per word stream limit error: %v", err)
}

return perWordStreamLimit, nil
}

func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []models.Message, chatUuid string, regenerate bool, streamOutput bool) (string, string, bool) {
// check per chat_model limit

Expand Down Expand Up @@ -516,7 +531,12 @@ func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries
}
defer stream.Close()

setSSEHeader(w)
// setSSEHeader(w)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")

flusher, ok := w.(http.Flusher)
if !ok {
Expand All @@ -531,6 +551,11 @@ func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries
answer_id = chatUuid
}

initial_resp := constructChatCompletionStreamReponse(answer_id, "!!!!")
data, _ := json.Marshal(initial_resp)
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()

for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
Expand Down Expand Up @@ -559,29 +584,23 @@ func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries
textIdx := response.Choices[0].Index
delta := response.Choices[0].Delta.Content
textBuffer.appendByIndex(textIdx, delta)
// log.Println(delta)

if chatSession.Debug {
log.Printf("%s", delta)
}
answer = textBuffer.String("\n")
if answer_id == "" {
answer_id = strings.TrimPrefix(response.ID, "chatcmpl-")
}
perWordStreamLimitStr := os.Getenv("PER_WORD_STREAM_LIMIT")

if perWordStreamLimitStr == "" {
perWordStreamLimitStr = "200"
}

perWordStreamLimit, err := strconv.Atoi(perWordStreamLimitStr)
perWordStreamLimit, err := getPerWordStreamLimit()
if err != nil {
RespondWithError(w, http.StatusInternalServerError, fmt.Sprintf("per word stream limit error: %v", err), nil)
RespondWithError(w, http.StatusInternalServerError, err.Error(), nil)
return "", "", true
}

if strings.HasSuffix(delta, "\n") || len(answer) < perWordStreamLimit {
response.Choices[0].Delta.Content = answer
data, _ := json.Marshal(response)
resp := constructChatCompletionStreamReponse(answer_id, answer)
data, _ := json.Marshal(resp)
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()
}
Expand Down Expand Up @@ -691,7 +710,13 @@ func (h *ChatHandler) CompletionStream(w http.ResponseWriter, chatSession sqlc_q
// concatenate all string builders into a single string
answer = textBuffer.String("\n\n")

if strings.HasSuffix(delta, "\n") || len(answer) < 200 {
perWordStreamLimit, err := getPerWordStreamLimit()
if err != nil {
RespondWithError(w, http.StatusInternalServerError, err.Error(), nil)
return "", "", true
}

if strings.HasSuffix(delta, "\n") || len(answer) < perWordStreamLimit {
response := constructChatCompletionStreamReponse(answer_id, answer)
data, _ := json.Marshal(response)
fmt.Fprintf(w, "data: %v\n\n", string(data))
Expand Down Expand Up @@ -1390,7 +1415,7 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q
if stream {
url = fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:streamGenerateContent?alt=sse&key=$GEMINI_API_KEY", chatSession.Model)
}

url = os.ExpandEnv(url)

req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
Expand Down

0 comments on commit f232c20

Please sign in to comment.