diff --git a/packages/orm/src/client/client-impl.ts b/packages/orm/src/client/client-impl.ts index ce2d4d428..4af8b0310 100644 --- a/packages/orm/src/client/client-impl.ts +++ b/packages/orm/src/client/client-impl.ts @@ -21,13 +21,13 @@ import type { TransactionIsolationLevel, } from './contract'; import { AggregateOperationHandler } from './crud/operations/aggregate'; -import type { AllCrudOperation, CoreCrudOperation } from './crud/operations/base'; +import type { AllCrudOperations, CoreCrudOperations } from './crud/operations/base'; import { BaseOperationHandler } from './crud/operations/base'; import { CountOperationHandler } from './crud/operations/count'; import { CreateOperationHandler } from './crud/operations/create'; import { DeleteOperationHandler } from './crud/operations/delete'; -import { FindOperationHandler } from './crud/operations/find'; import { ExistsOperationHandler } from './crud/operations/exists'; +import { FindOperationHandler } from './crud/operations/find'; import { GroupByOperationHandler } from './crud/operations/group-by'; import { UpdateOperationHandler } from './crud/operations/update'; import { InputValidator } from './crud/validator'; @@ -59,6 +59,7 @@ export class ClientImpl { public readonly $schema: SchemaDef; readonly kyselyProps: KyselyProps; private auth: AuthType | undefined; + inputValidator: InputValidator; constructor( private readonly schema: SchemaDef, @@ -114,6 +115,7 @@ export class ClientImpl { } this.kysely = new Kysely(this.kyselyProps); + this.inputValidator = baseClient?.inputValidator ?? new InputValidator(this as any); return createClientProxy(this); } @@ -242,8 +244,7 @@ export class ClientImpl { } // Validate inputs using the same validator infrastructure as CRUD operations. - const inputValidator = new InputValidator(this as any); - const validatedInput = inputValidator.validateProcedureInput(name, input); + const validatedInput = this.inputValidator.validateProcedureInput(name, input); const handler = procOptions[name] as Function; @@ -292,19 +293,22 @@ export class ClientImpl { await new SchemaDbPusher(this.schema, this.kysely).push(); } - $use(plugin: RuntimePlugin) { - // tsc perf - const newPlugins: RuntimePlugin[] = [...(this.$options.plugins ?? []), plugin]; + $use(plugin: RuntimePlugin) { + const newPlugins: RuntimePlugin[] = [...(this.$options.plugins ?? []), plugin]; const newOptions: ClientOptions = { ...this.options, plugins: newPlugins, }; - return new ClientImpl(this.schema, newOptions, this); + const newClient = new ClientImpl(this.schema, newOptions, this); + // create a new validator to have a fresh schema cache, because plugins may extend the + // query args schemas + newClient.inputValidator = new InputValidator(newClient as any); + return newClient; } $unuse(pluginId: string) { // tsc perf - const newPlugins: RuntimePlugin[] = []; + const newPlugins: RuntimePlugin[] = []; for (const plugin of this.options.plugins ?? []) { if (plugin.id !== pluginId) { newPlugins.push(plugin); @@ -314,16 +318,24 @@ export class ClientImpl { ...this.options, plugins: newPlugins, }; - return new ClientImpl(this.schema, newOptions, this); + const newClient = new ClientImpl(this.schema, newOptions, this); + // create a new validator to have a fresh schema cache, because plugins may + // extend the query args schemas + newClient.inputValidator = new InputValidator(newClient as any); + return newClient; } $unuseAll() { // tsc perf const newOptions: ClientOptions = { ...this.options, - plugins: [] as RuntimePlugin[], + plugins: [] as RuntimePlugin[], }; - return new ClientImpl(this.schema, newOptions, this); + const newClient = new ClientImpl(this.schema, newOptions, this); + // create a new validator to have a fresh schema cache, because plugins may + // extend the query args schemas + newClient.inputValidator = new InputValidator(newClient as any); + return newClient; } $setAuth(auth: AuthType | undefined) { @@ -340,10 +352,10 @@ export class ClientImpl { } $setOptions>(options: Options): ClientContract { - return new ClientImpl(this.schema, options as ClientOptions, this) as unknown as ClientContract< - SchemaDef, - Options - >; + const newClient = new ClientImpl(this.schema, options as ClientOptions, this); + // create a new validator to have a fresh schema cache, because options may change validation settings + newClient.inputValidator = new InputValidator(newClient as any); + return newClient as unknown as ClientContract; } $setInputValidation(enable: boolean) { @@ -351,7 +363,7 @@ export class ClientImpl { ...this.options, validateInput: enable, }; - return new ClientImpl(this.schema, newOptions, this); + return this.$setOptions(newOptions); } $executeRaw(query: TemplateStringsArray, ...values: any[]) { @@ -391,7 +403,6 @@ export class ClientImpl { } function createClientProxy(client: ClientImpl): ClientImpl { - const inputValidator = new InputValidator(client as any); const resultProcessor = new ResultProcessor(client.$schema, client.$options); return new Proxy(client, { @@ -403,7 +414,7 @@ function createClientProxy(client: ClientImpl): ClientImpl { if (typeof prop === 'string') { const model = Object.keys(client.$schema.models).find((m) => m.toLowerCase() === prop.toLowerCase()); if (model) { - return createModelCrudHandler(client as any, model, inputValidator, resultProcessor); + return createModelCrudHandler(client as any, model, client.inputValidator, resultProcessor); } } @@ -419,8 +430,8 @@ function createModelCrudHandler( resultProcessor: ResultProcessor, ): ModelOperations { const createPromise = ( - operation: CoreCrudOperation, - nominalOperation: AllCrudOperation, + operation: CoreCrudOperations, + nominalOperation: AllCrudOperations, args: unknown, handler: BaseOperationHandler, postProcess = false, @@ -448,8 +459,8 @@ function createModelCrudHandler( const onQuery = plugin.onQuery; if (onQuery) { const _proceed = proceed; - proceed = (_args: unknown) => - onQuery({ + proceed = (_args: unknown) => { + const ctx: any = { client, model, operation: nominalOperation, @@ -457,7 +468,9 @@ function createModelCrudHandler( args: _args, // ensure inner overrides are propagated to the previous proceed proceed: (nextArgs: unknown) => _proceed(nextArgs), - }) as Promise; + }; + return (onQuery as (ctx: any) => Promise)(ctx); + }; } } @@ -516,6 +529,7 @@ function createModelCrudHandler( args, new FindOperationHandler(client, model, inputValidator), true, + false, ); }, diff --git a/packages/orm/src/client/contract.ts b/packages/orm/src/client/contract.ts index af5e51389..0006cd2ab 100644 --- a/packages/orm/src/client/contract.ts +++ b/packages/orm/src/client/contract.ts @@ -40,13 +40,23 @@ import type { UpdateManyArgs, UpsertArgs, } from './crud-types'; +import type { CoreCrudOperations } from './crud/operations/base'; import type { ClientOptions, QueryOptions, ToQueryOptions } from './options'; -import type { RuntimePlugin } from './plugin'; +import type { ExtQueryArgsBase, RuntimePlugin } from './plugin'; import type { ZenStackPromise } from './promise'; import type { ToKysely } from './query-builder'; type TransactionUnsupportedMethods = (typeof TRANSACTION_UNSUPPORTED_METHODS)[number]; +/** + * Extracts extended query args for a specific operation. + */ +type ExtractExtQueryArgs = Operation extends keyof ExtQueryArgs + ? NonNullable + : 'all' extends keyof ExtQueryArgs + ? NonNullable + : {}; + /** * Transaction isolation levels. */ @@ -61,7 +71,11 @@ export enum TransactionIsolationLevel { /** * ZenStack client interface. */ -export type ClientContract = ClientOptions> = { +export type ClientContract< + Schema extends SchemaDef, + Options extends ClientOptions = ClientOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, +> = { /** * The schema definition. */ @@ -118,7 +132,7 @@ export type ClientContract | undefined): ClientContract; + $setAuth(auth: AuthType | undefined): ClientContract; /** * Returns a new client with new options applied. @@ -127,15 +141,15 @@ export type ClientContract>(options: Options): ClientContract; + $setOptions>(options: Options): ClientContract; /** * Returns a new client enabling/disabling input validations expressed with attributes like * `@email`, `@regex`, `@@validate`, etc. * - * @deprecated Use `$setOptions` instead. + * @deprecated Use {@link $setOptions} instead. */ - $setInputValidation(enable: boolean): ClientContract; + $setInputValidation(enable: boolean): ClientContract; /** * The Kysely query builder instance. @@ -151,7 +165,7 @@ export type ClientContract( - callback: (tx: Omit, TransactionUnsupportedMethods>) => Promise, + callback: (tx: TransactionClientContract) => Promise, options?: { isolationLevel?: TransactionIsolationLevel }, ): Promise; @@ -166,12 +180,14 @@ export type ClientContract): ClientContract; + $use( + plugin: RuntimePlugin, + ): ClientContract; /** * Returns a new client with the specified plugin removed. */ - $unuse(pluginId: string): ClientContract; + $unuse(pluginId: string): ClientContract; /** * Returns a new client with all plugins removed. @@ -194,16 +210,22 @@ export type ClientContract; } & { - [Key in GetModels as Uncapitalize]: ModelOperations>; + [Key in GetModels as Uncapitalize]: ModelOperations< + Schema, + Key, + ToQueryOptions, + ExtQueryArgs + >; } & ProcedureOperations; /** * The contract for a client in a transaction. */ -export type TransactionClientContract> = Omit< - ClientContract, - TransactionUnsupportedMethods ->; +export type TransactionClientContract< + Schema extends SchemaDef, + Options extends ClientOptions, + ExtQueryArgs extends ExtQueryArgsBase, +> = Omit, TransactionUnsupportedMethods>; export type ProcedureOperations = Schema['procedures'] extends Record @@ -253,6 +275,7 @@ export type AllModelOperations< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions, + ExtQueryArgs, > = { /** * Returns a list of entities. @@ -335,8 +358,8 @@ export type AllModelOperations< * }); // result: `{ _count: { posts: number } }` * ``` */ - findMany>( - args?: SelectSubset>, + findMany & ExtractExtQueryArgs>( + args?: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise[]>; /** @@ -345,8 +368,8 @@ export type AllModelOperations< * @returns a single entity or null if not found * @see {@link findMany} */ - findUnique>( - args: SelectSubset>, + findUnique & ExtractExtQueryArgs>( + args: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise | null>; /** @@ -355,8 +378,8 @@ export type AllModelOperations< * @returns a single entity * @see {@link findMany} */ - findUniqueOrThrow>( - args: SelectSubset>, + findUniqueOrThrow & ExtractExtQueryArgs>( + args: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise>; /** @@ -365,8 +388,8 @@ export type AllModelOperations< * @returns a single entity or null if not found * @see {@link findMany} */ - findFirst>( - args?: SelectSubset>, + findFirst & ExtractExtQueryArgs>( + args?: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise | null>; /** @@ -375,8 +398,8 @@ export type AllModelOperations< * @returns a single entity * @see {@link findMany} */ - findFirstOrThrow>( - args?: SelectSubset>, + findFirstOrThrow & ExtractExtQueryArgs>( + args?: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise>; /** @@ -431,8 +454,8 @@ export type AllModelOperations< * }); * ``` */ - create>( - args: SelectSubset>, + create & ExtractExtQueryArgs>( + args: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise>; /** @@ -460,8 +483,8 @@ export type AllModelOperations< * }); * ``` */ - createMany>( - args?: SelectSubset>, + createMany & ExtractExtQueryArgs>( + args?: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise; /** @@ -482,8 +505,13 @@ export type AllModelOperations< * }); * ``` */ - createManyAndReturn>( - args?: SelectSubset>, + createManyAndReturn< + T extends CreateManyAndReturnArgs & ExtractExtQueryArgs, + >( + args?: SelectSubset< + T, + CreateManyAndReturnArgs & ExtractExtQueryArgs + >, ): ZenStackPromise[]>; /** @@ -603,8 +631,8 @@ export type AllModelOperations< * }); * ``` */ - update>( - args: SelectSubset>, + update & ExtractExtQueryArgs>( + args: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise>; /** @@ -627,8 +655,8 @@ export type AllModelOperations< * limit: 10 * }); */ - updateMany>( - args: Subset>, + updateMany & ExtractExtQueryArgs>( + args: Subset & ExtractExtQueryArgs>, ): ZenStackPromise; /** @@ -653,8 +681,13 @@ export type AllModelOperations< * }); * ``` */ - updateManyAndReturn>( - args: Subset>, + updateManyAndReturn< + T extends UpdateManyAndReturnArgs & ExtractExtQueryArgs, + >( + args: Subset< + T, + UpdateManyAndReturnArgs & ExtractExtQueryArgs + >, ): ZenStackPromise[]>; /** @@ -677,8 +710,8 @@ export type AllModelOperations< * }); * ``` */ - upsert>( - args: SelectSubset>, + upsert & ExtractExtQueryArgs>( + args: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise>; /** @@ -700,8 +733,8 @@ export type AllModelOperations< * }); // result: `{ id: string; email: string }` * ``` */ - delete>( - args: SelectSubset>, + delete & ExtractExtQueryArgs>( + args: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise>; /** @@ -723,8 +756,8 @@ export type AllModelOperations< * }); * ``` */ - deleteMany>( - args?: Subset>, + deleteMany & ExtractExtQueryArgs>( + args?: Subset & ExtractExtQueryArgs>, ): ZenStackPromise; /** @@ -745,8 +778,8 @@ export type AllModelOperations< * select: { _all: true, email: true } * }); // result: `{ _all: number, email: number }` */ - count>( - args?: Subset>, + count & ExtractExtQueryArgs>( + args?: Subset & ExtractExtQueryArgs>, ): ZenStackPromise>>; /** @@ -766,8 +799,8 @@ export type AllModelOperations< * _max: { age: true } * }); // result: `{ _count: number, _avg: { age: number }, ... }` */ - aggregate>( - args: Subset>, + aggregate & ExtractExtQueryArgs>( + args: Subset & ExtractExtQueryArgs>, ): ZenStackPromise>>; /** @@ -803,8 +836,8 @@ export type AllModelOperations< * having: { country: 'US', age: { _avg: { gte: 18 } } } * }); */ - groupBy>( - args: Subset>, + groupBy & ExtractExtQueryArgs>( + args: Subset & ExtractExtQueryArgs>, ): ZenStackPromise>>; /** @@ -818,14 +851,14 @@ export type AllModelOperations< * await db.user.exists({ * where: { id: 1 }, * }); // result: `boolean` - * + * * // check with a relation * await db.user.exists({ * where: { posts: { some: { published: true } } }, * }); // result: `boolean` */ - exists>( - args?: Subset>, + exists & ExtractExtQueryArgs>( + args?: Subset & ExtractExtQueryArgs>, ): ZenStackPromise; }; @@ -835,8 +868,9 @@ export type ModelOperations< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs = {}, > = Omit< - AllModelOperations, + AllModelOperations, // exclude operations not applicable to delegate models IsDelegateModel extends true ? OperationsIneligibleForDelegateModels : never >; diff --git a/packages/orm/src/client/crud-types.ts b/packages/orm/src/client/crud-types.ts index 1b6f3d3c4..5626931e7 100644 --- a/packages/orm/src/client/crud-types.ts +++ b/packages/orm/src/client/crud-types.ts @@ -215,10 +215,10 @@ export type ModelResult< FieldIsArray >; } & ('_count' extends keyof I - ? I['_count'] extends false | undefined - ? {} - : { _count: SelectCountResult } - : {}) + ? I['_count'] extends false | undefined + ? {} + : { _count: SelectCountResult } + : {}) : Args extends { omit: infer O } & Record ? DefaultModelResult : DefaultModelResult, diff --git a/packages/orm/src/client/crud/operations/base.ts b/packages/orm/src/client/crud/operations/base.ts index 957b94ab8..c9c85121c 100644 --- a/packages/orm/src/client/crud/operations/base.ts +++ b/packages/orm/src/client/crud/operations/base.ts @@ -54,25 +54,90 @@ import { getCrudDialect } from '../dialects'; import type { BaseCrudDialect } from '../dialects/base-dialect'; import { InputValidator } from '../validator'; -export type CoreCrudOperation = - | 'findMany' - | 'findUnique' - | 'findFirst' - | 'create' - | 'createMany' - | 'createManyAndReturn' - | 'update' - | 'updateMany' - | 'updateManyAndReturn' - | 'upsert' - | 'delete' - | 'deleteMany' - | 'count' - | 'aggregate' - | 'groupBy' - | 'exists'; - -export type AllCrudOperation = CoreCrudOperation | 'findUniqueOrThrow' | 'findFirstOrThrow'; +/** + * List of core CRUD operations. It excludes the 'orThrow' variants. + */ +export const CoreCrudOperations = [ + 'findMany', + 'findUnique', + 'findFirst', + 'create', + 'createMany', + 'createManyAndReturn', + 'update', + 'updateMany', + 'updateManyAndReturn', + 'upsert', + 'delete', + 'deleteMany', + 'count', + 'aggregate', + 'groupBy', + 'exists', +] as const; + +/** + * List of core CRUD operations. It excludes the 'orThrow' variants. + */ +export type CoreCrudOperations = (typeof CoreCrudOperations)[number]; + +/** + * List of core read operations. It excludes the 'orThrow' variants. + */ +export const CoreReadOperations = [ + 'findMany', + 'findUnique', + 'findFirst', + 'count', + 'aggregate', + 'groupBy', + 'exists', +] as const; + +/** + * List of core read operations. It excludes the 'orThrow' variants. + */ +export type CoreReadOperations = (typeof CoreReadOperations)[number]; + +/** + * List of core write operations. + */ +export const CoreWriteOperations = [ + 'create', + 'createMany', + 'createManyAndReturn', + 'update', + 'updateMany', + 'updateManyAndReturn', + 'upsert', + 'delete', + 'deleteMany', +] as const; + +/** + * List of core write operations. + */ +export type CoreWriteOperations = (typeof CoreWriteOperations)[number]; + +/** + * List of all CRUD operations, including 'orThrow' variants. + */ +export const AllCrudOperations = [...CoreCrudOperations, 'findUniqueOrThrow', 'findFirstOrThrow'] as const; + +/** + * List of all CRUD operations, including 'orThrow' variants. + */ +export type AllCrudOperations = (typeof AllCrudOperations)[number]; + +/** + * List of all read operations, including 'orThrow' variants. + */ +export const AllReadOperations = [...CoreReadOperations, 'findUniqueOrThrow', 'findFirstOrThrow'] as const; + +/** + * List of all read operations, including 'orThrow' variants. + */ +export type AllReadOperations = (typeof AllReadOperations)[number]; // context for nested relation operations export type FromRelationContext = { @@ -109,7 +174,7 @@ export abstract class BaseOperationHandler { return this.client.$qb; } - abstract handle(operation: CoreCrudOperation, args: any): Promise; + abstract handle(operation: CoreCrudOperations, args: any): Promise; withClient(client: ClientContract) { return new (this.constructor as new (...args: any[]) => this)(client, this.model, this.inputValidator); diff --git a/packages/orm/src/client/crud/operations/find.ts b/packages/orm/src/client/crud/operations/find.ts index 49938c8c3..db087a3b5 100644 --- a/packages/orm/src/client/crud/operations/find.ts +++ b/packages/orm/src/client/crud/operations/find.ts @@ -1,9 +1,9 @@ import type { GetModels, SchemaDef } from '../../../schema'; import type { FindArgs } from '../../crud-types'; -import { BaseOperationHandler, type CoreCrudOperation } from './base'; +import { BaseOperationHandler, type CoreCrudOperations } from './base'; export class FindOperationHandler extends BaseOperationHandler { - async handle(operation: CoreCrudOperation, args: unknown, validateArgs = true): Promise { + async handle(operation: CoreCrudOperations, args: unknown, validateArgs = true): Promise { // normalize args to strip `undefined` fields const normalizedArgs = this.normalizeArgs(args); @@ -11,10 +11,7 @@ export class FindOperationHandler extends BaseOperatio // parse args let parsedArgs = validateArgs - ? this.inputValidator.validateFindArgs(this.model, normalizedArgs, { - unique: operation === 'findUnique', - findOne, - }) + ? this.inputValidator.validateFindArgs(this.model, normalizedArgs, operation) : (normalizedArgs as FindArgs, true> | undefined); if (findOne) { diff --git a/packages/orm/src/client/crud/validator/index.ts b/packages/orm/src/client/crud/validator/index.ts index 50245c605..3556c6046 100644 --- a/packages/orm/src/client/crud/validator/index.ts +++ b/packages/orm/src/client/crud/validator/index.ts @@ -2,7 +2,7 @@ import { enumerate, invariant } from '@zenstackhq/common-helpers'; import Decimal from 'decimal.js'; import stableStringify from 'json-stable-stringify'; import { match, P } from 'ts-pattern'; -import { z, ZodType } from 'zod'; +import { z, ZodObject, ZodType } from 'zod'; import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types'; import { type AttributeApplication, @@ -46,6 +46,7 @@ import { requireField, requireModel, } from '../../query-utils'; +import type { CoreCrudOperations } from '../operations/base'; import { addBigIntValidation, addCustomValidation, @@ -55,11 +56,11 @@ import { addStringValidation, } from './utils'; -const schemaCache = new WeakMap>(); - -type GetSchemaFunc = (model: GetModels, options: Options) => ZodType; +type GetSchemaFunc = (model: GetModels) => ZodType; export class InputValidator { + private readonly schemaCache = new Map(); + constructor(private readonly client: ClientContract) {} private get schema() { @@ -191,19 +192,20 @@ export class InputValidator { validateFindArgs( model: GetModels, args: unknown, - options: { unique: boolean; findOne: boolean }, + operation: CoreCrudOperations, ): FindArgs, true> | undefined { - return this.validate< - FindArgs, true> | undefined, - Parameters[1] - >(model, 'find', options, (model, options) => this.makeFindSchema(model, options), args); + return this.validate, true> | undefined>( + model, + operation, + (model) => this.makeFindSchema(model, operation), + args, + ); } validateExistsArgs(model: GetModels, args: unknown): ExistsArgs> | undefined { return this.validate>>( model, 'exists', - undefined, (model) => this.makeExistsSchema(model), args, ); @@ -213,7 +215,6 @@ export class InputValidator { return this.validate>>( model, 'create', - undefined, (model) => this.makeCreateSchema(model), args, ); @@ -223,7 +224,6 @@ export class InputValidator { return this.validate>>( model, 'createMany', - undefined, (model) => this.makeCreateManySchema(model), args, ); @@ -236,7 +236,6 @@ export class InputValidator { return this.validate> | undefined>( model, 'createManyAndReturn', - undefined, (model) => this.makeCreateManyAndReturnSchema(model), args, ); @@ -246,7 +245,6 @@ export class InputValidator { return this.validate>>( model, 'update', - undefined, (model) => this.makeUpdateSchema(model), args, ); @@ -256,7 +254,6 @@ export class InputValidator { return this.validate>>( model, 'updateMany', - undefined, (model) => this.makeUpdateManySchema(model), args, ); @@ -269,7 +266,6 @@ export class InputValidator { return this.validate>>( model, 'updateManyAndReturn', - undefined, (model) => this.makeUpdateManyAndReturnSchema(model), args, ); @@ -279,7 +275,6 @@ export class InputValidator { return this.validate>>( model, 'upsert', - undefined, (model) => this.makeUpsertSchema(model), args, ); @@ -289,7 +284,6 @@ export class InputValidator { return this.validate>>( model, 'delete', - undefined, (model) => this.makeDeleteSchema(model), args, ); @@ -302,7 +296,6 @@ export class InputValidator { return this.validate> | undefined>( model, 'deleteMany', - undefined, (model) => this.makeDeleteManySchema(model), args, ); @@ -312,7 +305,6 @@ export class InputValidator { return this.validate> | undefined>( model, 'count', - undefined, (model) => this.makeCountSchema(model), args, ); @@ -322,7 +314,6 @@ export class InputValidator { return this.validate>>( model, 'aggregate', - undefined, (model) => this.makeAggregateSchema(model), args, ); @@ -332,49 +323,32 @@ export class InputValidator { return this.validate>>( model, 'groupBy', - undefined, (model) => this.makeGroupBySchema(model), args, ); } private getSchemaCache(cacheKey: string) { - let thisCache = schemaCache.get(this.schema); - if (!thisCache) { - thisCache = new Map(); - schemaCache.set(this.schema, thisCache); - } - return thisCache.get(cacheKey); + return this.schemaCache.get(cacheKey); } private setSchemaCache(cacheKey: string, schema: ZodType) { - let thisCache = schemaCache.get(this.schema); - if (!thisCache) { - thisCache = new Map(); - schemaCache.set(this.schema, thisCache); - } - return thisCache.set(cacheKey, schema); + return this.schemaCache.set(cacheKey, schema); } - private validate( - model: GetModels, - operation: string, - options: Options, - getSchema: GetSchemaFunc, - args: unknown, - ) { + private validate(model: GetModels, operation: string, getSchema: GetSchemaFunc, args: unknown) { const cacheKey = stableStringify({ type: 'model', model, operation, - options, extraValidationsEnabled: this.extraValidationsEnabled, }); let schema = this.getSchemaCache(cacheKey!); if (!schema) { - schema = getSchema(model, options); + schema = getSchema(model); this.setSchemaCache(cacheKey!, schema); } + const { error, data } = schema.safeParse(args); if (error) { throw createInvalidInputError( @@ -388,12 +362,27 @@ export class InputValidator { return data as T; } + private mergePluginArgsSchema(schema: ZodObject, operation: CoreCrudOperations) { + let result = schema; + for (const plugin of this.options.plugins ?? []) { + if (plugin.extQueryArgs) { + const pluginSchema = plugin.extQueryArgs.getValidationSchema(operation); + if (pluginSchema) { + result = result.extend(pluginSchema.shape); + } + } + } + return result.strict(); + } + // #region Find - private makeFindSchema(model: string, options: { unique: boolean; findOne: boolean }) { + private makeFindSchema(model: string, operation: CoreCrudOperations) { const fields: Record = {}; - const where = this.makeWhereSchema(model, options.unique); - if (options.unique) { + const unique = operation === 'findUnique'; + const findOne = operation === 'findUnique' || operation === 'findFirst'; + const where = this.makeWhereSchema(model, unique); + if (unique) { fields['where'] = where; } else { fields['where'] = where.optional(); @@ -403,9 +392,9 @@ export class InputValidator { fields['include'] = this.makeIncludeSchema(model).optional().nullable(); fields['omit'] = this.makeOmitSchema(model).optional().nullable(); - if (!options.unique) { + if (!unique) { fields['skip'] = this.makeSkipSchema().optional(); - if (options.findOne) { + if (findOne) { fields['take'] = z.literal(1).optional(); } else { fields['take'] = this.makeTakeSchema().optional(); @@ -415,22 +404,22 @@ export class InputValidator { fields['distinct'] = this.makeDistinctSchema(model).optional(); } - let result: ZodType = z.strictObject(fields); + const baseSchema = z.strictObject(fields); + let result: ZodType = this.mergePluginArgsSchema(baseSchema, operation); result = this.refineForSelectIncludeMutuallyExclusive(result); result = this.refineForSelectOmitMutuallyExclusive(result); - if (!options.unique) { + if (!unique) { result = result.optional(); } return result; } private makeExistsSchema(model: string) { - return z - .strictObject({ - where: this.makeWhereSchema(model, false).optional(), - }) - .optional(); + const baseSchema = z.strictObject({ + where: this.makeWhereSchema(model, false).optional(), + }); + return this.mergePluginArgsSchema(baseSchema, 'exists').optional(); } private makeScalarSchema(type: string, attributes?: readonly AttributeApplication[]) { @@ -1158,27 +1147,29 @@ export class InputValidator { private makeCreateSchema(model: string) { const dataSchema = this.makeCreateDataSchema(model, false); - let schema: ZodType = z.strictObject({ + const baseSchema = z.strictObject({ data: dataSchema, select: this.makeSelectSchema(model).optional().nullable(), include: this.makeIncludeSchema(model).optional().nullable(), omit: this.makeOmitSchema(model).optional().nullable(), }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'create'); schema = this.refineForSelectIncludeMutuallyExclusive(schema); schema = this.refineForSelectOmitMutuallyExclusive(schema); return schema; } private makeCreateManySchema(model: string) { - return this.makeCreateManyDataSchema(model, []).optional(); + return this.mergePluginArgsSchema(this.makeCreateManyDataSchema(model, []), 'createMany').optional(); } private makeCreateManyAndReturnSchema(model: string) { const base = this.makeCreateManyDataSchema(model, []); - const result = base.extend({ + let result: ZodObject = base.extend({ select: this.makeSelectSchema(model).optional().nullable(), omit: this.makeOmitSchema(model).optional().nullable(), }); + result = this.mergePluginArgsSchema(result, 'createManyAndReturn'); return this.refineForSelectOmitMutuallyExclusive(result).optional(); } @@ -1440,29 +1431,34 @@ export class InputValidator { // #region Update private makeUpdateSchema(model: string) { - let schema: ZodType = z.strictObject({ + const baseSchema = z.strictObject({ where: this.makeWhereSchema(model, true), data: this.makeUpdateDataSchema(model), select: this.makeSelectSchema(model).optional().nullable(), include: this.makeIncludeSchema(model).optional().nullable(), omit: this.makeOmitSchema(model).optional().nullable(), }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'update'); schema = this.refineForSelectIncludeMutuallyExclusive(schema); schema = this.refineForSelectOmitMutuallyExclusive(schema); return schema; } private makeUpdateManySchema(model: string) { - return z.strictObject({ - where: this.makeWhereSchema(model, false).optional(), - data: this.makeUpdateDataSchema(model, [], true), - limit: z.number().int().nonnegative().optional(), - }); + return this.mergePluginArgsSchema( + z.strictObject({ + where: this.makeWhereSchema(model, false).optional(), + data: this.makeUpdateDataSchema(model, [], true), + limit: z.number().int().nonnegative().optional(), + }), + 'updateMany', + ); } private makeUpdateManyAndReturnSchema(model: string) { - const base = this.makeUpdateManySchema(model); - let schema: ZodType = base.extend({ + // plugin extended args schema is merged in `makeUpdateManySchema` + const baseSchema: ZodObject = this.makeUpdateManySchema(model); + let schema: ZodType = baseSchema.extend({ select: this.makeSelectSchema(model).optional().nullable(), omit: this.makeOmitSchema(model).optional().nullable(), }); @@ -1471,7 +1467,7 @@ export class InputValidator { } private makeUpsertSchema(model: string) { - let schema: ZodType = z.strictObject({ + const baseSchema = z.strictObject({ where: this.makeWhereSchema(model, true), create: this.makeCreateDataSchema(model, false), update: this.makeUpdateDataSchema(model), @@ -1479,6 +1475,7 @@ export class InputValidator { include: this.makeIncludeSchema(model).optional().nullable(), omit: this.makeOmitSchema(model).optional().nullable(), }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'upsert'); schema = this.refineForSelectIncludeMutuallyExclusive(schema); schema = this.refineForSelectOmitMutuallyExclusive(schema); return schema; @@ -1595,25 +1592,26 @@ export class InputValidator { // #region Delete private makeDeleteSchema(model: GetModels) { - let schema: ZodType = z.strictObject({ + const baseSchema = z.strictObject({ where: this.makeWhereSchema(model, true), select: this.makeSelectSchema(model).optional().nullable(), include: this.makeIncludeSchema(model).optional().nullable(), omit: this.makeOmitSchema(model).optional().nullable(), }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'delete'); schema = this.refineForSelectIncludeMutuallyExclusive(schema); schema = this.refineForSelectOmitMutuallyExclusive(schema); return schema; } private makeDeleteManySchema(model: GetModels) { - return z - .object({ + return this.mergePluginArgsSchema( + z.strictObject({ where: this.makeWhereSchema(model, false).optional(), limit: z.number().int().nonnegative().optional(), - }) - - .optional(); + }), + 'deleteMany', + ).optional(); } // #endregion @@ -1621,16 +1619,16 @@ export class InputValidator { // #region Count makeCountSchema(model: GetModels) { - return z - .object({ + return this.mergePluginArgsSchema( + z.strictObject({ where: this.makeWhereSchema(model, false).optional(), skip: this.makeSkipSchema().optional(), take: this.makeTakeSchema().optional(), orderBy: this.orArray(this.makeOrderBySchema(model, true, false), true).optional(), select: this.makeCountAggregateInputSchema(model).optional(), - }) - - .optional(); + }), + 'count', + ).optional(); } private makeCountAggregateInputSchema(model: GetModels) { @@ -1655,8 +1653,8 @@ export class InputValidator { // #region Aggregate makeAggregateSchema(model: GetModels) { - return z - .object({ + return this.mergePluginArgsSchema( + z.strictObject({ where: this.makeWhereSchema(model, false).optional(), skip: this.makeSkipSchema().optional(), take: this.makeTakeSchema().optional(), @@ -1666,9 +1664,9 @@ export class InputValidator { _sum: this.makeSumAvgInputSchema(model).optional(), _min: this.makeMinMaxInputSchema(model).optional(), _max: this.makeMinMaxInputSchema(model).optional(), - }) - - .optional(); + }), + 'aggregate', + ).optional(); } makeSumAvgInputSchema(model: GetModels) { @@ -1711,7 +1709,7 @@ export class InputValidator { ? this.orArray(z.enum(nonRelationFields as [string, ...string[]]), true) : z.never(); - let schema: z.ZodSchema = z.strictObject({ + const baseSchema = z.strictObject({ where: this.makeWhereSchema(model, false).optional(), orderBy: this.orArray(this.makeOrderBySchema(model, false, true), true).optional(), by: bySchema, @@ -1725,6 +1723,8 @@ export class InputValidator { _max: this.makeMinMaxInputSchema(model).optional(), }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'groupBy'); + // fields used in `having` must be either in the `by` list, or aggregations schema = schema.refine((value: any) => { const bys = typeof value.by === 'string' ? [value.by] : value.by; diff --git a/packages/orm/src/client/index.ts b/packages/orm/src/client/index.ts index e69e41802..bf17a9e6e 100644 --- a/packages/orm/src/client/index.ts +++ b/packages/orm/src/client/index.ts @@ -3,6 +3,13 @@ export * from './contract'; export type * from './crud-types'; export { getCrudDialect } from './crud/dialects'; export { BaseCrudDialect } from './crud/dialects/base-dialect'; +export { + AllCrudOperations, + AllReadOperations, + CoreCrudOperations, + CoreReadOperations, + CoreWriteOperations, +} from './crud/operations/base'; export { InputValidator } from './crud/validator'; export { ORMError, ORMErrorReason, RejectedByPolicyReason } from './errors'; export * from './options'; diff --git a/packages/orm/src/client/options.ts b/packages/orm/src/client/options.ts index c5e6c94d3..d1fa23ed7 100644 --- a/packages/orm/src/client/options.ts +++ b/packages/orm/src/client/options.ts @@ -59,7 +59,7 @@ export type ClientOptions = { /** * Plugins. */ - plugins?: RuntimePlugin[]; + plugins?: RuntimePlugin[]; /** * Logging configuration. @@ -85,7 +85,7 @@ export type ClientOptions = { /** * Options for omitting fields in ORM query results. */ - omit?: OmitOptions; + omit?: OmitConfig; /** * Whether to allow overriding omit settings at query time. Defaults to `true`. When set to @@ -111,9 +111,9 @@ export type ClientOptions = { : {}); /** - * Options for omitting fields in ORM query results. + * Config for omitting fields in ORM query results. */ -export type OmitOptions = { +export type OmitConfig = { [Model in GetModels]?: { [Field in GetModelFields as Field extends ScalarFields ? Field : never]?: boolean; }; diff --git a/packages/orm/src/client/plugin.ts b/packages/orm/src/client/plugin.ts index cd092f4ac..ee024f1e8 100644 --- a/packages/orm/src/client/plugin.ts +++ b/packages/orm/src/client/plugin.ts @@ -1,14 +1,19 @@ import type { OperationNode, QueryId, QueryResult, RootOperationNode, UnknownRow } from 'kysely'; -import type { ClientContract } from '.'; +import type { ZodObject } from 'zod'; +import type { ClientContract, ZModelFunction } from '.'; import type { GetModels, SchemaDef } from '../schema'; import type { MaybePromise } from '../utils/type-utils'; -import type { AllCrudOperation } from './crud/operations/base'; -import type { ZModelFunction } from './options'; +import type { AllCrudOperations, CoreCrudOperations } from './crud/operations/base'; + +/** + * Base shape of plugin-extended query args. + */ +export type ExtQueryArgsBase = { [K in CoreCrudOperations | 'all']?: object }; /** * ZenStack runtime plugin. */ -export interface RuntimePlugin { +export interface RuntimePlugin { /** * Plugin ID. */ @@ -50,17 +55,26 @@ export interface RuntimePlugin { * Intercepts a Kysely query. */ onKyselyQuery?: OnKyselyQueryCallback; -} + /** + * Extended query args configuration. + */ + extQueryArgs?: { + /** + * Callback for getting a Zod schema to validate the extended query args for the given operation. + */ + getValidationSchema: (operation: CoreCrudOperations) => ZodObject | undefined; + }; +} /** * Defines a ZenStack runtime plugin. */ -export function definePlugin(plugin: RuntimePlugin) { +export function definePlugin( + plugin: RuntimePlugin, +): RuntimePlugin { return plugin; } -export { type CoreCrudOperation as CrudOperation } from './crud/operations/base'; - // #region OnProcedure hooks type OnProcedureCallback = (ctx: OnProcedureHookContext) => Promise; @@ -110,12 +124,12 @@ type OnQueryHookContext = { /** * The operation that is being performed. */ - operation: AllCrudOperation; + operation: AllCrudOperations; /** * The query arguments. */ - args: unknown; + args: Record | undefined; /** * The function to proceed with the original query. @@ -123,7 +137,7 @@ type OnQueryHookContext = { * * @param args The query arguments. */ - proceed: (args: unknown) => Promise; + proceed: (args: Record | undefined) => Promise; /** * The ZenStack client that is performing the operation. diff --git a/packages/orm/src/utils/type-utils.ts b/packages/orm/src/utils/type-utils.ts index f1ad3d35c..85152e328 100644 --- a/packages/orm/src/utils/type-utils.ts +++ b/packages/orm/src/utils/type-utils.ts @@ -88,3 +88,5 @@ export type OrUndefinedIf = Condition extends true export type UnwrapTuplePromises = { [K in keyof T]: Awaited; }; + +export type Exact = T extends Shape ? (Exclude extends never ? T : never) : never; diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 0ea84a97a..18320c69d 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -6,6 +6,7 @@ import { type BaseCrudDialect, type ClientContract, type CRUD_EXT, + type ZModelFunction, } from '@zenstackhq/orm'; import type { BinaryExpression, @@ -560,7 +561,7 @@ export class ExpressionTransformer { // check plugins for (const plugin of this.clientOptions.plugins ?? []) { if (plugin.functions?.[functionName]) { - func = plugin.functions[functionName]; + func = plugin.functions[functionName] as unknown as ZModelFunction; break; } } diff --git a/packages/plugins/policy/src/plugin.ts b/packages/plugins/policy/src/plugin.ts index b45f30bf9..e27b0cf54 100644 --- a/packages/plugins/policy/src/plugin.ts +++ b/packages/plugins/policy/src/plugin.ts @@ -3,9 +3,9 @@ import type { SchemaDef } from '@zenstackhq/orm/schema'; import { check } from './functions'; import { PolicyHandler } from './policy-handler'; -export class PolicyPlugin implements RuntimePlugin { +export class PolicyPlugin implements RuntimePlugin { get id() { - return 'policy'; + return 'policy' as const; } get name() { @@ -22,8 +22,8 @@ export class PolicyPlugin implements RuntimePlugin) { - const handler = new PolicyHandler(client); + onKyselyQuery({ query, client, proceed }: OnKyselyQueryArgs) { + const handler = new PolicyHandler(client); return handler.handle(query, proceed); } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8b686eb02..a60befaa5 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -69,6 +69,9 @@ catalogs: react-dom: specifier: 19.2.0 version: 19.2.0 + sql.js: + specifier: ^1.13.0 + version: 1.13.0 svelte: specifier: 5.45.6 version: 5.45.6 @@ -84,6 +87,9 @@ catalogs: vue: specifier: 3.5.22 version: 3.5.22 + zod: + specifier: ^4.0.0 + version: 4.1.12 zod-validation-error: specifier: ^4.0.1 version: 4.0.1 @@ -909,6 +915,9 @@ importers: kysely: specifier: 'catalog:' version: 0.28.8 + zod: + specifier: 'catalog:' + version: 4.1.12 devDependencies: '@types/better-sqlite3': specifier: 'catalog:' @@ -1034,6 +1043,9 @@ importers: uuid: specifier: ^11.0.5 version: 11.0.5 + zod: + specifier: 'catalog:' + version: 4.1.12 devDependencies: '@zenstackhq/typescript-config': specifier: workspace:* @@ -12438,7 +12450,7 @@ snapshots: eslint: 9.29.0(jiti@2.6.1) eslint-import-resolver-node: 0.3.9 eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-jsx-a11y: 6.10.2(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-react: 7.37.5(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-react-hooks: 7.0.1(eslint@9.29.0(jiti@2.6.1)) @@ -12471,7 +12483,7 @@ snapshots: tinyglobby: 0.2.15 unrs-resolver: 1.11.1 optionalDependencies: - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) transitivePeerDependencies: - supports-color @@ -12486,7 +12498,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)): + eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)): dependencies: '@rtsao/scc': 1.1.0 array-includes: 3.1.9 diff --git a/samples/orm/main.ts b/samples/orm/main.ts index 4b7f82ef1..3b9dd1294 100644 --- a/samples/orm/main.ts +++ b/samples/orm/main.ts @@ -21,18 +21,13 @@ async function main() { client.user.create({ data: { ...args }, }), - listPublicPosts: ({ client }) => - client.post.findMany({ - where: { - published: true, - }, - }), + listPublicPosts: ({}) => [], }, }).$use({ id: 'cost-logger', onQuery: async ({ model, operation, args, proceed }) => { const start = Date.now(); - const result = await proceed(args); + const result = await proceed(args as any); console.log(`[cost] ${model} ${operation} took ${Date.now() - start}ms`); return result; }, @@ -43,6 +38,8 @@ async function main() { await db.profile.deleteMany(); await db.user.deleteMany(); + db.user.findMany({ where: { id: '1' } }); + // create users and some posts const user1 = await db.user.create({ data: { diff --git a/samples/orm/package.json b/samples/orm/package.json index fd43fa6c4..b052aa9ea 100644 --- a/samples/orm/package.json +++ b/samples/orm/package.json @@ -18,7 +18,8 @@ "@zenstackhq/orm": "workspace:*", "@zenstackhq/plugin-policy": "workspace:*", "better-sqlite3": "catalog:", - "kysely": "catalog:" + "kysely": "catalog:", + "zod": "catalog:" }, "devDependencies": { "@types/better-sqlite3": "catalog:", diff --git a/tests/e2e/orm/client-api/find.test.ts b/tests/e2e/orm/client-api/find.test.ts index 0a881288e..2eddd2146 100644 --- a/tests/e2e/orm/client-api/find.test.ts +++ b/tests/e2e/orm/client-api/find.test.ts @@ -907,7 +907,7 @@ describe('Client find tests ', () => { // @ts-expect-error include: { author: { where: { email: user.email } } }, }), - ).rejects.toThrow(`Invalid find args`); + ).rejects.toThrow(`Invalid findFirst args`); // sorting let u = await client.user.findUniqueOrThrow({ diff --git a/tests/e2e/orm/plugin-infra/ext-query-args.test.ts b/tests/e2e/orm/plugin-infra/ext-query-args.test.ts new file mode 100644 index 000000000..7a5630d1a --- /dev/null +++ b/tests/e2e/orm/plugin-infra/ext-query-args.test.ts @@ -0,0 +1,273 @@ +import { CoreReadOperations, CoreWriteOperations, definePlugin, type ClientContract } from '@zenstackhq/orm'; +import { createTestClient } from '@zenstackhq/testtools'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import z from 'zod'; +import { schema } from './ext-query-args/schema'; + +describe('Plugin extended query args', () => { + let db: ClientContract; + + const cacheSchema = z.object({ + cache: z + .strictObject({ + ttl: z.number().min(0).optional(), + }) + .optional(), + }); + + const cacheBustSchema = z.object({ + cache: z.strictObject({ + bust: z.boolean().optional(), + }), + }); + + type CacheOptions = z.infer; + type CacheBustOptions = z.infer; + + beforeEach(async () => { + db = await createTestClient(schema); + await db.user.deleteMany(); + }); + + afterEach(async () => { + await db?.$disconnect(); + }); + + it('should allow extending all operations', async () => { + let gotTTL: number | undefined = undefined; + + const extDb = db.$use( + definePlugin< + typeof schema, + { + all: CacheOptions; + } + >({ + id: 'cache', + extQueryArgs: { + getValidationSchema: () => cacheSchema, + }, + + onQuery: async ({ args, proceed }) => { + if (args && 'cache' in args) { + gotTTL = (args as CacheOptions).cache?.ttl; + } + return proceed(args); + }, + }), + ); + + // cache is optional + const alice = await extDb.user.create({ data: { name: 'Alice' } }); + + // ttl is optional + const bob = await extDb.user.create({ data: { name: 'Bob' }, cache: {} }); + + gotTTL = undefined; + await expect(extDb.user.findMany({ cache: { ttl: 5000 } })).toResolveWithLength(2); + expect(gotTTL).toBe(5000); + + await expect(extDb.user.findMany({ cache: { ttl: -1 } })).rejects.toThrow('Too small'); + + // reject unrecognized keys in extended args + // @ts-expect-error + await expect(extDb.user.findMany({ cache: { x: 1 } })).rejects.toThrow('Unrecognized key'); + + // still reject invalid original args + // @ts-expect-error + await expect(extDb.user.findMany({ where: { foo: 'bar' } })).rejects.toThrow('Unrecognized key'); + // @ts-expect-error + await expect(extDb.user.findMany({ foo: 'bar' })).rejects.toThrow('Unrecognized key'); + // @ts-expect-error + await expect(extDb.user.findMany({ where: { id: 'abc' } })).rejects.toThrow('expected number'); + + // validate all other operations + + const cacheOption = { cache: { ttl: 1000 } } as const; + + // read operations + await expect(extDb.user.findUnique({ where: { id: 1 }, ...cacheOption })).toResolveTruthy(); + await expect(extDb.user.findUniqueOrThrow({ where: { id: 1 }, ...cacheOption })).toResolveTruthy(); + await expect(extDb.user.findFirst(cacheOption)).toResolveTruthy(); + await expect(extDb.user.findFirstOrThrow(cacheOption)).toResolveTruthy(); + await expect(extDb.user.count(cacheOption)).resolves.toBe(2); + await expect(extDb.user.exists(cacheOption)).resolves.toBe(true); + await expect( + extDb.user.aggregate({ + _count: true, + ...cacheOption, + }), + ).resolves.toHaveProperty('_count'); + await expect( + extDb.user.groupBy({ + by: ['id'], + _count: { + id: true, + }, + ...cacheOption, + }), + ).resolves.toHaveLength(2); + + // create operations + await expect(extDb.user.createMany({ data: [{ name: 'Charlie' }], ...cacheOption })).resolves.toHaveProperty( + 'count', + ); + await expect(extDb.user.createManyAndReturn({ data: [{ name: 'David' }], ...cacheOption })).toResolveWithLength( + 1, + ); + + // update operations + await expect( + extDb.user.update({ where: { id: alice.id }, data: { name: 'Alice Updated' }, ...cacheOption }), + ).toResolveTruthy(); + await expect( + extDb.user.updateMany({ where: { name: 'Bob' }, data: { name: 'Bob Updated' }, ...cacheOption }), + ).resolves.toHaveProperty('count'); + await expect( + extDb.user.updateManyAndReturn({ + where: { name: 'Charlie' }, + data: { name: 'Charlie Updated' }, + ...cacheOption, + }), + ).toResolveTruthy(); + await expect( + extDb.user.upsert({ + where: { id: 999 }, + create: { name: 'Eve' }, + update: { name: 'Eve Updated' }, + ...cacheOption, + }), + ).resolves.toMatchObject({ name: 'Eve' }); + + // delete operations + await expect(extDb.user.delete({ where: { id: bob.id }, ...cacheOption })).toResolveTruthy(); + await expect(extDb.user.deleteMany({ where: { name: 'David' }, ...cacheOption })).resolves.toHaveProperty( + 'count', + ); + + // validate transaction + await extDb.$transaction(async (tx) => { + await expect(tx.user.findMany(cacheOption)).toResolveTruthy(); + }); + + // validate $use + await expect(extDb.$use({ id: 'foo' }).user.findMany(cacheOption)).toResolveTruthy(); + + // validate $setOptions + await expect( + extDb.$setOptions({ ...extDb.$options, validateInput: false }).user.findMany(cacheOption), + ).toResolveTruthy(); + + // validate $setAuth + await expect(extDb.$setAuth({ id: 1 }).user.findMany(cacheOption)).toResolveTruthy(); + }); + + it('should allow extending specific operations', async () => { + const extDb = db.$use( + definePlugin< + typeof schema, + { + [Op in CoreReadOperations]: CacheOptions; + } + >({ + id: 'cache', + extQueryArgs: { + getValidationSchema: (operation) => { + if (!(CoreReadOperations as readonly string[]).includes(operation)) { + return undefined; + } + return cacheSchema; + }, + }, + }), + ); + + // "create" is not extended + // @ts-expect-error + await expect(extDb.user.create({ data: { name: 'Bob' }, cache: {} })).rejects.toThrow('Unrecognized key'); + + await extDb.user.create({ data: { name: 'Alice' } }); + + await expect(extDb.user.findMany({ cache: { ttl: 100 } })).toResolveWithLength(1); + await expect(extDb.user.count({ where: { name: 'Alice' }, cache: { ttl: 200 } })).resolves.toBe(1); + }); + + it('should allow different extensions for different operations', async () => { + let gotTTL: number | undefined = undefined; + let gotBust: boolean | undefined = undefined; + + const extDb = db.$use( + definePlugin< + typeof schema, + { + [Op in CoreReadOperations]: CacheOptions; + } & { + [Op in CoreWriteOperations]: CacheBustOptions; + } + >({ + id: 'cache', + extQueryArgs: { + getValidationSchema: (operation) => { + if ((CoreReadOperations as readonly string[]).includes(operation)) { + return cacheSchema; + } else if ((CoreWriteOperations as readonly string[]).includes(operation)) { + return cacheBustSchema; + } + return undefined; + }, + }, + + onQuery: async ({ args, proceed }) => { + if (args && 'cache' in args) { + gotTTL = (args as CacheOptions).cache?.ttl; + gotBust = (args as CacheBustOptions).cache?.bust; + } + return proceed(args); + }, + }), + ); + + gotBust = undefined; + await extDb.user.create({ data: { name: 'Alice' }, cache: { bust: true } }); + expect(gotBust).toBe(true); + + // ttl extension is not applied to "create" + // @ts-expect-error + await expect(extDb.user.create({ data: { name: 'Bob' }, cache: { ttl: 100 } })).rejects.toThrow( + 'Unrecognized key', + ); + + gotTTL = undefined; + await expect(extDb.user.findMany({ cache: { ttl: 5000 } })).toResolveWithLength(1); + expect(gotTTL).toBe(5000); + + // bust extension is not applied to "findMany" + // @ts-expect-error + await expect(extDb.user.findMany({ cache: { bust: true } })).rejects.toThrow('Unrecognized key'); + }); + + it('should isolate validation schemas between clients', async () => { + const extDb = db.$use( + definePlugin< + typeof schema, + { + all: CacheOptions; + } + >({ + id: 'cache', + extQueryArgs: { + getValidationSchema: () => cacheSchema, + }, + }), + ); + + // @ts-expect-error + await expect(db.user.findMany({ cache: { ttl: 1000 } })).rejects.toThrow('Unrecognized key'); + await expect(extDb.user.findMany({ cache: { ttl: 1000 } })).toResolveWithLength(0); + + // do it again to make sure cache is not shared + // @ts-expect-error + await expect(db.user.findMany({ cache: { ttl: 2000 } })).rejects.toThrow('Unrecognized key'); + await expect(extDb.user.findMany({ cache: { ttl: 2000 } })).toResolveWithLength(0); + }); +}); diff --git a/tests/e2e/orm/plugin-infra/ext-query-args/input.ts b/tests/e2e/orm/plugin-infra/ext-query-args/input.ts new file mode 100644 index 000000000..22bdbfa73 --- /dev/null +++ b/tests/e2e/orm/plugin-infra/ext-query-args/input.ts @@ -0,0 +1,31 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaType as $Schema } from "./schema"; +import type { FindManyArgs as $FindManyArgs, FindUniqueArgs as $FindUniqueArgs, FindFirstArgs as $FindFirstArgs, ExistsArgs as $ExistsArgs, CreateArgs as $CreateArgs, CreateManyArgs as $CreateManyArgs, CreateManyAndReturnArgs as $CreateManyAndReturnArgs, UpdateArgs as $UpdateArgs, UpdateManyArgs as $UpdateManyArgs, UpdateManyAndReturnArgs as $UpdateManyAndReturnArgs, UpsertArgs as $UpsertArgs, DeleteArgs as $DeleteArgs, DeleteManyArgs as $DeleteManyArgs, CountArgs as $CountArgs, AggregateArgs as $AggregateArgs, GroupByArgs as $GroupByArgs, WhereInput as $WhereInput, SelectInput as $SelectInput, IncludeInput as $IncludeInput, OmitInput as $OmitInput, QueryOptions as $QueryOptions } from "@zenstackhq/orm"; +import type { SimplifiedPlainResult as $Result, SelectIncludeOmit as $SelectIncludeOmit } from "@zenstackhq/orm"; +export type UserFindManyArgs = $FindManyArgs<$Schema, "User">; +export type UserFindUniqueArgs = $FindUniqueArgs<$Schema, "User">; +export type UserFindFirstArgs = $FindFirstArgs<$Schema, "User">; +export type UserExistsArgs = $ExistsArgs<$Schema, "User">; +export type UserCreateArgs = $CreateArgs<$Schema, "User">; +export type UserCreateManyArgs = $CreateManyArgs<$Schema, "User">; +export type UserCreateManyAndReturnArgs = $CreateManyAndReturnArgs<$Schema, "User">; +export type UserUpdateArgs = $UpdateArgs<$Schema, "User">; +export type UserUpdateManyArgs = $UpdateManyArgs<$Schema, "User">; +export type UserUpdateManyAndReturnArgs = $UpdateManyAndReturnArgs<$Schema, "User">; +export type UserUpsertArgs = $UpsertArgs<$Schema, "User">; +export type UserDeleteArgs = $DeleteArgs<$Schema, "User">; +export type UserDeleteManyArgs = $DeleteManyArgs<$Schema, "User">; +export type UserCountArgs = $CountArgs<$Schema, "User">; +export type UserAggregateArgs = $AggregateArgs<$Schema, "User">; +export type UserGroupByArgs = $GroupByArgs<$Schema, "User">; +export type UserWhereInput = $WhereInput<$Schema, "User">; +export type UserSelect = $SelectInput<$Schema, "User">; +export type UserInclude = $IncludeInput<$Schema, "User">; +export type UserOmit = $OmitInput<$Schema, "User">; +export type UserGetPayload, Options extends $QueryOptions<$Schema> = $QueryOptions<$Schema>> = $Result<$Schema, "User", Args, Options>; diff --git a/tests/e2e/orm/plugin-infra/ext-query-args/models.ts b/tests/e2e/orm/plugin-infra/ext-query-args/models.ts new file mode 100644 index 000000000..7a605bdbc --- /dev/null +++ b/tests/e2e/orm/plugin-infra/ext-query-args/models.ts @@ -0,0 +1,10 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaType as $Schema } from "./schema"; +import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +export type User = $ModelResult<$Schema, "User">; diff --git a/tests/e2e/orm/plugin-infra/ext-query-args/schema.ts b/tests/e2e/orm/plugin-infra/ext-query-args/schema.ts new file mode 100644 index 000000000..a8f0ffb86 --- /dev/null +++ b/tests/e2e/orm/plugin-infra/ext-query-args/schema.ts @@ -0,0 +1,38 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaDef, ExpressionUtils } from "@zenstackhq/orm/schema"; +export class SchemaType implements SchemaDef { + provider = { + type: "sqlite" + } as const; + models = { + User: { + name: "User", + fields: { + id: { + name: "id", + type: "Int", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("autoincrement") }] }], + default: ExpressionUtils.call("autoincrement") + }, + name: { + name: "name", + type: "String" + } + }, + idFields: ["id"], + uniqueFields: { + id: { type: "Int" } + } + } + } as const; + authType = "User" as const; + plugins = {}; +} +export const schema = new SchemaType(); diff --git a/tests/e2e/orm/plugin-infra/ext-query-args/schema.zmodel b/tests/e2e/orm/plugin-infra/ext-query-args/schema.zmodel new file mode 100644 index 000000000..2e4b0dc03 --- /dev/null +++ b/tests/e2e/orm/plugin-infra/ext-query-args/schema.zmodel @@ -0,0 +1,8 @@ +datasource db { + provider = "sqlite" +} + +model User { + id Int @id @default(autoincrement()) + name String +} diff --git a/tests/e2e/orm/policy/basic-schema-read.test.ts b/tests/e2e/orm/policy/basic-schema-read.test.ts index 7464a38ce..d1bba66d6 100644 --- a/tests/e2e/orm/policy/basic-schema-read.test.ts +++ b/tests/e2e/orm/policy/basic-schema-read.test.ts @@ -23,7 +23,7 @@ describe('Read policy tests', () => { }); // anonymous auth context by default - const anonClient = client.$use(new PolicyPlugin()); + const anonClient = client.$use(new PolicyPlugin()); await expect(anonClient.user.findFirst()).toResolveNull(); const authClient = anonClient.$setAuth({ diff --git a/tests/e2e/package.json b/tests/e2e/package.json index c24650ba1..5022cf2e8 100644 --- a/tests/e2e/package.json +++ b/tests/e2e/package.json @@ -25,7 +25,8 @@ "kysely": "catalog:", "ulid": "^3.0.0", "uuid": "^11.0.5", - "cuid": "^3.0.0" + "cuid": "^3.0.0", + "zod": "catalog:" }, "devDependencies": { "@zenstackhq/cli": "workspace:*",