Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions packages/orm/src/client/client-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import { ZenStackQueryExecutor } from './executor/zenstack-query-executor';
import * as BuiltinFunctions from './functions';
import { SchemaDbPusher } from './helpers/schema-db-pusher';
import type { ClientOptions, ProceduresOptions } from './options';
import type { RuntimePlugin } from './plugin';
import type { AnyPlugin } from './plugin';
import { createZenStackPromise, type ZenStackPromise } from './promise';
import { ResultProcessor } from './result-processor';

Expand Down Expand Up @@ -293,8 +293,8 @@ export class ClientImpl {
await new SchemaDbPusher(this.schema, this.kysely).push();
}

$use(plugin: RuntimePlugin<any, any>) {
const newPlugins: RuntimePlugin<any, any>[] = [...(this.$options.plugins ?? []), plugin];
$use(plugin: AnyPlugin) {
const newPlugins: AnyPlugin[] = [...(this.$options.plugins ?? []), plugin];
const newOptions: ClientOptions<SchemaDef> = {
...this.options,
plugins: newPlugins,
Expand All @@ -308,7 +308,7 @@ export class ClientImpl {

$unuse(pluginId: string) {
// tsc perf
const newPlugins: RuntimePlugin<any, any>[] = [];
const newPlugins: AnyPlugin[] = [];
for (const plugin of this.options.plugins ?? []) {
if (plugin.id !== pluginId) {
newPlugins.push(plugin);
Expand All @@ -329,7 +329,7 @@ export class ClientImpl {
// tsc perf
const newOptions: ClientOptions<SchemaDef> = {
...this.options,
plugins: [] as RuntimePlugin<any, any>[],
plugins: [] as AnyPlugin[],
};
const newClient = new ClientImpl(this.schema, newOptions, this);
// create a new validator to have a fresh schema cache, because plugins may
Expand Down Expand Up @@ -408,6 +408,16 @@ function createClientProxy(client: ClientImpl): ClientImpl {
return new Proxy(client, {
get: (target, prop, receiver) => {
if (typeof prop === 'string' && prop.startsWith('$')) {
// Check for plugin-provided members (search in reverse order so later plugins win)
const plugins = target.$options.plugins ?? [];
for (let i = plugins.length - 1; i >= 0; i--) {
const plugin = plugins[i];
const clientMembers = plugin?.client as Record<string, unknown> | undefined;
if (clientMembers && prop in clientMembers) {
return clientMembers[prop];
}
}
// Fall through to built-in $ methods
return Reflect.get(target, prop, receiver);
}

Expand Down
64 changes: 47 additions & 17 deletions packages/orm/src/client/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,15 @@ import type {
UpdateManyArgs,
UpsertArgs,
} from './crud-types';
import type { CoreCrudOperations } from './crud/operations/base';
import type {
CoreCreateOperations,
CoreCrudOperations,
CoreDeleteOperations,
CoreReadOperations,
CoreUpdateOperations,
} from './crud/operations/base';
import type { ClientOptions, QueryOptions, ToQueryOptions } from './options';
import type { ExtQueryArgsBase, RuntimePlugin } from './plugin';
import type { ExtClientMembersBase, ExtQueryArgsBase, RuntimePlugin } from './plugin';
import type { ZenStackPromise } from './promise';
import type { ToKysely } from './query-builder';

Expand All @@ -51,11 +57,26 @@ type TransactionUnsupportedMethods = (typeof TRANSACTION_UNSUPPORTED_METHODS)[nu
/**
* Extracts extended query args for a specific operation.
*/
type ExtractExtQueryArgs<ExtQueryArgs, Operation extends CoreCrudOperations> = Operation extends keyof ExtQueryArgs
? NonNullable<ExtQueryArgs[Operation]>
: 'all' extends keyof ExtQueryArgs
? NonNullable<ExtQueryArgs['all']>
: {};
type ExtractExtQueryArgs<ExtQueryArgs, Operation extends CoreCrudOperations> = (Operation extends keyof ExtQueryArgs
? ExtQueryArgs[Operation]
: {}) &
('$create' extends keyof ExtQueryArgs
? Operation extends CoreCreateOperations
? ExtQueryArgs['$create']
: {}
: {}) &
('$read' extends keyof ExtQueryArgs ? (Operation extends CoreReadOperations ? ExtQueryArgs['$read'] : {}) : {}) &
('$update' extends keyof ExtQueryArgs
? Operation extends CoreUpdateOperations
? ExtQueryArgs['$update']
: {}
: {}) &
('$delete' extends keyof ExtQueryArgs
? Operation extends CoreDeleteOperations
? ExtQueryArgs['$delete']
: {}
: {}) &
('$all' extends keyof ExtQueryArgs ? ExtQueryArgs['$all'] : {});

/**
* Transaction isolation levels.
Expand All @@ -75,6 +96,7 @@ export type ClientContract<
Schema extends SchemaDef,
Options extends ClientOptions<Schema> = ClientOptions<Schema>,
ExtQueryArgs extends ExtQueryArgsBase = {},
ExtClientMembers extends ExtClientMembersBase = {},
> = {
/**
* The schema definition.
Expand Down Expand Up @@ -132,7 +154,7 @@ export type ClientContract<
/**
* Sets the current user identity.
*/
$setAuth(auth: AuthType<Schema> | undefined): ClientContract<Schema, Options, ExtQueryArgs>;
$setAuth(auth: AuthType<Schema> | undefined): ClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>;

/**
* Returns a new client with new options applied.
Expand All @@ -141,15 +163,17 @@ export type ClientContract<
* const dbNoValidation = db.$setOptions({ ...db.$options, validateInput: false });
* ```
*/
$setOptions<Options extends ClientOptions<Schema>>(options: Options): ClientContract<Schema, Options, ExtQueryArgs>;
$setOptions<NewOptions extends ClientOptions<Schema>>(
options: NewOptions,
): ClientContract<Schema, NewOptions, ExtQueryArgs, ExtClientMembers>;

/**
* Returns a new client enabling/disabling input validations expressed with attributes like
* `@email`, `@regex`, `@@validate`, etc.
*
* @deprecated Use {@link $setOptions} instead.
*/
$setInputValidation(enable: boolean): ClientContract<Schema, Options, ExtQueryArgs>;
$setInputValidation(enable: boolean): ClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>;

/**
* The Kysely query builder instance.
Expand All @@ -165,7 +189,7 @@ export type ClientContract<
* Starts an interactive transaction.
*/
$transaction<T>(
callback: (tx: TransactionClientContract<Schema, Options, ExtQueryArgs>) => Promise<T>,
callback: (tx: TransactionClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>) => Promise<T>,
options?: { isolationLevel?: TransactionIsolationLevel },
): Promise<T>;

Expand All @@ -180,14 +204,18 @@ export type ClientContract<
/**
* Returns a new client with the specified plugin installed.
*/
$use<PluginSchema extends SchemaDef = Schema, PluginExtQueryArgs extends ExtQueryArgsBase = {}>(
plugin: RuntimePlugin<PluginSchema, PluginExtQueryArgs>,
): ClientContract<Schema, Options, ExtQueryArgs & PluginExtQueryArgs>;
$use<
PluginSchema extends SchemaDef = Schema,
PluginExtQueryArgs extends ExtQueryArgsBase = {},
PluginExtClientMembers extends ExtClientMembersBase = {},
>(
plugin: RuntimePlugin<PluginSchema, PluginExtQueryArgs, PluginExtClientMembers>,
): ClientContract<Schema, Options, ExtQueryArgs & PluginExtQueryArgs, ExtClientMembers & PluginExtClientMembers>;

/**
* Returns a new client with the specified plugin removed.
*/
$unuse(pluginId: string): ClientContract<Schema, Options, ExtQueryArgs>;
$unuse(pluginId: string): ClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>;

/**
* Returns a new client with all plugins removed.
Expand Down Expand Up @@ -216,7 +244,8 @@ export type ClientContract<
ToQueryOptions<Options>,
ExtQueryArgs
>;
} & ProcedureOperations<Schema>;
} & ProcedureOperations<Schema> &
ExtClientMembers;

/**
* The contract for a client in a transaction.
Expand All @@ -225,7 +254,8 @@ export type TransactionClientContract<
Schema extends SchemaDef,
Options extends ClientOptions<Schema>,
ExtQueryArgs extends ExtQueryArgsBase,
> = Omit<ClientContract<Schema, Options, ExtQueryArgs>, TransactionUnsupportedMethods>;
ExtClientMembers extends ExtClientMembersBase,
> = Omit<ClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>, TransactionUnsupportedMethods>;

export type ProcedureOperations<Schema extends SchemaDef> =
Schema['procedures'] extends Record<string, ProcedureDef>
Expand Down
30 changes: 30 additions & 0 deletions packages/orm/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,36 @@ export const CoreWriteOperations = [
*/
export type CoreWriteOperations = (typeof CoreWriteOperations)[number];

/**
* List of core create operations.
*/
export const CoreCreateOperations = ['create', 'createMany', 'createManyAndReturn', 'upsert'] as const;

/**
* List of core create operations.
*/
export type CoreCreateOperations = (typeof CoreCreateOperations)[number];

/**
* List of core update operations.
*/
export const CoreUpdateOperations = ['update', 'updateMany', 'updateManyAndReturn', 'upsert'] as const;

/**
* List of core update operations.
*/
export type CoreUpdateOperations = (typeof CoreUpdateOperations)[number];

/**
* List of core delete operations.
*/
export const CoreDeleteOperations = ['delete', 'deleteMany'] as const;

/**
* List of core delete operations.
*/
export type CoreDeleteOperations = (typeof CoreDeleteOperations)[number];

/**
* List of all CRUD operations, including 'orThrow' variants.
*/
Expand Down
84 changes: 81 additions & 3 deletions packages/orm/src/client/crud/validator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
type UpsertArgs,
} from '../../crud-types';
import { createInternalError, createInvalidInputError } from '../../errors';
import type { AnyPlugin } from '../../plugin';
import {
fieldHasDefaultValue,
getDiscriminatorField,
Expand All @@ -46,7 +47,13 @@ import {
requireField,
requireModel,
} from '../../query-utils';
import type { CoreCrudOperations } from '../operations/base';
import {
CoreCreateOperations,
CoreDeleteOperations,
CoreReadOperations,
CoreUpdateOperations,
type CoreCrudOperations,
} from '../operations/base';
import {
addBigIntValidation,
addCustomValidation,
Expand Down Expand Up @@ -365,8 +372,8 @@ export class InputValidator<Schema extends SchemaDef> {
private mergePluginArgsSchema(schema: ZodObject, operation: CoreCrudOperations) {
let result = schema;
for (const plugin of this.options.plugins ?? []) {
if (plugin.extQueryArgs) {
const pluginSchema = plugin.extQueryArgs.getValidationSchema(operation);
if (plugin.queryArgs) {
const pluginSchema = this.getPluginExtQueryArgsSchema(plugin, operation);
if (pluginSchema) {
result = result.extend(pluginSchema.shape);
}
Expand All @@ -375,6 +382,77 @@ export class InputValidator<Schema extends SchemaDef> {
return result.strict();
}

private getPluginExtQueryArgsSchema(plugin: AnyPlugin, operation: string): ZodObject | undefined {
if (!plugin.queryArgs) {
return undefined;
}

let result: ZodType | undefined;

if (operation in plugin.queryArgs && plugin.queryArgs[operation]) {
// most specific operation takes highest precedence
result = plugin.queryArgs[operation];
} else if (operation === 'upsert') {
// upsert is special: it's in both CoreCreateOperations and CoreUpdateOperations
// so we need to merge both $create and $update schemas to match the type system
const createSchema =
'$create' in plugin.queryArgs && plugin.queryArgs['$create'] ? plugin.queryArgs['$create'] : undefined;
const updateSchema =
'$update' in plugin.queryArgs && plugin.queryArgs['$update'] ? plugin.queryArgs['$update'] : undefined;

if (createSchema && updateSchema) {
invariant(
createSchema instanceof z.ZodObject,
'Plugin extended query args schema must be a Zod object',
);
invariant(
updateSchema instanceof z.ZodObject,
'Plugin extended query args schema must be a Zod object',
);
// merge both schemas (combines their properties)
result = createSchema.extend(updateSchema.shape);
} else if (createSchema) {
result = createSchema;
} else if (updateSchema) {
result = updateSchema;
}
} else if (
// then comes grouped operations: $create, $read, $update, $delete
CoreCreateOperations.includes(operation as CoreCreateOperations) &&
'$create' in plugin.queryArgs &&
plugin.queryArgs['$create']
) {
result = plugin.queryArgs['$create'];
} else if (
CoreReadOperations.includes(operation as CoreReadOperations) &&
'$read' in plugin.queryArgs &&
plugin.queryArgs['$read']
) {
result = plugin.queryArgs['$read'];
} else if (
CoreUpdateOperations.includes(operation as CoreUpdateOperations) &&
'$update' in plugin.queryArgs &&
plugin.queryArgs['$update']
) {
result = plugin.queryArgs['$update'];
} else if (
CoreDeleteOperations.includes(operation as CoreDeleteOperations) &&
'$delete' in plugin.queryArgs &&
plugin.queryArgs['$delete']
) {
result = plugin.queryArgs['$delete'];
} else if ('$all' in plugin.queryArgs && plugin.queryArgs['$all']) {
// finally comes $all
result = plugin.queryArgs['$all'];
}

invariant(
result === undefined || result instanceof z.ZodObject,
'Plugin extended query args schema must be a Zod object',
);
return result;
}

// #region Find

private makeFindSchema(model: string, operation: CoreCrudOperations) {
Expand Down
3 changes: 3 additions & 0 deletions packages/orm/src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ export { BaseCrudDialect } from './crud/dialects/base-dialect';
export {
AllCrudOperations,
AllReadOperations,
CoreCreateOperations,
CoreCrudOperations,
CoreDeleteOperations,
CoreReadOperations,
CoreUpdateOperations,
CoreWriteOperations,
} from './crud/operations/base';
export { InputValidator } from './crud/validator';
Expand Down
4 changes: 2 additions & 2 deletions packages/orm/src/client/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import type { PrependParameter } from '../utils/type-utils';
import type { ClientContract, CRUD_EXT } from './contract';
import type { GetProcedureNames, ProcedureHandlerFunc } from './crud-types';
import type { BaseCrudDialect } from './crud/dialects/base-dialect';
import type { RuntimePlugin } from './plugin';
import type { AnyPlugin } from './plugin';
import type { ToKyselySchema } from './query-builder';

export type ZModelFunctionContext<Schema extends SchemaDef> = {
Expand Down Expand Up @@ -59,7 +59,7 @@ export type ClientOptions<Schema extends SchemaDef> = {
/**
* Plugins.
*/
plugins?: RuntimePlugin<any, any>[];
plugins?: AnyPlugin[];

/**
* Logging configuration.
Expand Down
Loading