Skip to content

Commit 0c9f4f4

Browse files
committed
fix: old router is now simple classifier
1 parent 6886416 commit 0c9f4f4

File tree

4 files changed

+52
-35
lines changed

4 files changed

+52
-35
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ OPENAI_APIKEY=openai_key npm run tsx ./src/examples/marketing.ts
632632
| balancer.ts | Balance between various llm's based on cost, etc |
633633
| docker.ts | Use the docker sandbox to find files by description |
634634
| prime.ts | Using field processors to process fields in a prompt |
635+
| simple-classify.ts | Use a simple classifier to classify stuff |
635636

636637
## Our Goal
637638

src/ax/dsp/router.ts

+15-13
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ import { ColorLog } from '../util/log.js'
44

55
const colorLog = new ColorLog()
66

7-
export interface AxRouterForwardOptions {
7+
export interface AxSimpleClassifierForwardOptions {
88
cutoff?: number
99
}
1010

11-
export class AxRoute {
11+
export class AxSimpleClassifierClass {
1212
private readonly name: string
1313
private readonly context: readonly string[]
1414

@@ -26,7 +26,7 @@ export class AxRoute {
2626
}
2727
}
2828

29-
export class AxRouter {
29+
export class AxSimpleClassifier {
3030
private readonly ai: AxAIService
3131

3232
private db: AxDBMemory
@@ -45,25 +45,27 @@ export class AxRouter {
4545
this.db.setDB(state)
4646
}
4747

48-
public setRoutes = async (routes: readonly AxRoute[]): Promise<void> => {
49-
for (const ro of routes) {
50-
const ret = await this.ai.embed({ texts: ro.getContext() })
48+
public setClasses = async (
49+
classes: readonly AxSimpleClassifierClass[]
50+
): Promise<void> => {
51+
for (const c of classes) {
52+
const ret = await this.ai.embed({ texts: c.getContext() })
5153
await this.db.upsert({
52-
id: ro.getName(),
53-
table: 'routes',
54+
id: c.getName(),
55+
table: 'classes',
5456
values: ret.embeddings[0],
5557
})
5658
}
5759
}
5860

5961
public async forward(
6062
text: string,
61-
options?: Readonly<AxRouterForwardOptions>
63+
options?: Readonly<AxSimpleClassifierForwardOptions>
6264
): Promise<string> {
6365
const { embeddings } = await this.ai.embed({ texts: [text] })
6466

6567
const matches = await this.db.query({
66-
table: 'routes',
68+
table: 'classes',
6769
values: embeddings[0],
6870
})
6971

@@ -83,12 +85,12 @@ export class AxRouter {
8385
)
8486
}
8587

86-
const route = m.at(0)
87-
if (!route) {
88+
const matchedClass = m.at(0)
89+
if (!matchedClass) {
8890
return ''
8991
}
9092

91-
return route.id
93+
return matchedClass.id
9294
}
9395

9496
public setOptions(options: Readonly<{ debug?: boolean }>): void {

src/ax/index.ts

+24-10
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ import {
254254
AxLLMRequestTypeValues,
255255
AxSpanKindValues
256256
} from './trace/trace.js';
257+
import {
258+
AxMockAIService,
259+
type AxMockAIServiceConfig
260+
} from './ai/mock/api.js';
257261
import {
258262
AxProgram,
259263
AxProgramWithSignature,
@@ -280,21 +284,22 @@ import {
280284
AxRateLimiterTokenUsage,
281285
type AxRateLimiterTokenUsageOptions
282286
} from './util/rate-limit.js';
283-
import {
284-
AxRoute,
285-
AxRouter,
286-
type AxRouterForwardOptions
287-
} from './dsp/router.js';
288287
import {
289288
AxSignature,
290289
type AxField,
291290
type AxIField
292291
} from './dsp/sig.js';
292+
import {
293+
AxSimpleClassifier,
294+
AxSimpleClassifierClass,
295+
type AxSimpleClassifierForwardOptions
296+
} from './dsp/router.js';
293297
import {
294298
AxTestPrompt,
295299
type AxEvaluateArgs
296300
} from './dsp/evaluate.js';
297301
import {
302+
type AxAIInputModelList,
298303
type AxAIModelList,
299304
type AxAIPromptConfig,
300305
type AxAIService,
@@ -326,6 +331,11 @@ import {
326331
type AxDBUpsertRequest,
327332
type AxDBUpsertResponse
328333
} from './db/types.js';
334+
import {
335+
type AxFieldProcessor,
336+
type AxFieldProcessorProcess,
337+
type AxStreamingFieldProcessorProcess
338+
} from './dsp/fieldProcessor.js';
329339
import {AxAIDeepSeekModel} from './ai/deepseek/types.js';
330340
import {AxAIGroqModel} from './ai/groq/types.js';
331341
import {AxChainOfThought} from './prompts/cot.js';
@@ -334,10 +344,9 @@ import {AxDefaultResultReranker} from './docs/reranker.js';
334344
import {AxEmbeddingAdapter} from './funcs/embed.js';
335345
import {AxInstanceRegistry} from './dsp/registry.js';
336346
import {AxMemory} from './mem/memory.js';
337-
import {AxMockAIService} from './ai/mock/api.js';
347+
import {AxMultiServiceRouter} from './ai/multiservice.js';
338348
import {AxRAG} from './prompts/rag.js';
339349
import {type AxAIMemory} from './mem/types.js';
340-
import {type AxFieldProcessor} from './dsp/fieldProcessor.js';
341350

342351
// Value exports
343352
export { AxAI };
@@ -405,14 +414,15 @@ export { AxJSInterpreterPermission };
405414
export { AxLLMRequestTypeValues };
406415
export { AxMemory };
407416
export { AxMockAIService };
417+
export { AxMultiServiceRouter };
408418
export { AxProgram };
409419
export { AxProgramWithSignature };
410420
export { AxPromptTemplate };
411421
export { AxRAG };
412422
export { AxRateLimiterTokenUsage };
413-
export { AxRoute };
414-
export { AxRouter };
415423
export { AxSignature };
424+
export { AxSimpleClassifier };
425+
export { AxSimpleClassifierClass };
416426
export { AxSpanKindValues };
417427
export { AxTestPrompt };
418428

@@ -469,6 +479,7 @@ export type { AxAIHuggingFaceArgs };
469479
export type { AxAIHuggingFaceConfig };
470480
export type { AxAIHuggingFaceRequest };
471481
export type { AxAIHuggingFaceResponse };
482+
export type { AxAIInputModelList };
472483
export type { AxAIMemory };
473484
export type { AxAIMistralArgs };
474485
export type { AxAIModelList };
@@ -542,6 +553,7 @@ export type { AxEvaluateArgs };
542553
export type { AxExample };
543554
export type { AxField };
544555
export type { AxFieldProcessor };
556+
export type { AxFieldProcessorProcess };
545557
export type { AxFieldTemplateFn };
546558
export type { AxFieldValue };
547559
export type { AxFunction };
@@ -559,6 +571,7 @@ export type { AxInternalChatRequest };
559571
export type { AxInternalEmbedRequest };
560572
export type { AxMetricFn };
561573
export type { AxMetricFnArgs };
574+
export type { AxMockAIServiceConfig };
562575
export type { AxModelConfig };
563576
export type { AxModelInfo };
564577
export type { AxModelInfoWithProvider };
@@ -577,9 +590,10 @@ export type { AxRerankerOut };
577590
export type { AxResponseHandlerArgs };
578591
export type { AxRewriteIn };
579592
export type { AxRewriteOut };
580-
export type { AxRouterForwardOptions };
593+
export type { AxSimpleClassifierForwardOptions };
581594
export type { AxStreamingAssertion };
582595
export type { AxStreamingEvent };
596+
export type { AxStreamingFieldProcessorProcess };
583597
export type { AxTokenUsage };
584598
export type { AxTunable };
585599
export type { AxUsable };

src/examples/routing.ts src/examples/simple-classify.ts

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
1-
import { AxAI, AxRoute, AxRouter } from '@ax-llm/ax'
1+
import { AxAI, AxSimpleClassifier, AxSimpleClassifierClass } from '@ax-llm/ax'
22

3-
const customerSupport = new AxRoute('customerSupport', [
3+
const customerSupport = new AxSimpleClassifierClass('customerSupport', [
44
'how can I return a product?',
55
'where is my order?',
66
'can you help me with a refund?',
77
'I need to update my shipping address',
88
'my product arrived damaged, what should I do?',
99
])
1010

11-
const employeeHR = new AxRoute('employeeHR', [
11+
const employeeHR = new AxSimpleClassifierClass('employeeHR', [
1212
'how do I request time off?',
1313
'where can I find the employee handbook?',
1414
'who do I contact for IT support?',
1515
'I have a question about my benefits',
1616
'how do I log my work hours?',
1717
])
1818

19-
const salesInquiries = new AxRoute('salesInquiries', [
19+
const salesInquiries = new AxSimpleClassifierClass('salesInquiries', [
2020
'I want to buy your products',
2121
'can you provide a quote?',
2222
'what are the payment options?',
2323
'how do I get a discount?',
2424
'who can I speak with for a bulk order?',
2525
])
2626

27-
const technicalSupport = new AxRoute('technicalSupport', [
27+
const technicalSupport = new AxSimpleClassifierClass('technicalSupport', [
2828
'how do I install your software?',
2929
'I’m having trouble logging in',
3030
'can you help me configure my settings?',
@@ -37,20 +37,20 @@ const ai = new AxAI({
3737
apiKey: process.env.OPENAI_APIKEY as string,
3838
})
3939

40-
const router = new AxRouter(ai)
40+
const classifier = new AxSimpleClassifier(ai)
4141

42-
await router.setRoutes([
42+
await classifier.setClasses([
4343
customerSupport,
4444
employeeHR,
4545
salesInquiries,
4646
technicalSupport,
4747
])
4848

49-
const r1 = await router.forward('I need help with my order')
50-
const r2 = await router.forward('I want to know more about the company')
51-
const r3 = await router.forward('I need help installing your software')
52-
const r4 = await router.forward('I did not receive my order on time')
53-
const r5 = await router.forward('Where can I find info about our 401k')
49+
const r1 = await classifier.forward('I need help with my order')
50+
const r2 = await classifier.forward('I want to know more about the company')
51+
const r3 = await classifier.forward('I need help installing your software')
52+
const r4 = await classifier.forward('I did not receive my order on time')
53+
const r5 = await classifier.forward('Where can I find info about our 401k')
5454

5555
console.log(r1 === 'salesInquiries' ? 'PASS' : 'FAIL: ' + r1)
5656
console.log(r2 === 'salesInquiries' ? 'PASS' : 'FAIL: ' + r2)

0 commit comments

Comments
 (0)