diff --git a/packages/ai/src/evals/eval.ts b/packages/ai/src/evals/eval.ts index bf253472..15647703 100644 --- a/packages/ai/src/evals/eval.ts +++ b/packages/ai/src/evals/eval.ts @@ -1,4 +1,5 @@ import { afterAll, beforeAll, describe, inject, it } from 'vitest'; +import type { RunnerTestFile, RunnerTestSuite } from 'vitest'; import { context, SpanStatusCode, trace, type Context } from '@opentelemetry/api'; import { customAlphabet } from 'nanoid'; import { withEvalContext, getEvalContext, getConfigScope } from './context/storage'; @@ -57,8 +58,27 @@ type RunTaskFailureDetails = { overrides?: Record; }; +type EvalHookSuite = (RunnerTestSuite | RunnerTestFile) & { + meta: Record; + tasks: Array<{ meta: { case?: EvalCaseReport } }>; +}; + +type CompatibleSuiteHook = { + (suite: RunnerTestSuite | RunnerTestFile): Promise | void; + (context: unknown, suite: RunnerTestSuite | RunnerTestFile): Promise | void; +}; + const RUN_TASK_FAILURE_DETAILS = Symbol.for('axiom.eval.runTaskFailureDetails'); +function withCompatibleSuiteHook( + fn: (suite: EvalHookSuite) => Promise | void, +): CompatibleSuiteHook { + return async function ({}: any, maybeSuite?: RunnerTestSuite | RunnerTestFile): Promise { + const suite = (maybeSuite ?? arguments[0]) as EvalHookSuite; + await fn(suite); + } as CompatibleSuiteHook; +} + function attachRunTaskFailureDetails( error: unknown, details: RunTaskFailureDetails, @@ -238,7 +258,7 @@ async function registerEval< | { flags: Record; pickedFlags?: string[]; overrides?: Record } | undefined; - beforeAll(async (suite) => { + const handleBeforeAll = async (suite: EvalHookSuite) => { // Ensure worker process knows CLI overrides if (injectedOverrides && Object.keys(injectedOverrides).length > 0) { try { @@ -353,9 +373,9 @@ async function registerEval< }; suiteStart = performance.now(); - }); + }; - afterAll(async (suite) => { + const handleAfterAll = async (suite: EvalHookSuite) => { if (instrumentationError) { throw instrumentationError; } @@ -417,10 +437,10 @@ async function registerEval< const durationMs = Math.round(performance.now() - suiteStart); const successCases = suite.tasks.filter( - (task) => task.meta.case.status === 'success', + (task) => task.meta.case?.status === 'success', ).length; const erroredCases = suite.tasks.filter( - (task) => task.meta.case.status === 'fail' || task.meta.case.status === 'pending', + (task) => task.meta.case?.status === 'fail' || task.meta.case?.status === 'pending', ).length; // signal Axiom that evaluation finished to kick of summary calculations @@ -434,7 +454,10 @@ async function registerEval< durationMs, }); } - }); + }; + + beforeAll(withCompatibleSuiteHook(handleBeforeAll)); + afterAll(withCompatibleSuiteHook(handleAfterAll)); type CollectionRecordWithIndex = { index: number } & CollectionRecord;