diff --git a/packages/orm/src/client/client-impl.ts b/packages/orm/src/client/client-impl.ts index 4af8b031..719f90f3 100644 --- a/packages/orm/src/client/client-impl.ts +++ b/packages/orm/src/client/client-impl.ts @@ -37,7 +37,7 @@ import { ZenStackQueryExecutor } from './executor/zenstack-query-executor'; import * as BuiltinFunctions from './functions'; import { SchemaDbPusher } from './helpers/schema-db-pusher'; import type { ClientOptions, ProceduresOptions } from './options'; -import type { RuntimePlugin } from './plugin'; +import type { AnyPlugin } from './plugin'; import { createZenStackPromise, type ZenStackPromise } from './promise'; import { ResultProcessor } from './result-processor'; @@ -293,8 +293,8 @@ export class ClientImpl { await new SchemaDbPusher(this.schema, this.kysely).push(); } - $use(plugin: RuntimePlugin) { - const newPlugins: RuntimePlugin[] = [...(this.$options.plugins ?? []), plugin]; + $use(plugin: AnyPlugin) { + const newPlugins: AnyPlugin[] = [...(this.$options.plugins ?? []), plugin]; const newOptions: ClientOptions = { ...this.options, plugins: newPlugins, @@ -308,7 +308,7 @@ export class ClientImpl { $unuse(pluginId: string) { // tsc perf - const newPlugins: RuntimePlugin[] = []; + const newPlugins: AnyPlugin[] = []; for (const plugin of this.options.plugins ?? []) { if (plugin.id !== pluginId) { newPlugins.push(plugin); @@ -329,7 +329,7 @@ export class ClientImpl { // tsc perf const newOptions: ClientOptions = { ...this.options, - plugins: [] as RuntimePlugin[], + plugins: [] as AnyPlugin[], }; const newClient = new ClientImpl(this.schema, newOptions, this); // create a new validator to have a fresh schema cache, because plugins may @@ -408,6 +408,16 @@ function createClientProxy(client: ClientImpl): ClientImpl { return new Proxy(client, { get: (target, prop, receiver) => { if (typeof prop === 'string' && prop.startsWith('$')) { + // Check for plugin-provided members (search in reverse order so later plugins win) + const plugins = target.$options.plugins ?? []; + for (let i = plugins.length - 1; i >= 0; i--) { + const plugin = plugins[i]; + const clientMembers = plugin?.client as Record | undefined; + if (clientMembers && prop in clientMembers) { + return clientMembers[prop]; + } + } + // Fall through to built-in $ methods return Reflect.get(target, prop, receiver); } diff --git a/packages/orm/src/client/contract.ts b/packages/orm/src/client/contract.ts index 0006cd2a..945f3645 100644 --- a/packages/orm/src/client/contract.ts +++ b/packages/orm/src/client/contract.ts @@ -40,9 +40,15 @@ import type { UpdateManyArgs, UpsertArgs, } from './crud-types'; -import type { CoreCrudOperations } from './crud/operations/base'; +import type { + CoreCreateOperations, + CoreCrudOperations, + CoreDeleteOperations, + CoreReadOperations, + CoreUpdateOperations, +} from './crud/operations/base'; import type { ClientOptions, QueryOptions, ToQueryOptions } from './options'; -import type { ExtQueryArgsBase, RuntimePlugin } from './plugin'; +import type { ExtClientMembersBase, ExtQueryArgsBase, RuntimePlugin } from './plugin'; import type { ZenStackPromise } from './promise'; import type { ToKysely } from './query-builder'; @@ -51,11 +57,26 @@ type TransactionUnsupportedMethods = (typeof TRANSACTION_UNSUPPORTED_METHODS)[nu /** * Extracts extended query args for a specific operation. */ -type ExtractExtQueryArgs = Operation extends keyof ExtQueryArgs - ? NonNullable - : 'all' extends keyof ExtQueryArgs - ? NonNullable - : {}; +type ExtractExtQueryArgs = (Operation extends keyof ExtQueryArgs + ? ExtQueryArgs[Operation] + : {}) & + ('$create' extends keyof ExtQueryArgs + ? Operation extends CoreCreateOperations + ? ExtQueryArgs['$create'] + : {} + : {}) & + ('$read' extends keyof ExtQueryArgs ? (Operation extends CoreReadOperations ? ExtQueryArgs['$read'] : {}) : {}) & + ('$update' extends keyof ExtQueryArgs + ? Operation extends CoreUpdateOperations + ? ExtQueryArgs['$update'] + : {} + : {}) & + ('$delete' extends keyof ExtQueryArgs + ? Operation extends CoreDeleteOperations + ? ExtQueryArgs['$delete'] + : {} + : {}) & + ('$all' extends keyof ExtQueryArgs ? ExtQueryArgs['$all'] : {}); /** * Transaction isolation levels. @@ -75,6 +96,7 @@ export type ClientContract< Schema extends SchemaDef, Options extends ClientOptions = ClientOptions, ExtQueryArgs extends ExtQueryArgsBase = {}, + ExtClientMembers extends ExtClientMembersBase = {}, > = { /** * The schema definition. @@ -132,7 +154,7 @@ export type ClientContract< /** * Sets the current user identity. */ - $setAuth(auth: AuthType | undefined): ClientContract; + $setAuth(auth: AuthType | undefined): ClientContract; /** * Returns a new client with new options applied. @@ -141,7 +163,9 @@ export type ClientContract< * const dbNoValidation = db.$setOptions({ ...db.$options, validateInput: false }); * ``` */ - $setOptions>(options: Options): ClientContract; + $setOptions>( + options: NewOptions, + ): ClientContract; /** * Returns a new client enabling/disabling input validations expressed with attributes like @@ -149,7 +173,7 @@ export type ClientContract< * * @deprecated Use {@link $setOptions} instead. */ - $setInputValidation(enable: boolean): ClientContract; + $setInputValidation(enable: boolean): ClientContract; /** * The Kysely query builder instance. @@ -165,7 +189,7 @@ export type ClientContract< * Starts an interactive transaction. */ $transaction( - callback: (tx: TransactionClientContract) => Promise, + callback: (tx: TransactionClientContract) => Promise, options?: { isolationLevel?: TransactionIsolationLevel }, ): Promise; @@ -180,14 +204,18 @@ export type ClientContract< /** * Returns a new client with the specified plugin installed. */ - $use( - plugin: RuntimePlugin, - ): ClientContract; + $use< + PluginSchema extends SchemaDef = Schema, + PluginExtQueryArgs extends ExtQueryArgsBase = {}, + PluginExtClientMembers extends ExtClientMembersBase = {}, + >( + plugin: RuntimePlugin, + ): ClientContract; /** * Returns a new client with the specified plugin removed. */ - $unuse(pluginId: string): ClientContract; + $unuse(pluginId: string): ClientContract; /** * Returns a new client with all plugins removed. @@ -216,7 +244,8 @@ export type ClientContract< ToQueryOptions, ExtQueryArgs >; -} & ProcedureOperations; +} & ProcedureOperations & + ExtClientMembers; /** * The contract for a client in a transaction. @@ -225,7 +254,8 @@ export type TransactionClientContract< Schema extends SchemaDef, Options extends ClientOptions, ExtQueryArgs extends ExtQueryArgsBase, -> = Omit, TransactionUnsupportedMethods>; + ExtClientMembers extends ExtClientMembersBase, +> = Omit, TransactionUnsupportedMethods>; export type ProcedureOperations = Schema['procedures'] extends Record diff --git a/packages/orm/src/client/crud/operations/base.ts b/packages/orm/src/client/crud/operations/base.ts index c9c85121..5eb4b2d1 100644 --- a/packages/orm/src/client/crud/operations/base.ts +++ b/packages/orm/src/client/crud/operations/base.ts @@ -119,6 +119,36 @@ export const CoreWriteOperations = [ */ export type CoreWriteOperations = (typeof CoreWriteOperations)[number]; +/** + * List of core create operations. + */ +export const CoreCreateOperations = ['create', 'createMany', 'createManyAndReturn', 'upsert'] as const; + +/** + * List of core create operations. + */ +export type CoreCreateOperations = (typeof CoreCreateOperations)[number]; + +/** + * List of core update operations. + */ +export const CoreUpdateOperations = ['update', 'updateMany', 'updateManyAndReturn', 'upsert'] as const; + +/** + * List of core update operations. + */ +export type CoreUpdateOperations = (typeof CoreUpdateOperations)[number]; + +/** + * List of core delete operations. + */ +export const CoreDeleteOperations = ['delete', 'deleteMany'] as const; + +/** + * List of core delete operations. + */ +export type CoreDeleteOperations = (typeof CoreDeleteOperations)[number]; + /** * List of all CRUD operations, including 'orThrow' variants. */ diff --git a/packages/orm/src/client/crud/validator/index.ts b/packages/orm/src/client/crud/validator/index.ts index 3556c604..9390c69c 100644 --- a/packages/orm/src/client/crud/validator/index.ts +++ b/packages/orm/src/client/crud/validator/index.ts @@ -35,6 +35,7 @@ import { type UpsertArgs, } from '../../crud-types'; import { createInternalError, createInvalidInputError } from '../../errors'; +import type { AnyPlugin } from '../../plugin'; import { fieldHasDefaultValue, getDiscriminatorField, @@ -46,7 +47,13 @@ import { requireField, requireModel, } from '../../query-utils'; -import type { CoreCrudOperations } from '../operations/base'; +import { + CoreCreateOperations, + CoreDeleteOperations, + CoreReadOperations, + CoreUpdateOperations, + type CoreCrudOperations, +} from '../operations/base'; import { addBigIntValidation, addCustomValidation, @@ -365,8 +372,8 @@ export class InputValidator { private mergePluginArgsSchema(schema: ZodObject, operation: CoreCrudOperations) { let result = schema; for (const plugin of this.options.plugins ?? []) { - if (plugin.extQueryArgs) { - const pluginSchema = plugin.extQueryArgs.getValidationSchema(operation); + if (plugin.queryArgs) { + const pluginSchema = this.getPluginExtQueryArgsSchema(plugin, operation); if (pluginSchema) { result = result.extend(pluginSchema.shape); } @@ -375,6 +382,77 @@ export class InputValidator { return result.strict(); } + private getPluginExtQueryArgsSchema(plugin: AnyPlugin, operation: string): ZodObject | undefined { + if (!plugin.queryArgs) { + return undefined; + } + + let result: ZodType | undefined; + + if (operation in plugin.queryArgs && plugin.queryArgs[operation]) { + // most specific operation takes highest precedence + result = plugin.queryArgs[operation]; + } else if (operation === 'upsert') { + // upsert is special: it's in both CoreCreateOperations and CoreUpdateOperations + // so we need to merge both $create and $update schemas to match the type system + const createSchema = + '$create' in plugin.queryArgs && plugin.queryArgs['$create'] ? plugin.queryArgs['$create'] : undefined; + const updateSchema = + '$update' in plugin.queryArgs && plugin.queryArgs['$update'] ? plugin.queryArgs['$update'] : undefined; + + if (createSchema && updateSchema) { + invariant( + createSchema instanceof z.ZodObject, + 'Plugin extended query args schema must be a Zod object', + ); + invariant( + updateSchema instanceof z.ZodObject, + 'Plugin extended query args schema must be a Zod object', + ); + // merge both schemas (combines their properties) + result = createSchema.extend(updateSchema.shape); + } else if (createSchema) { + result = createSchema; + } else if (updateSchema) { + result = updateSchema; + } + } else if ( + // then comes grouped operations: $create, $read, $update, $delete + CoreCreateOperations.includes(operation as CoreCreateOperations) && + '$create' in plugin.queryArgs && + plugin.queryArgs['$create'] + ) { + result = plugin.queryArgs['$create']; + } else if ( + CoreReadOperations.includes(operation as CoreReadOperations) && + '$read' in plugin.queryArgs && + plugin.queryArgs['$read'] + ) { + result = plugin.queryArgs['$read']; + } else if ( + CoreUpdateOperations.includes(operation as CoreUpdateOperations) && + '$update' in plugin.queryArgs && + plugin.queryArgs['$update'] + ) { + result = plugin.queryArgs['$update']; + } else if ( + CoreDeleteOperations.includes(operation as CoreDeleteOperations) && + '$delete' in plugin.queryArgs && + plugin.queryArgs['$delete'] + ) { + result = plugin.queryArgs['$delete']; + } else if ('$all' in plugin.queryArgs && plugin.queryArgs['$all']) { + // finally comes $all + result = plugin.queryArgs['$all']; + } + + invariant( + result === undefined || result instanceof z.ZodObject, + 'Plugin extended query args schema must be a Zod object', + ); + return result; + } + // #region Find private makeFindSchema(model: string, operation: CoreCrudOperations) { diff --git a/packages/orm/src/client/index.ts b/packages/orm/src/client/index.ts index bf17a9e6..00bbf1b6 100644 --- a/packages/orm/src/client/index.ts +++ b/packages/orm/src/client/index.ts @@ -6,8 +6,11 @@ export { BaseCrudDialect } from './crud/dialects/base-dialect'; export { AllCrudOperations, AllReadOperations, + CoreCreateOperations, CoreCrudOperations, + CoreDeleteOperations, CoreReadOperations, + CoreUpdateOperations, CoreWriteOperations, } from './crud/operations/base'; export { InputValidator } from './crud/validator'; diff --git a/packages/orm/src/client/options.ts b/packages/orm/src/client/options.ts index d1fa23ed..6439e399 100644 --- a/packages/orm/src/client/options.ts +++ b/packages/orm/src/client/options.ts @@ -4,7 +4,7 @@ import type { PrependParameter } from '../utils/type-utils'; import type { ClientContract, CRUD_EXT } from './contract'; import type { GetProcedureNames, ProcedureHandlerFunc } from './crud-types'; import type { BaseCrudDialect } from './crud/dialects/base-dialect'; -import type { RuntimePlugin } from './plugin'; +import type { AnyPlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; export type ZModelFunctionContext = { @@ -59,7 +59,7 @@ export type ClientOptions = { /** * Plugins. */ - plugins?: RuntimePlugin[]; + plugins?: AnyPlugin[]; /** * Logging configuration. diff --git a/packages/orm/src/client/plugin.ts b/packages/orm/src/client/plugin.ts index ee024f1e..81dff0ec 100644 --- a/packages/orm/src/client/plugin.ts +++ b/packages/orm/src/client/plugin.ts @@ -1,19 +1,33 @@ import type { OperationNode, QueryId, QueryResult, RootOperationNode, UnknownRow } from 'kysely'; -import type { ZodObject } from 'zod'; +import type { ZodType } from 'zod'; import type { ClientContract, ZModelFunction } from '.'; import type { GetModels, SchemaDef } from '../schema'; import type { MaybePromise } from '../utils/type-utils'; import type { AllCrudOperations, CoreCrudOperations } from './crud/operations/base'; +type AllowedExtQueryArgKeys = CoreCrudOperations | '$create' | '$read' | '$update' | '$delete' | '$all'; + /** * Base shape of plugin-extended query args. */ -export type ExtQueryArgsBase = { [K in CoreCrudOperations | 'all']?: object }; +export type ExtQueryArgsBase = { + [K in AllowedExtQueryArgKeys]?: object; +}; + +/** + * Base type for plugin-extended client members (methods and properties). + * Member names should start with '$' to avoid model name conflicts. + */ +export type ExtClientMembersBase = Record; /** * ZenStack runtime plugin. */ -export interface RuntimePlugin { +export interface RuntimePlugin< + Schema extends SchemaDef, + ExtQueryArgs extends ExtQueryArgsBase, + ExtClientMembers extends Record, +> { /** * Plugin ID. */ @@ -59,19 +73,26 @@ export interface RuntimePlugin ZodObject | undefined; + queryArgs?: { + [K in keyof ExtQueryArgs]: ZodType; }; + + /** + * Extended client members (methods and properties). + */ + client?: ExtClientMembers; } + +export type AnyPlugin = RuntimePlugin; + /** * Defines a ZenStack runtime plugin. */ -export function definePlugin( - plugin: RuntimePlugin, -): RuntimePlugin { +export function definePlugin< + Schema extends SchemaDef, + const ExtQueryArgs extends ExtQueryArgsBase = {}, + const ExtClientMembers extends Record = {}, +>(plugin: RuntimePlugin): RuntimePlugin { return plugin; } diff --git a/packages/plugins/policy/src/plugin.ts b/packages/plugins/policy/src/plugin.ts index e27b0cf5..3f67309f 100644 --- a/packages/plugins/policy/src/plugin.ts +++ b/packages/plugins/policy/src/plugin.ts @@ -3,7 +3,7 @@ import type { SchemaDef } from '@zenstackhq/orm/schema'; import { check } from './functions'; import { PolicyHandler } from './policy-handler'; -export class PolicyPlugin implements RuntimePlugin { +export class PolicyPlugin implements RuntimePlugin { get id() { return 'policy' as const; } diff --git a/tests/e2e/orm/plugin-infra/client-members.test.ts b/tests/e2e/orm/plugin-infra/client-members.test.ts new file mode 100644 index 00000000..cf9e6f9b --- /dev/null +++ b/tests/e2e/orm/plugin-infra/client-members.test.ts @@ -0,0 +1,239 @@ +import { definePlugin, type ClientContract } from '@zenstackhq/orm'; +import { createTestClient } from '@zenstackhq/testtools'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import z from 'zod'; +import { schema } from './ext-query-args/schema'; + +describe('Plugin client members', () => { + let db: ClientContract; + + beforeEach(async () => { + db = await createTestClient(schema); + await db.user.deleteMany(); + }); + + afterEach(async () => { + await db?.$disconnect(); + }); + + it('should allow adding methods and props to client', async () => { + let methodCalled = false; + + const extDb = db.$use( + definePlugin({ + id: 'test-plugin', + client: { + // method + $invalidateCache(model?: string) { + methodCalled = true; + return model ?? 'hello'; + }, + + // dynamic property + get $cacheStats() { + return { hits: 10, misses: 5 }; + }, + + // constant property + $cacheStats1: { + hits: 20, + misses: 10, + }, + }, + }), + ); + + const result = extDb.$invalidateCache(); + expect(result).toBe('hello'); + expect(methodCalled).toBe(true); + + expect(extDb.$invalidateCache('user')).toBe('user'); + + // @ts-expect-error + extDb.$invalidateCache(1); + + expect(extDb.$cacheStats.hits).toBe(10); + expect(extDb.$cacheStats.misses).toBe(5); + + expect(extDb.$cacheStats1.hits).toBe(20); + expect(extDb.$cacheStats1.misses).toBe(10); + }); + + it('should support multiple plugins with different members', async () => { + const plugin1 = definePlugin({ + id: 'plugin1', + client: { + $method1: () => 'from-plugin1', + }, + }); + + const plugin2 = definePlugin({ + id: 'plugin2', + client: { + $method2: () => 'from-plugin2', + }, + }); + + const extDb = db.$use(plugin1).$use(plugin2); + + expect(extDb.$method1()).toBe('from-plugin1'); + expect(extDb.$method2()).toBe('from-plugin2'); + }); + + it('should make later plugin win for conflicting members', async () => { + const plugin1 = definePlugin({ + id: 'plugin1', + client: { + $conflicting: () => 'from-plugin1', + }, + }); + + const plugin2 = definePlugin({ + id: 'plugin2', + client: { + $conflicting: () => 'from-plugin2', + }, + }); + + const extDb = db.$use(plugin1).$use(plugin2); + + // Later plugin wins + expect(extDb.$conflicting()).toBe('from-plugin2'); + }); + + it('should make members available in transactions', async () => { + const extDb = db.$use( + definePlugin({ + id: 'test-plugin', + client: { + $txHelper: () => 'in-transaction', + }, + }), + ); + + await extDb.$transaction(async (tx) => { + expect(tx.$txHelper()).toBe('in-transaction'); + await tx.user.create({ data: { name: 'Bob' } }); + }); + }); + + it('should remove members when plugin is removed via $unuse', async () => { + const extDb = db.$use( + definePlugin({ + id: 'removable-plugin', + client: { + $toBeRemoved: () => 'exists', + }, + }), + ); + + expect(extDb.$toBeRemoved()).toBe('exists'); + + const removedDb = extDb.$unuse('removable-plugin'); + + // After $unuse, the method should not be available + // TypeScript would complain, but at runtime it should be undefined + expect(removedDb.$toBeRemoved).toBeUndefined(); + }); + + it('should remove all members when $unuseAll is called', async () => { + const extDb = db + .$use( + definePlugin({ + id: 'p1', + client: { $m1: () => 'a' }, + }), + ) + .$use( + definePlugin({ + id: 'p2', + client: { $m2: () => 'b' }, + }), + ); + + expect(extDb.$m1()).toBe('a'); + expect(extDb.$m2()).toBe('b'); + + const cleanDb = extDb.$unuseAll(); + + expect((cleanDb as any).$m1).toBeUndefined(); + expect((cleanDb as any).$m2).toBeUndefined(); + }); + + it('should isolate members between client instances', async () => { + const extDb = db.$use( + definePlugin({ + id: 'isolated-plugin', + client: { + $isolated: () => 'only-on-extDb', + }, + }), + ); + + expect(extDb.$isolated()).toBe('only-on-extDb'); + + // Original db should not have the method + expect((db as any).$isolated).toBeUndefined(); + }); + + it('should preserve members through $setAuth', async () => { + const extDb = db.$use( + definePlugin({ + id: 'test-plugin', + client: { + $preserved: () => 'still-here', + }, + }), + ); + + const authDb = extDb.$setAuth({ id: 1 }); + + expect(authDb.$preserved()).toBe('still-here'); + }); + + it('should preserve members through $setOptions', async () => { + const extDb = db.$use( + definePlugin({ + id: 'test-plugin', + client: { + $preserved: () => 'still-here', + }, + }), + ); + + const newOptionsDb = extDb.$setOptions({ ...extDb.$options, validateInput: false }); + + expect(newOptionsDb.$preserved()).toBe('still-here'); + }); + + it('should work with both extQueryArgs and client members', async () => { + let gotTTL: number | undefined; + + const extDb = db.$use( + definePlugin({ + id: 'cache-plugin', + queryArgs: { + $all: z.object({ + cache: z + .object({ + ttl: z.number().optional(), + }) + .optional(), + }), + }, + onQuery: async ({ args, proceed }) => { + if (args && 'cache' in args) { + gotTTL = (args as any).cache?.ttl; + } + return proceed(args); + }, + client: { + $getCachedTTL: () => gotTTL, + }, + }), + ); + + await extDb.user.create({ data: { name: 'Test' }, cache: { ttl: 1000 } }); + expect(extDb.$getCachedTTL()).toBe(1000); + }); +}); diff --git a/tests/e2e/orm/plugin-infra/ext-query-args.test.ts b/tests/e2e/orm/plugin-infra/ext-query-args.test.ts index 7a5630d1..95794626 100644 --- a/tests/e2e/orm/plugin-infra/ext-query-args.test.ts +++ b/tests/e2e/orm/plugin-infra/ext-query-args.test.ts @@ -1,4 +1,4 @@ -import { CoreReadOperations, CoreWriteOperations, definePlugin, type ClientContract } from '@zenstackhq/orm'; +import { definePlugin, type ClientContract } from '@zenstackhq/orm'; import { createTestClient } from '@zenstackhq/testtools'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import z from 'zod'; @@ -16,13 +16,14 @@ describe('Plugin extended query args', () => { }); const cacheBustSchema = z.object({ - cache: z.strictObject({ - bust: z.boolean().optional(), - }), + cache: z + .strictObject({ + bust: z.boolean().optional(), + }) + .optional(), }); type CacheOptions = z.infer; - type CacheBustOptions = z.infer; beforeEach(async () => { db = await createTestClient(schema); @@ -33,34 +34,32 @@ describe('Plugin extended query args', () => { await db?.$disconnect(); }); - it('should allow extending all operations', async () => { + it('should allow extending grouped operations', async () => { let gotTTL: number | undefined = undefined; - const extDb = db.$use( - definePlugin< - typeof schema, - { - all: CacheOptions; + const cachePlugin = definePlugin({ + id: 'cache', + queryArgs: { + $read: cacheSchema, + $create: cacheBustSchema, + $update: cacheBustSchema, + $delete: cacheBustSchema, + }, + + onQuery: async ({ args, proceed }) => { + if (args && 'cache' in args) { + gotTTL = (args as CacheOptions).cache?.ttl; } - >({ - id: 'cache', - extQueryArgs: { - getValidationSchema: () => cacheSchema, - }, + return proceed(args); + }, + }); - onQuery: async ({ args, proceed }) => { - if (args && 'cache' in args) { - gotTTL = (args as CacheOptions).cache?.ttl; - } - return proceed(args); - }, - }), - ); + const extDb = db.$use(cachePlugin); // cache is optional const alice = await extDb.user.create({ data: { name: 'Alice' } }); - // ttl is optional + // bust is optional const bob = await extDb.user.create({ data: { name: 'Bob' }, cache: {} }); gotTTL = undefined; @@ -81,9 +80,20 @@ describe('Plugin extended query args', () => { // @ts-expect-error await expect(extDb.user.findMany({ where: { id: 'abc' } })).rejects.toThrow('expected number'); + // read args are not allowed in create + // @ts-expect-error + await expect(extDb.user.create({ data: { name: 'Charlie' }, cache: { ttl: 1000 } })).rejects.toThrow( + 'Unrecognized key', + ); + + // create args are not allowed in read + // @ts-expect-error + await expect(extDb.user.findMany({ cache: { bust: true } })).rejects.toThrow('Unrecognized key'); + // validate all other operations const cacheOption = { cache: { ttl: 1000 } } as const; + const cacheBustOption = { cache: { bust: true } } as const; // read operations await expect(extDb.user.findUnique({ where: { id: 1 }, ...cacheOption })).toResolveTruthy(); @@ -109,25 +119,25 @@ describe('Plugin extended query args', () => { ).resolves.toHaveLength(2); // create operations - await expect(extDb.user.createMany({ data: [{ name: 'Charlie' }], ...cacheOption })).resolves.toHaveProperty( - 'count', - ); - await expect(extDb.user.createManyAndReturn({ data: [{ name: 'David' }], ...cacheOption })).toResolveWithLength( - 1, - ); + await expect( + extDb.user.createMany({ data: [{ name: 'Charlie' }], ...cacheBustOption }), + ).resolves.toHaveProperty('count'); + await expect( + extDb.user.createManyAndReturn({ data: [{ name: 'David' }], ...cacheBustOption }), + ).toResolveWithLength(1); // update operations await expect( - extDb.user.update({ where: { id: alice.id }, data: { name: 'Alice Updated' }, ...cacheOption }), + extDb.user.update({ where: { id: alice.id }, data: { name: 'Alice Updated' }, ...cacheBustOption }), ).toResolveTruthy(); await expect( - extDb.user.updateMany({ where: { name: 'Bob' }, data: { name: 'Bob Updated' }, ...cacheOption }), + extDb.user.updateMany({ where: { name: 'Bob' }, data: { name: 'Bob Updated' }, ...cacheBustOption }), ).resolves.toHaveProperty('count'); await expect( extDb.user.updateManyAndReturn({ where: { name: 'Charlie' }, data: { name: 'Charlie Updated' }, - ...cacheOption, + ...cacheBustOption, }), ).toResolveTruthy(); await expect( @@ -135,13 +145,13 @@ describe('Plugin extended query args', () => { where: { id: 999 }, create: { name: 'Eve' }, update: { name: 'Eve Updated' }, - ...cacheOption, + ...cacheBustOption, }), ).resolves.toMatchObject({ name: 'Eve' }); // delete operations - await expect(extDb.user.delete({ where: { id: bob.id }, ...cacheOption })).toResolveTruthy(); - await expect(extDb.user.deleteMany({ where: { name: 'David' }, ...cacheOption })).resolves.toHaveProperty( + await expect(extDb.user.delete({ where: { id: bob.id }, ...cacheBustOption })).toResolveTruthy(); + await expect(extDb.user.deleteMany({ where: { name: 'David' }, ...cacheBustOption })).resolves.toHaveProperty( 'count', ); @@ -162,101 +172,51 @@ describe('Plugin extended query args', () => { await expect(extDb.$setAuth({ id: 1 }).user.findMany(cacheOption)).toResolveTruthy(); }); - it('should allow extending specific operations', async () => { + it('should allow extending all operations', async () => { const extDb = db.$use( - definePlugin< - typeof schema, - { - [Op in CoreReadOperations]: CacheOptions; - } - >({ + definePlugin({ id: 'cache', - extQueryArgs: { - getValidationSchema: (operation) => { - if (!(CoreReadOperations as readonly string[]).includes(operation)) { - return undefined; - } - return cacheSchema; - }, + queryArgs: { + $all: cacheSchema, }, }), ); - // "create" is not extended - // @ts-expect-error - await expect(extDb.user.create({ data: { name: 'Bob' }, cache: {} })).rejects.toThrow('Unrecognized key'); - - await extDb.user.create({ data: { name: 'Alice' } }); - + const alice = await extDb.user.create({ data: { name: 'Alice' }, cache: {} }); await expect(extDb.user.findMany({ cache: { ttl: 100 } })).toResolveWithLength(1); await expect(extDb.user.count({ where: { name: 'Alice' }, cache: { ttl: 200 } })).resolves.toBe(1); + await expect( + extDb.user.update({ where: { id: alice.id }, data: { name: 'Alice Updated' }, cache: { ttl: 300 } }), + ).toResolveTruthy(); + await expect(extDb.user.delete({ where: { id: alice.id }, cache: { ttl: 400 } })).toResolveTruthy(); }); - it('should allow different extensions for different operations', async () => { - let gotTTL: number | undefined = undefined; - let gotBust: boolean | undefined = undefined; - + it('should allow extending specific operations', async () => { const extDb = db.$use( - definePlugin< - typeof schema, - { - [Op in CoreReadOperations]: CacheOptions; - } & { - [Op in CoreWriteOperations]: CacheBustOptions; - } - >({ + definePlugin({ id: 'cache', - extQueryArgs: { - getValidationSchema: (operation) => { - if ((CoreReadOperations as readonly string[]).includes(operation)) { - return cacheSchema; - } else if ((CoreWriteOperations as readonly string[]).includes(operation)) { - return cacheBustSchema; - } - return undefined; - }, - }, - - onQuery: async ({ args, proceed }) => { - if (args && 'cache' in args) { - gotTTL = (args as CacheOptions).cache?.ttl; - gotBust = (args as CacheBustOptions).cache?.bust; - } - return proceed(args); + queryArgs: { + $read: cacheSchema, }, }), ); - gotBust = undefined; - await extDb.user.create({ data: { name: 'Alice' }, cache: { bust: true } }); - expect(gotBust).toBe(true); - - // ttl extension is not applied to "create" + // "create" is not extended // @ts-expect-error - await expect(extDb.user.create({ data: { name: 'Bob' }, cache: { ttl: 100 } })).rejects.toThrow( - 'Unrecognized key', - ); + await expect(extDb.user.create({ data: { name: 'Bob' }, cache: {} })).rejects.toThrow('Unrecognized key'); - gotTTL = undefined; - await expect(extDb.user.findMany({ cache: { ttl: 5000 } })).toResolveWithLength(1); - expect(gotTTL).toBe(5000); + await extDb.user.create({ data: { name: 'Alice' } }); - // bust extension is not applied to "findMany" - // @ts-expect-error - await expect(extDb.user.findMany({ cache: { bust: true } })).rejects.toThrow('Unrecognized key'); + await expect(extDb.user.findMany({ cache: { ttl: 100 } })).toResolveWithLength(1); + await expect(extDb.user.count({ where: { name: 'Alice' }, cache: { ttl: 200 } })).resolves.toBe(1); }); it('should isolate validation schemas between clients', async () => { const extDb = db.$use( - definePlugin< - typeof schema, - { - all: CacheOptions; - } - >({ + definePlugin({ id: 'cache', - extQueryArgs: { - getValidationSchema: () => cacheSchema, + queryArgs: { + $all: cacheSchema, }, }), ); @@ -270,4 +230,103 @@ describe('Plugin extended query args', () => { await expect(db.user.findMany({ cache: { ttl: 2000 } })).rejects.toThrow('Unrecognized key'); await expect(extDb.user.findMany({ cache: { ttl: 2000 } })).toResolveWithLength(0); }); + + it('should merge $create and $update schemas for upsert operation', async () => { + // Define different schemas for $create and $update + const createOnlySchema = z.object({ + tracking: z + .strictObject({ + source: z.string().optional(), + }) + .optional(), + }); + + const updateOnlySchema = z.object({ + audit: z + .strictObject({ + reason: z.string().optional(), + }) + .optional(), + }); + + const extDb = db.$use( + definePlugin({ + id: 'test', + queryArgs: { + $create: createOnlySchema, + $update: updateOnlySchema, + }, + }), + ); + + // upsert should accept both tracking (from $create) and audit (from $update) + await expect( + extDb.user.upsert({ + where: { id: 999 }, + create: { name: 'Alice' }, + update: { name: 'Alice Updated' }, + tracking: { source: 'test' }, + audit: { reason: 'testing merge' }, + }), + ).resolves.toMatchObject({ name: 'Alice' }); + + // upsert should reject tracking-only in update operations + await expect( + extDb.user.update({ + where: { id: 1 }, + data: { name: 'Test' }, + // @ts-expect-error - tracking is only for $create + tracking: { source: 'test' }, + }), + ).rejects.toThrow('Unrecognized key'); + + // upsert should reject audit-only in create operations + await expect( + extDb.user.create({ + data: { name: 'Bob' }, + // @ts-expect-error - audit is only for $update + audit: { reason: 'test' }, + }), + ).rejects.toThrow('Unrecognized key'); + + // verify that upsert without both is fine + await expect( + extDb.user.upsert({ + where: { id: 888 }, + create: { name: 'Charlie' }, + update: { name: 'Charlie Updated' }, + }), + ).resolves.toMatchObject({ name: 'Charlie' }); + + // verify that upsert with only tracking is fine + await expect( + extDb.user.upsert({ + where: { id: 777 }, + create: { name: 'David' }, + update: { name: 'David Updated' }, + tracking: { source: 'test' }, + }), + ).resolves.toMatchObject({ name: 'David' }); + + // verify that upsert with only audit is fine + await expect( + extDb.user.upsert({ + where: { id: 666 }, + create: { name: 'Eve' }, + update: { name: 'Eve Updated' }, + audit: { reason: 'testing' }, + }), + ).resolves.toMatchObject({ name: 'Eve' }); + + // verify that upsert with both is fine + await expect( + extDb.user.upsert({ + where: { id: 555 }, + create: { name: 'Frank' }, + update: { name: 'Frank Updated' }, + tracking: { source: 'test' }, + audit: { reason: 'testing both' }, + }), + ).resolves.toMatchObject({ name: 'Frank' }); + }); }); diff --git a/tests/e2e/orm/plugin-infra/on-query-hooks.test.ts b/tests/e2e/orm/plugin-infra/on-query-hooks.test.ts index 0b74e76f..46d0eb41 100644 --- a/tests/e2e/orm/plugin-infra/on-query-hooks.test.ts +++ b/tests/e2e/orm/plugin-infra/on-query-hooks.test.ts @@ -190,7 +190,7 @@ describe('On query hooks tests', () => { let findHookCalled = false; - const plugin = definePlugin({ + const plugin = definePlugin({ id: 'test-plugin', onQuery: (ctx) => { findHookCalled = true;