Skip to content

Commit 576213d

Browse files
committed
fix: nil pointer dereference error when carrying an image to a conversation (coaidev#221)
1 parent 9cb9580 commit 576213d

File tree

5 files changed

+137
-13
lines changed

5 files changed

+137
-13
lines changed

channel/system.go

+9-6
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,13 @@ type SearchState struct {
7373
}
7474

7575
type commonState struct {
76-
Article []string `json:"article" mapstructure:"article"`
77-
Generation []string `json:"generation" mapstructure:"generation"`
78-
Cache []string `json:"cache" mapstructure:"cache"`
79-
Expire int64 `json:"expire" mapstructure:"expire"`
80-
Size int64 `json:"size" mapstructure:"size"`
81-
ImageStore bool `json:"image_store" mapstructure:"imagestore"`
76+
Article []string `json:"article" mapstructure:"article"`
77+
Generation []string `json:"generation" mapstructure:"generation"`
78+
Cache []string `json:"cache" mapstructure:"cache"`
79+
Expire int64 `json:"expire" mapstructure:"expire"`
80+
Size int64 `json:"size" mapstructure:"size"`
81+
ImageStore bool `json:"image_store" mapstructure:"imagestore"`
82+
PromptStore bool `json:"prompt_store" mapstructure:"promptstore"`
8283
}
8384

8485
type SystemConfig struct {
@@ -114,6 +115,8 @@ func (c *SystemConfig) Load() {
114115
globals.CacheAcceptedSize = c.GetCacheAcceptedSize()
115116
globals.AcceptImageStore = c.AcceptImageStore()
116117

118+
globals.AcceptPromptStore = c.Common.PromptStore
119+
117120
if c.General.PWAManifest == "" {
118121
c.General.PWAManifest = utils.ReadPWAManifest()
119122
}

globals/constant.go

+6
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,9 @@ const (
4848
HttpsProxyType
4949
Socks5ProxyType
5050
)
51+
52+
const (
53+
WebTokenType = "web"
54+
ApiTokenType = "api"
55+
SystemToken = "system"
56+
)

globals/variables.go

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ var CacheAcceptedModels []string
2323
var CacheAcceptedExpire int64
2424
var CacheAcceptedSize int64
2525
var AcceptImageStore bool
26+
var AcceptPromptStore bool
2627
var CloseRegistration bool
2728
var CloseRelay bool
2829

utils/buffer.go

+99-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package utils
22

33
import (
44
"chat/globals"
5+
"fmt"
56
"strings"
7+
"time"
68
)
79

810
type Charge interface {
@@ -28,7 +30,11 @@ type Buffer struct {
2830
ToolCalls *globals.ToolCalls `json:"tool_calls"`
2931
ToolCallsCursor int `json:"tool_calls_cursor"`
3032
FunctionCall *globals.FunctionCall `json:"function_call"`
33+
StartTime *time.Time `json:"-"`
34+
Prompts string `json:"prompts"`
35+
TokenName string `json:"-"`
3136
Charge Charge `json:"-"`
37+
VisionRecall bool `json:"-"`
3238
}
3339

3440
func initInputToken(model string, history []globals.Message) int {
@@ -71,6 +77,7 @@ func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
7177
FunctionCall: nil,
7278
ToolCalls: nil,
7379
ToolCallsCursor: 0,
80+
StartTime: ToPtr(time.Now()),
7481
}
7582
}
7683

@@ -79,6 +86,11 @@ func (b *Buffer) GetCursor() int {
7986
}
8087

8188
func (b *Buffer) GetQuota() float32 {
89+
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(true))
90+
}
91+
92+
func (b *Buffer) GetRecordQuota() float32 {
93+
// end of the buffer, the output token is counted using the times
8294
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
8395
}
8496

@@ -106,15 +118,23 @@ func (b *Buffer) GetChunk() string {
106118
return b.Latest
107119
}
108120

121+
func (b *Buffer) InitVisionRecall() {
122+
// set the vision recall flag to true to prevent the buffer from adding more images of retrying the channel
123+
b.VisionRecall = true
124+
}
125+
109126
func (b *Buffer) AddImage(image *Image) {
110-
if image != nil {
111-
b.Images = append(b.Images, *image)
127+
if image == nil || b.VisionRecall {
128+
return
112129
}
113130

131+
b.Images = append(b.Images, *image)
132+
133+
tokens := image.CountTokens(b.Model)
134+
b.InputTokens += tokens
135+
114136
if b.Charge.IsBillingType(globals.TokenBilling) {
115-
if image != nil {
116-
b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput()
117-
}
137+
b.Quota += float32(tokens) / 1000 * b.Charge.GetInput()
118138
}
119139
}
120140

@@ -145,6 +165,13 @@ func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.Tool
145165
return 0, nil
146166
}
147167

168+
func appendTool(tool globals.ToolCall, chunk globals.ToolCall) string {
169+
from := ToString(tool.Function.Arguments)
170+
to := ToString(chunk.Function.Arguments)
171+
172+
return from + to
173+
}
174+
148175
func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls {
149176
if source == nil {
150177
return target
@@ -157,7 +184,7 @@ func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.Too
157184
idx, hit := hitTool(tool, tools)
158185

159186
if hit != nil {
160-
tools[idx].Function.Arguments += tool.Function.Arguments
187+
tools[idx].Function.Arguments = appendTool(tools[idx], tool)
161188
} else {
162189
tools = append(tools, tool)
163190
}
@@ -209,6 +236,27 @@ func (b *Buffer) GetCharge() Charge {
209236
return b.Charge
210237
}
211238

239+
func (b *Buffer) ToChargeInfo() string {
240+
switch b.Charge.GetType() {
241+
case globals.TokenBilling:
242+
return fmt.Sprintf(
243+
"input tokens: %0.4f quota / 1k tokens\n"+
244+
"output tokens: %0.4f quota / 1k tokens\n",
245+
b.Charge.GetInput(), b.Charge.GetOutput(),
246+
)
247+
case globals.TimesBilling:
248+
return fmt.Sprintf("%f quota per request\n", b.Charge.GetLimit())
249+
case globals.NonBilling:
250+
return "no cost"
251+
}
252+
253+
return ""
254+
}
255+
256+
func (b *Buffer) SetPrompts(prompts interface{}) {
257+
b.Prompts = ToString(prompts)
258+
}
259+
212260
func (b *Buffer) Read() string {
213261
return b.Data
214262
}
@@ -247,5 +295,49 @@ func (b *Buffer) CountOutputToken(running bool) int {
247295
}
248296

249297
func (b *Buffer) CountToken() int {
250-
return b.CountInputToken() + b.CountOutputToken(false)
298+
return b.CountInputToken() + b.CountOutputToken(true)
299+
}
300+
301+
func (b *Buffer) GetDuration() float32 {
302+
if b.StartTime == nil {
303+
return 0
304+
}
305+
306+
return float32(time.Since(*b.StartTime).Seconds())
307+
}
308+
309+
func (b *Buffer) GetStartTime() *time.Time {
310+
return b.StartTime
311+
}
312+
313+
func (b *Buffer) GetPrompts() string {
314+
return b.Prompts
315+
}
316+
317+
func (b *Buffer) GetTokenName() string {
318+
if len(b.TokenName) == 0 {
319+
return globals.WebTokenType
320+
}
321+
322+
return b.TokenName
323+
}
324+
325+
func (b *Buffer) SetTokenName(tokenName string) {
326+
b.TokenName = tokenName
327+
}
328+
329+
func (b *Buffer) GetRecordPrompts() string {
330+
if !globals.AcceptPromptStore {
331+
return ""
332+
}
333+
334+
return b.GetPrompts()
335+
}
336+
337+
func (b *Buffer) GetRecordResponsePrompts() string {
338+
if !globals.AcceptPromptStore {
339+
return ""
340+
}
341+
342+
return b.Read()
251343
}

utils/net.go

+22
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,28 @@ func Post(uri string, headers map[string]string, body interface{}, config ...glo
183183
return data, err
184184
}
185185

186+
func ToString(data interface{}) string {
187+
switch v := data.(type) {
188+
case string:
189+
return v
190+
case int, int8, int16, int32, int64:
191+
return fmt.Sprintf("%d", v)
192+
case uint, uint8, uint16, uint32, uint64:
193+
return fmt.Sprintf("%d", v)
194+
case float32, float64:
195+
return fmt.Sprintf("%f", v)
196+
case bool:
197+
return fmt.Sprintf("%t", v)
198+
default:
199+
data := Marshal(data)
200+
if len(data) > 0 {
201+
return data
202+
}
203+
204+
return fmt.Sprintf("%v", data)
205+
}
206+
}
207+
186208
func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) {
187209
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config)
188210
if err != nil {

0 commit comments

Comments
 (0)