Skip to content

Commit b7a6a15

Browse files
committed
fix: streaming parser overrun
1 parent f460c57 commit b7a6a15

File tree

12 files changed

+138
-74
lines changed

12 files changed

+138
-74
lines changed

src/ax/ai/base.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ export class AxBaseAI<
456456
}
457457

458458
if (options?.debug ?? this.debug) {
459-
logChatRequest(req.chatPrompt)
459+
logChatRequest(req.chatPrompt, options?.debugHideSystemPrompt)
460460
}
461461

462462
const rt = options?.rateLimiter ?? this.rt

src/ax/ai/debug.ts

+13-5
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ const colorLog = new ColorLog()
66

77
const formatChatMessage = (
88
msg: AxChatRequest['chatPrompt'][number],
9-
hideContent?: boolean
9+
hideContent?: boolean,
10+
hideSystemPrompt?: boolean
1011
) => {
1112
switch (msg.role) {
1213
case 'system':
14+
if (hideSystemPrompt) {
15+
return ''
16+
}
1317
return `\n${colorLog.blueBright('System:')}\n${colorLog.whiteBright(msg.content)}`
1418
case 'function':
1519
return `\n${colorLog.blueBright('Function Result:')}\n${colorLog.whiteBright(msg.result)}`
@@ -48,16 +52,20 @@ const formatChatMessage = (
4852
}
4953

5054
export const logChatRequestMessage = (
51-
msg: AxChatRequest['chatPrompt'][number]
55+
msg: AxChatRequest['chatPrompt'][number],
56+
hideSystemPrompt?: boolean
5257
) => {
53-
process.stdout.write(`${formatChatMessage(msg)}\n`)
58+
process.stdout.write(`${formatChatMessage(msg, hideSystemPrompt)}\n`)
5459
process.stdout.write(colorLog.blueBright('\nAssistant:\n'))
5560
}
5661

5762
export const logChatRequest = (
58-
chatPrompt: Readonly<AxChatRequest['chatPrompt']>
63+
chatPrompt: Readonly<AxChatRequest['chatPrompt']>,
64+
hideSystemPrompt?: boolean
5965
) => {
60-
const items = chatPrompt?.map((msg) => formatChatMessage(msg))
66+
const items = chatPrompt?.map((msg) =>
67+
formatChatMessage(msg, hideSystemPrompt)
68+
)
6169

6270
if (items) {
6371
process.stdout.write(items.join('\n'))

src/ax/ai/types.ts

+1
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ export type AxAIServiceActionOptions<
240240
traceId?: string
241241
rateLimiter?: AxRateLimiterFunction
242242
debug?: boolean
243+
debugHideSystemPrompt?: boolean
243244
}
244245

245246
export interface AxAIService<TModel = unknown, TEmbedModel = unknown> {

src/ax/dsp/datetime.test.ts

+9-9
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,34 @@ const field: AxField = {
1111
describe('datetime parsing', () => {
1212
it('should parse datetime with timezone abbreviation', () => {
1313
const dt = parseLLMFriendlyDateTime(field, '2022-01-01 12:00 EST')
14-
expect(dt.toUTCString()).toBe('Sat, 01 Jan 2022 17:00:00 GMT')
14+
expect(dt?.toUTCString()).toBe('Sat, 01 Jan 2022 17:00:00 GMT')
1515
})
1616

1717
it('should parse datetime with seconds and timezone abbreviation', () => {
1818
const dt = parseLLMFriendlyDateTime(field, '2022-01-01 12:00:10 EST')
19-
expect(dt.toUTCString()).toBe('Sat, 01 Jan 2022 17:00:10 GMT')
19+
expect(dt?.toUTCString()).toBe('Sat, 01 Jan 2022 17:00:10 GMT')
2020
})
2121

2222
it('should parse datetime with full timezone', () => {
2323
const dt = parseLLMFriendlyDateTime(
2424
field,
2525
'2022-01-01 12:00 America/New_York'
2626
)
27-
expect(dt.toUTCString()).toBe('Sat, 01 Jan 2022 17:00:00 GMT')
27+
expect(dt?.toUTCString()).toBe('Sat, 01 Jan 2022 17:00:00 GMT')
2828
})
2929

3030
it('should parse datetime with another full timezone', () => {
3131
const dt = parseLLMFriendlyDateTime(
3232
field,
3333
'2022-01-01 12:00 America/Los_Angeles'
3434
)
35-
expect(dt.toUTCString()).toBe('Sat, 01 Jan 2022 20:00:00 GMT')
35+
expect(dt?.toUTCString()).toBe('Sat, 01 Jan 2022 20:00:00 GMT')
3636
})
3737

3838
it('should parse datetime across DST boundary', () => {
3939
const summerDt = parseLLMFriendlyDateTime(field, '2022-07-01 12:00 EST')
4040
const winterDt = parseLLMFriendlyDateTime(field, '2022-01-01 12:00 EST')
41-
expect(summerDt.getUTCHours()).toBe(winterDt.getUTCHours())
41+
expect(summerDt?.getUTCHours()).toBe(winterDt?.getUTCHours())
4242
})
4343

4444
it('should throw error for invalid datetime value', () => {
@@ -55,22 +55,22 @@ describe('datetime parsing', () => {
5555
describe('date parsing', () => {
5656
it('should parse valid date', () => {
5757
const dt = parseLLMFriendlyDate(field, '2022-01-01')
58-
expect(dt.toUTCString()).toBe('Sat, 01 Jan 2022 00:00:00 GMT')
58+
expect(dt?.toUTCString()).toBe('Sat, 01 Jan 2022 00:00:00 GMT')
5959
})
6060

6161
it('should parse date with leading zeros', () => {
6262
const dt = parseLLMFriendlyDate(field, '2022-02-05')
63-
expect(dt.toUTCString()).toBe('Sat, 05 Feb 2022 00:00:00 GMT')
63+
expect(dt?.toUTCString()).toBe('Sat, 05 Feb 2022 00:00:00 GMT')
6464
})
6565

6666
it('should parse date at year boundary', () => {
6767
const dt = parseLLMFriendlyDate(field, '2022-12-31')
68-
expect(dt.toUTCString()).toBe('Sat, 31 Dec 2022 00:00:00 GMT')
68+
expect(dt?.toUTCString()).toBe('Sat, 31 Dec 2022 00:00:00 GMT')
6969
})
7070

7171
it('should parse date in leap year', () => {
7272
const dt = parseLLMFriendlyDate(field, '2024-02-29')
73-
expect(dt.toUTCString()).toBe('Thu, 29 Feb 2024 00:00:00 GMT')
73+
expect(dt?.toUTCString()).toBe('Thu, 29 Feb 2024 00:00:00 GMT')
7474
})
7575

7676
it('should throw error for invalid date value', () => {

src/ax/dsp/datetime.ts

+10-2
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@ import { ValidationError } from './validate.js'
55

66
export function parseLLMFriendlyDate(
77
field: Readonly<AxField>,
8-
dateStr: string
8+
dateStr: string,
9+
required: boolean = false
910
) {
1011
try {
1112
return _parseLLMFriendlyDate(dateStr)
1213
} catch (err) {
14+
if (field.isOptional && !required) {
15+
return
16+
}
1317
const message = (err as Error).message
1418
throw new ValidationError({ fields: [field], message, value: dateStr })
1519
}
@@ -31,11 +35,15 @@ function _parseLLMFriendlyDate(dateStr: string) {
3135

3236
export function parseLLMFriendlyDateTime(
3337
field: Readonly<AxField>,
34-
dateStr: string
38+
dateStr: string,
39+
required: boolean = false
3540
) {
3641
try {
3742
return _parseLLMFriendlyDateTime(dateStr)
3843
} catch (err) {
44+
if (field.isOptional && !required) {
45+
return
46+
}
3947
const message = (err as Error).message
4048
throw new ValidationError({ fields: [field], message, value: dateStr })
4149
}

src/ax/dsp/extract.ts

+29-11
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export const extractValues = (
1616
}
1717

1818
export interface extractionState {
19-
prevField?: { field: AxField; s: number; e: number }
19+
prevFields?: { field: AxField; s: number; e: number }[]
2020
currField?: AxField
2121
currFieldIndex?: number
2222
extractedFields: AxField[]
@@ -64,8 +64,9 @@ export const streamingExtractValues = (
6464
continue
6565
}
6666

67-
const prefix = field.title + ':'
68-
let e = matchesContent(content, prefix, xstate.s + 1)
67+
const isFirst = xstate.extractedFields.length === 0
68+
const prefix = (isFirst ? '' : '\n') + field.title + ':'
69+
let e = matchesContent(content, prefix, xstate.s === 0 ? 0 : xstate.s + 1)
6970

7071
switch (e) {
7172
case -1:
@@ -95,7 +96,11 @@ export const streamingExtractValues = (
9596
if (parsedValue !== undefined) {
9697
values[xstate.currField.name] = parsedValue
9798
}
98-
xstate.prevField = { field: xstate.currField, s: xstate.s, e }
99+
if (xstate.prevFields) {
100+
xstate.prevFields?.push({ field: xstate.currField, s: xstate.s, e })
101+
} else {
102+
xstate.prevFields = [{ field: xstate.currField, s: xstate.s, e }]
103+
}
99104
}
100105

101106
checkMissingRequiredFields(xstate, values, index)
@@ -136,7 +141,11 @@ export const streamingExtractFinalValue = (
136141
checkMissingRequiredFields(xstate, values, sigFields.length)
137142
}
138143

139-
const convertValueToType = (field: Readonly<AxField>, val: string) => {
144+
const convertValueToType = (
145+
field: Readonly<AxField>,
146+
val: string,
147+
required: boolean = false
148+
) => {
140149
switch (field.type?.name) {
141150
case 'code':
142151
return extractBlock(val)
@@ -147,6 +156,9 @@ const convertValueToType = (field: Readonly<AxField>, val: string) => {
147156
case 'number': {
148157
const v = Number(val)
149158
if (Number.isNaN(v)) {
159+
if (field.isOptional && !required) {
160+
return
161+
}
150162
throw new Error('Invalid number')
151163
}
152164
return v
@@ -162,18 +174,24 @@ const convertValueToType = (field: Readonly<AxField>, val: string) => {
162174
} else if (v === 'false') {
163175
return false
164176
} else {
177+
if (field.isOptional && !required) {
178+
return
179+
}
165180
throw new Error('Invalid boolean')
166181
}
167182
}
168183
case 'date':
169-
return parseLLMFriendlyDate(field, val)
184+
return parseLLMFriendlyDate(field, val, required)
170185

171186
case 'datetime':
172-
return parseLLMFriendlyDateTime(field, val)
187+
return parseLLMFriendlyDateTime(field, val, required)
173188

174189
case 'class':
175190
const className = val
176191
if (field.type.classes && !field.type.classes.includes(className)) {
192+
if (field.isOptional) {
193+
return
194+
}
177195
throw new Error(
178196
`Invalid class '${val}', expected one of the following: ${field.type.classes.join(', ')}`
179197
)
@@ -249,11 +267,11 @@ export function* streamValues<OUT>(
249267
// eslint-disable-next-line functional/prefer-immutable-types
250268
xstate: extractionState
251269
) {
252-
if (xstate.prevField && !xstate.prevField.field.isInternal) {
253-
const { field, s, e } = xstate.prevField
270+
for (const prevField of xstate.prevFields ?? []) {
271+
const { field, s, e } = prevField
254272
yield* yieldDelta<OUT>(content, field, s, e, xstate)
255-
xstate.prevField = undefined
256273
}
274+
xstate.prevFields = undefined
257275

258276
if (!xstate.currField || xstate.currField.isInternal) {
259277
return
@@ -354,7 +372,7 @@ function validateAndParseFieldValue(
354372
for (const [index, item] of value.entries()) {
355373
if (item !== undefined) {
356374
const v = typeof item === 'string' ? item.trim() : item
357-
value[index] = convertValueToType(field, v)
375+
value[index] = convertValueToType(field, v, true)
358376
}
359377
}
360378
} else {

src/ax/dsp/generate.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,12 @@ export class AxGen<
517517
const maxRetries = options.maxRetries ?? this.options?.maxRetries ?? 10
518518
const maxSteps = options.maxSteps ?? this.options?.maxSteps ?? 10
519519
const debug = options.debug ?? ai.getOptions().debug
520-
const mem = options.mem ?? this.options?.mem ?? new AxMemory(10000, debug)
520+
const memOptions = {
521+
debug: options.debug,
522+
debugHideSystemPrompt: options.debugHideSystemPrompt,
523+
}
524+
const mem =
525+
options.mem ?? this.options?.mem ?? new AxMemory(10000, memOptions)
521526

522527
let err: ValidationError | AxAssertionError | undefined
523528

src/ax/dsp/program.ts

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ export type AxProgramForwardOptions = {
6262
stopFunction?: string
6363
fastFail?: boolean
6464
debug?: boolean
65+
debugHideSystemPrompt?: boolean
6566
}
6667

6768
export type AxProgramStreamingForwardOptions = Omit<

src/ax/dsp/sig.test.ts

+1-7
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,7 @@ describe('extract values with signatures', () => {
9696
extractValues(
9797
sig,
9898
v1,
99-
`
100-
Title: Coastal Ecosystem Restoration
101-
102-
Key Points: Coastal regions prone to natural disasters, Selection criteria based on vulnerability indices and population density, Climate risk assessments conducted for sea-level rise and extreme weather events, Targeted ecosystems include mangrove forests, coral reefs, wetlands
103-
104-
Description: The project focuses on coastal regions vulnerable to natural disasters like hurricanes and flooding. Selection criteria included vulnerability indices, population density, and proximity to critical infrastructure. Climate risk assessments identified risks related to sea-level rise, storm surges, and extreme weather events. Targeted ecosystems encompass mangrove forests, coral reefs, and wetlands that provide coastal protection, biodiversity support, and livelihood opportunities for local communities.
105-
`
99+
`Title: Coastal Ecosystem Restoration\nKey Points: Coastal regions prone to natural disasters, Selection criteria based on vulnerability indices and population density, Climate risk assessments conducted for sea-level rise and extreme weather events, Targeted ecosystems include mangrove forests, coral reefs, wetlands\nDescription: The project focuses on coastal regions vulnerable to natural disasters like hurricanes and flooding. Selection criteria included vulnerability indices, population density, and proximity to critical infrastructure. Climate risk assessments identified risks related to sea-level rise, storm surges, and extreme weather events. Targeted ecosystems encompass mangrove forests, coral reefs, and wetlands that provide coastal protection, biodiversity support, and livelihood opportunities for local communities.`
106100
)
107101

108102
expect(v1).toEqual({

src/ax/dsp/util.ts

+1
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ export function matchesContent(
308308

309309
// First check if the complete prefix exists anywhere after startIndex
310310
const exactMatchIndex = content.indexOf(prefix, startIndex)
311+
311312
if (exactMatchIndex !== -1) {
312313
return exactMatchIndex
313314
}

0 commit comments

Comments
 (0)