|
| 1 | +package chatgpt |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "strings" |
| 7 | + "time" |
| 8 | + |
| 9 | + "github.com/PullRequestInc/go-gpt3" |
| 10 | + |
| 11 | + "github.com/yqchilde/wxbot/engine/pkg/log" |
| 12 | +) |
| 13 | + |
| 14 | +var ( |
| 15 | + gptClient gpt3.Client |
| 16 | + gptModel *GptModel |
| 17 | +) |
| 18 | + |
| 19 | +func AskChatGpt(question string, delay ...time.Duration) (answer string, err error) { |
| 20 | + // 获取客户端 |
| 21 | + if gptClient == nil { |
| 22 | + gptClient, err = getGptClient() |
| 23 | + if err != nil { |
| 24 | + return "", err |
| 25 | + } |
| 26 | + } |
| 27 | + |
| 28 | + // 获取模型 |
| 29 | + if gptModel == nil { |
| 30 | + gptModel, err = getGptModel() |
| 31 | + if err != nil { |
| 32 | + return "", err |
| 33 | + } |
| 34 | + } |
| 35 | + |
| 36 | + // 延迟请求 |
| 37 | + if len(delay) > 0 { |
| 38 | + time.Sleep(delay[0]) |
| 39 | + } |
| 40 | + |
| 41 | + // 请求gpt3 |
| 42 | + resp, err := gptClient.CompletionWithEngine(context.Background(), gptModel.Model, gpt3.CompletionRequest{ |
| 43 | + Prompt: []string{question}, |
| 44 | + MaxTokens: gpt3.IntPtr(gptModel.MaxTokens), |
| 45 | + Temperature: gpt3.Float32Ptr(float32(gptModel.Temperature)), |
| 46 | + TopP: gpt3.Float32Ptr(float32(gptModel.TopP)), |
| 47 | + Echo: false, |
| 48 | + PresencePenalty: float32(gptModel.PresencePenalty), |
| 49 | + FrequencyPenalty: float32(gptModel.FrequencyPenalty), |
| 50 | + }) |
| 51 | + |
| 52 | + // 处理响应回来的错误 |
| 53 | + if err != nil { |
| 54 | + if strings.Contains(err.Error(), "You exceeded your current quota") { |
| 55 | + log.Printf("当前apiKey[%s]配额已用完, 将删除并切换到下一个", apiKeys[0].Key) |
| 56 | + db.Orm.Table("apikey").Where("key = ?", apiKeys[0].Key).Delete(&ApiKey{}) |
| 57 | + if len(apiKeys) == 1 { |
| 58 | + return "", errors.New("OpenAi配额已用完,请联系管理员") |
| 59 | + } |
| 60 | + apiKeys = apiKeys[1:] |
| 61 | + gptClient = gpt3.NewClient(apiKeys[0].Key, gpt3.WithTimeout(time.Minute)) |
| 62 | + return AskChatGpt(question) |
| 63 | + } |
| 64 | + if strings.Contains(err.Error(), "The server had an error while processing your request") { |
| 65 | + log.Println("OpenAi服务出现问题,将重试") |
| 66 | + return AskChatGpt(question) |
| 67 | + } |
| 68 | + if strings.Contains(err.Error(), "Client.Timeout exceeded while awaiting headers") { |
| 69 | + log.Println("OpenAi服务请求超时,将重试") |
| 70 | + return AskChatGpt(question) |
| 71 | + } |
| 72 | + if strings.Contains(err.Error(), "Please reduce your prompt") { |
| 73 | + return "", errors.New("OpenAi免费上下文长度限制为4097个词组,您的上下文长度已超出限制,请发送\"清空会话\"以清空上下文") |
| 74 | + } |
| 75 | + return "", err |
| 76 | + } |
| 77 | + return resp.Choices[0].Text, nil |
| 78 | +} |
| 79 | + |
| 80 | +// filterAnswer 过滤答案,处理一些符号问题 |
| 81 | +// return 新的答案,是否需要重试 |
| 82 | +func filterAnswer(answer string) (newAnswer string, isNeedRetry bool) { |
| 83 | + punctuation := ",,!!??" |
| 84 | + answer = strings.TrimSpace(answer) |
| 85 | + if len(answer) == 1 { |
| 86 | + return answer, true |
| 87 | + } |
| 88 | + if len(answer) == 3 && strings.ContainsAny(answer, punctuation) { |
| 89 | + return answer, true |
| 90 | + } |
| 91 | + answer = strings.TrimLeftFunc(answer, func(r rune) bool { |
| 92 | + if strings.ContainsAny(string(r), punctuation) { |
| 93 | + return true |
| 94 | + } |
| 95 | + return false |
| 96 | + }) |
| 97 | + return answer, false |
| 98 | +} |
0 commit comments