diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 9f61eea3c7..9e4f94ebf7 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -21,7 +21,6 @@ import { runWithStreamingCallback, StreamingCallback, } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; import { z } from 'zod'; import { DocumentData } from './document.js'; @@ -36,8 +35,8 @@ import { MessageData, ModelAction, ModelArgument, - ModelReference, Part, + resolveModel as resolveModelFromRegistry, Role, ToolRequestPart, ToolResponsePart, @@ -538,14 +537,7 @@ async function resolveModel(options: GenerateOptions): Promise { throw new Error('Unable to resolve model.'); } } - if (typeof model === 'string') { - return (await lookupAction(`/model/${model}`)) as ModelAction; - } else if (model.hasOwnProperty('info')) { - const ref = model as ModelReference; - return (await lookupAction(`/model/${ref.name}`)) as ModelAction; - } else { - return model as ModelAction; - } + return resolveModelFromRegistry(model); } export class NoValidCandidatesError extends GenkitError { diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 45cf8be059..09e263ea10 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -16,11 +16,13 @@ import { Action, + actionWithMiddleware, defineAction, getStreamingCallback, Middleware, StreamingCallback, } from '@genkit-ai/core'; +import { lookupAction } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { performance } from 'node:perf_hooks'; import { z } from 'zod'; @@ -327,6 +329,61 @@ export function defineModel< return act as ModelAction; } +export function defineWrappedModel< + CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, +>({ + name, + model, + info, + configSchema, + use, +}: { + name: string; + info?: ModelInfo; + model: ModelArgument; + configSchema?: CustomOptionsSchema; + use: ModelMiddleware[]; +}): ModelAction { + var originalInfo: ModelInfo | undefined; + var originalConfigSchema: CustomOptionsSchema | undefined; + if (model.hasOwnProperty('info')) { + const ref = model as ModelReference; + originalInfo = ref.info; + originalConfigSchema = ref.configSchema; + } else { + const ma = model as ModelAction; + originalInfo = ma.__action?.metadata?.model as ModelInfo; + originalConfigSchema = ma.__configSchema as CustomOptionsSchema; + } + + return defineModel( + { + ...originalInfo, + ...info, + configSchema: configSchema || originalConfigSchema, + name, + }, + async (request) => { + const resolvedModel = await resolveModel(model); + const wrapped = actionWithMiddleware(resolvedModel, use); + return wrapped(request); + } + ) as ModelAction; +} + +export async function resolveModel( + model: ModelArgument +): Promise { + if (typeof model === 'string') { + return (await lookupAction(`/model/${model}`)) as ModelAction; + } else if (model.hasOwnProperty('info')) { + const ref = model as ModelReference; + return (await lookupAction(`/model/${ref.name}`)) as ModelAction; + } else { + return model as ModelAction; + } +} + export interface ModelReference { name: string; configSchema?: CustomOptions; diff --git a/js/ai/tests/model/model_test.ts b/js/ai/tests/model/model_test.ts new file mode 100644 index 0000000000..c35f09d768 --- /dev/null +++ b/js/ai/tests/model/model_test.ts @@ -0,0 +1,155 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { __hardResetRegistryForTesting } from '@genkit-ai/core/registry'; +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { z } from 'zod'; +import { generate } from '../../src/generate.js'; +import { + defineModel, + defineWrappedModel, + ModelAction, + ModelMiddleware, +} from '../../src/model.js'; + +const wrapRequest: ModelMiddleware = async (req, next) => { + return next({ + ...req, + messages: [ + { + role: 'user', + content: [ + { + text: + '(' + + req.messages + .map((m) => m.content.map((c) => c.text).join()) + .join() + + ')', + }, + ], + }, + ], + }); +}; +const wrapResponse: ModelMiddleware = async (req, next) => { + const res = await next(req); + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + '[' + + res.candidates[0].message.content.map((c) => c.text).join() + + ']', + }, + ], + }, + }, + ], + }; +}; + +describe('defineWrappedModel', () => { + beforeEach(__hardResetRegistryForTesting); + + var echoModel: ModelAction; + var wrappedEchoModel: ModelAction; + + beforeEach(() => { + echoModel = defineModel( + { + name: 'echoModel', + label: 'echo-echo-echo-echo-echo', + supports: { + multiturn: true, + }, + configSchema: z.object({ + customField: z.string(), + }), + }, + async (request) => { + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + }, + ], + }; + } + ); + wrappedEchoModel = defineWrappedModel({ + name: 'wrappedModel', + model: echoModel, + info: { + label: 'Wrapped Echo', + }, + use: [wrapRequest, wrapResponse], + }); + }); + + it('copies/overwrites metadata', async () => { + assert.deepStrictEqual(wrappedEchoModel.__action.metadata, { + model: { + customOptions: { + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: true, + properties: { + customField: { + type: 'string', + }, + }, + required: ['customField'], + type: 'object', + }, + label: 'Wrapped Echo', + versions: undefined, + supports: { + multiturn: true, + }, + }, + }); + }); + + it('applies middleware', async () => { + const response = await generate({ + prompt: 'banana', + model: wrappedEchoModel, + }); + + const want = '[Echo: (banana)]'; + assert.deepStrictEqual(response.text(), want); + }); +}); diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 71f0bcb2f8..29595d8c65 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -15,6 +15,7 @@ */ import { generate, generateStream, retrieve } from '@genkit-ai/ai'; +import { ModelMiddleware, defineWrappedModel } from '@genkit-ai/ai/model'; import { defineTool } from '@genkit-ai/ai/tool'; import { configureGenkit } from '@genkit-ai/core'; import { dotprompt, prompt } from '@genkit-ai/dotprompt'; @@ -47,7 +48,6 @@ configureGenkit({ '@opentelemetry/instrumentation-dns': { enabled: false }, '@opentelemetry/instrumentation-net': { enabled: false }, }, - metricExportIntervalMillis: 5_000, }, }), dotprompt(), @@ -331,3 +331,55 @@ export const dotpromptContext = defineFlow( return result.output() as any; } ); + +const wrapRequest: ModelMiddleware = async (req, next) => { + return next({ + ...req, + messages: [ + { + role: 'user', + content: [ + { + text: + '(' + + req.messages + .map((m) => m.content.map((c) => c.text).join()) + .join() + + ')', + }, + ], + }, + ], + }); +}; +const wrapResponse: ModelMiddleware = async (req, next) => { + const res = await next(req); + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + '[' + + res.candidates[0].message.content.map((c) => c.text).join() + + ']', + }, + ], + }, + }, + ], + }; +}; + +defineWrappedModel({ + name: 'wrappedGemini15Flash', + model: geminiPro, + info: { + label: 'wrappedGemini15Flash', + }, + use: [wrapRequest, wrapResponse], +});