@@ -2,7 +2,9 @@ package utils
2
2
3
3
import (
4
4
"chat/globals"
5
+ "fmt"
5
6
"strings"
7
+ "time"
6
8
)
7
9
8
10
type Charge interface {
@@ -28,7 +30,11 @@ type Buffer struct {
28
30
ToolCalls * globals.ToolCalls `json:"tool_calls"`
29
31
ToolCallsCursor int `json:"tool_calls_cursor"`
30
32
FunctionCall * globals.FunctionCall `json:"function_call"`
33
+ StartTime * time.Time `json:"-"`
34
+ Prompts string `json:"prompts"`
35
+ TokenName string `json:"-"`
31
36
Charge Charge `json:"-"`
37
+ VisionRecall bool `json:"-"`
32
38
}
33
39
34
40
func initInputToken (model string , history []globals.Message ) int {
@@ -71,6 +77,7 @@ func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
71
77
FunctionCall : nil ,
72
78
ToolCalls : nil ,
73
79
ToolCallsCursor : 0 ,
80
+ StartTime : ToPtr (time .Now ()),
74
81
}
75
82
}
76
83
@@ -79,6 +86,11 @@ func (b *Buffer) GetCursor() int {
79
86
}
80
87
81
88
func (b * Buffer ) GetQuota () float32 {
89
+ return b .Quota + CountOutputToken (b .Charge , b .CountOutputToken (true ))
90
+ }
91
+
92
+ func (b * Buffer ) GetRecordQuota () float32 {
93
+ // end of the buffer, the output token is counted using the times
82
94
return b .Quota + CountOutputToken (b .Charge , b .CountOutputToken (false ))
83
95
}
84
96
@@ -106,15 +118,23 @@ func (b *Buffer) GetChunk() string {
106
118
return b .Latest
107
119
}
108
120
121
+ func (b * Buffer ) InitVisionRecall () {
122
+ // set the vision recall flag to true to prevent the buffer from adding more images of retrying the channel
123
+ b .VisionRecall = true
124
+ }
125
+
109
126
func (b * Buffer ) AddImage (image * Image ) {
110
- if image != nil {
111
- b . Images = append ( b . Images , * image )
127
+ if image == nil || b . VisionRecall {
128
+ return
112
129
}
113
130
131
+ b .Images = append (b .Images , * image )
132
+
133
+ tokens := image .CountTokens (b .Model )
134
+ b .InputTokens += tokens
135
+
114
136
if b .Charge .IsBillingType (globals .TokenBilling ) {
115
- if image != nil {
116
- b .Quota += float32 (image .CountTokens (b .Model )) * b .Charge .GetInput ()
117
- }
137
+ b .Quota += float32 (tokens ) / 1000 * b .Charge .GetInput ()
118
138
}
119
139
}
120
140
@@ -145,6 +165,13 @@ func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.Tool
145
165
return 0 , nil
146
166
}
147
167
168
+ func appendTool (tool globals.ToolCall , chunk globals.ToolCall ) string {
169
+ from := ToString (tool .Function .Arguments )
170
+ to := ToString (chunk .Function .Arguments )
171
+
172
+ return from + to
173
+ }
174
+
148
175
func mixTools (source * globals.ToolCalls , target * globals.ToolCalls ) * globals.ToolCalls {
149
176
if source == nil {
150
177
return target
@@ -157,7 +184,7 @@ func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.Too
157
184
idx , hit := hitTool (tool , tools )
158
185
159
186
if hit != nil {
160
- tools [idx ].Function .Arguments += tool . Function . Arguments
187
+ tools [idx ].Function .Arguments = appendTool ( tools [ idx ], tool )
161
188
} else {
162
189
tools = append (tools , tool )
163
190
}
@@ -209,6 +236,27 @@ func (b *Buffer) GetCharge() Charge {
209
236
return b .Charge
210
237
}
211
238
239
+ func (b * Buffer ) ToChargeInfo () string {
240
+ switch b .Charge .GetType () {
241
+ case globals .TokenBilling :
242
+ return fmt .Sprintf (
243
+ "input tokens: %0.4f quota / 1k tokens\n " +
244
+ "output tokens: %0.4f quota / 1k tokens\n " ,
245
+ b .Charge .GetInput (), b .Charge .GetOutput (),
246
+ )
247
+ case globals .TimesBilling :
248
+ return fmt .Sprintf ("%f quota per request\n " , b .Charge .GetLimit ())
249
+ case globals .NonBilling :
250
+ return "no cost"
251
+ }
252
+
253
+ return ""
254
+ }
255
+
256
+ func (b * Buffer ) SetPrompts (prompts interface {}) {
257
+ b .Prompts = ToString (prompts )
258
+ }
259
+
212
260
func (b * Buffer ) Read () string {
213
261
return b .Data
214
262
}
@@ -247,5 +295,49 @@ func (b *Buffer) CountOutputToken(running bool) int {
247
295
}
248
296
249
297
func (b * Buffer ) CountToken () int {
250
- return b .CountInputToken () + b .CountOutputToken (false )
298
+ return b .CountInputToken () + b .CountOutputToken (true )
299
+ }
300
+
301
+ func (b * Buffer ) GetDuration () float32 {
302
+ if b .StartTime == nil {
303
+ return 0
304
+ }
305
+
306
+ return float32 (time .Since (* b .StartTime ).Seconds ())
307
+ }
308
+
309
+ func (b * Buffer ) GetStartTime () * time.Time {
310
+ return b .StartTime
311
+ }
312
+
313
+ func (b * Buffer ) GetPrompts () string {
314
+ return b .Prompts
315
+ }
316
+
317
+ func (b * Buffer ) GetTokenName () string {
318
+ if len (b .TokenName ) == 0 {
319
+ return globals .WebTokenType
320
+ }
321
+
322
+ return b .TokenName
323
+ }
324
+
325
+ func (b * Buffer ) SetTokenName (tokenName string ) {
326
+ b .TokenName = tokenName
327
+ }
328
+
329
+ func (b * Buffer ) GetRecordPrompts () string {
330
+ if ! globals .AcceptPromptStore {
331
+ return ""
332
+ }
333
+
334
+ return b .GetPrompts ()
335
+ }
336
+
337
+ func (b * Buffer ) GetRecordResponsePrompts () string {
338
+ if ! globals .AcceptPromptStore {
339
+ return ""
340
+ }
341
+
342
+ return b .Read ()
251
343
}
0 commit comments