diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 053ae9ff..632042c5 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -31,10 +31,23 @@ jobs: ports: - 5432:5432 + mysql: + image: mysql:8.4 + env: + MYSQL_ROOT_PASSWORD: mysql + ports: + - 3306:3306 + # Set health checks to wait until mysql has started + options: >- + --health-cmd="mysqladmin ping --silent" + --health-interval=10s + --health-timeout=5s + --health-retries=3 + strategy: matrix: node-version: [22.x] - provider: [sqlite, postgresql] + provider: [sqlite, postgresql, mysql] steps: - name: Checkout @@ -81,5 +94,9 @@ jobs: - name: Lint run: pnpm run lint + - name: Set MySQL max_connections + run: | + mysql -h 127.0.0.1 -uroot -pmysql -e "SET GLOBAL max_connections=500;" + - name: Test run: TEST_DB_PROVIDER=${{ matrix.provider }} pnpm run test diff --git a/TODO.md b/TODO.md index 7c7a767c..d36f9699 100644 --- a/TODO.md +++ b/TODO.md @@ -18,7 +18,6 @@ - [ ] ZModel - [x] Import - [ ] View support - - [ ] Datasource provider-scoped attributes - [ ] ORM - [x] Create - [x] Input validation @@ -72,7 +71,7 @@ - [x] Query builder API - [x] Computed fields - [x] Plugin - - [ ] Custom procedures + - [x] Custom procedures - [ ] Misc - [x] JSDoc for CRUD methods - [x] Cache validation schemas @@ -110,4 +109,4 @@ - [x] SQLite - [x] PostgreSQL - [x] Multi-schema - - [ ] MySQL + - [x] MySQL diff --git a/package.json b/package.json index 639198d0..ba1c7479 100644 --- a/package.json +++ b/package.json @@ -9,8 +9,9 @@ "watch": "turbo run watch build", "lint": "turbo run lint", "test": "turbo run test", - "test:all": "pnpm run test:sqlite && pnpm run test:pg", + "test:all": "pnpm run test:sqlite && pnpm run test:pg && pnpm run test:mysql", "test:pg": "TEST_DB_PROVIDER=postgresql turbo run test", + "test:mysql": "TEST_DB_PROVIDER=mysql turbo run test", "test:sqlite": "TEST_DB_PROVIDER=sqlite turbo run test", "test:coverage": "vitest run --coverage", "format": "prettier --write \"**/*.{ts,tsx,md}\"", diff --git a/packages/language/src/constants.ts b/packages/language/src/constants.ts index a39b4913..9fd79bf7 100644 --- a/packages/language/src/constants.ts +++ b/packages/language/src/constants.ts @@ -1,14 +1,7 @@ /** * Supported db providers */ -export const SUPPORTED_PROVIDERS = [ - 'sqlite', - 'postgresql', - // TODO: other providers - // 'mysql', - // 'sqlserver', - // 'cockroachdb', -]; +export const SUPPORTED_PROVIDERS = ['sqlite', 'postgresql', 'mysql']; /** * All scalar types @@ -41,3 +34,8 @@ export enum ExpressionContext { ValidationRule = 'ValidationRule', Index = 'Index', } + +/** + * Database providers that support list field types. + */ +export const DB_PROVIDERS_SUPPORTING_LIST_TYPE = ['postgresql']; diff --git a/packages/language/src/document.ts b/packages/language/src/document.ts index 17146f85..9642e61d 100644 --- a/packages/language/src/document.ts +++ b/packages/language/src/document.ts @@ -1,3 +1,4 @@ +import { invariant } from '@zenstackhq/common-helpers'; import { isAstNode, TextDocument, @@ -10,10 +11,18 @@ import { import fs from 'node:fs'; import path from 'node:path'; import { fileURLToPath } from 'node:url'; -import { isDataSource, type Model } from './ast'; -import { STD_LIB_MODULE_NAME } from './constants'; +import { isDataModel, isDataSource, type Model } from './ast'; +import { DB_PROVIDERS_SUPPORTING_LIST_TYPE, STD_LIB_MODULE_NAME } from './constants'; import { createZModelServices } from './module'; -import { getDataModelAndTypeDefs, getDocument, hasAttribute, resolveImport, resolveTransitiveImports } from './utils'; +import { + getAllFields, + getDataModelAndTypeDefs, + getDocument, + getLiteral, + hasAttribute, + resolveImport, + resolveTransitiveImports, +} from './utils'; import type { ZModelFormatter } from './zmodel-formatter'; /** @@ -207,6 +216,24 @@ function validationAfterImportMerge(model: Model) { if (authDecls.length > 1) { errors.push('Validation error: Multiple `@@auth` declarations are not allowed'); } + + // check for usages incompatible with the datasource provider + const provider = getDataSourceProvider(model); + invariant(provider !== undefined, 'Datasource provider should be defined at this point'); + + for (const decl of model.declarations.filter(isDataModel)) { + const fields = getAllFields(decl, true); + for (const field of fields) { + if (field.type.array && !isDataModel(field.type.reference?.ref)) { + if (!DB_PROVIDERS_SUPPORTING_LIST_TYPE.includes(provider)) { + errors.push( + `Validation error: List type is not supported for "${provider}" provider (model: "${decl.name}", field: "${field.name}")`, + ); + } + } + } + } + return errors; } @@ -226,3 +253,15 @@ export async function formatDocument(content: string) { const edits = await formatter.formatDocument(document, { options, textDocument: identifier }); return TextDocument.applyEdits(document.textDocument, edits); } + +function getDataSourceProvider(model: Model) { + const dataSource = model.declarations.find(isDataSource); + if (!dataSource) { + return undefined; + } + const provider = dataSource?.fields.find((f) => f.name === 'provider'); + if (!provider) { + return undefined; + } + return getLiteral(provider.value); +} diff --git a/packages/language/src/validators/datamodel-validator.ts b/packages/language/src/validators/datamodel-validator.ts index a1caba5d..6c5d18ff 100644 --- a/packages/language/src/validators/datamodel-validator.ts +++ b/packages/language/src/validators/datamodel-validator.ts @@ -5,20 +5,16 @@ import { ArrayExpr, DataField, DataModel, - Model, ReferenceExpr, TypeDef, isDataModel, - isDataSource, isEnum, - isModel, isStringLiteral, isTypeDef, } from '../generated/ast'; import { getAllAttributes, getAllFields, - getLiteral, getModelIdFields, getModelUniqueFields, getUniqueFields, @@ -105,13 +101,6 @@ export default class DataModelValidator implements AstValidator { accept('error', 'Unsupported type argument must be a string literal', { node: field.type.unsupported }); } - if (field.type.array && !isDataModel(field.type.reference?.ref)) { - const provider = this.getDataSourceProvider(AstUtils.getContainerOfType(field, isModel)!); - if (provider === 'sqlite') { - accept('error', `List type is not supported for "${provider}" provider.`, { node: field.type }); - } - } - field.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); if (isTypeDef(field.type.reference?.ref)) { @@ -121,18 +110,6 @@ export default class DataModelValidator implements AstValidator { } } - private getDataSourceProvider(model: Model) { - const dataSource = model.declarations.find(isDataSource); - if (!dataSource) { - return undefined; - } - const provider = dataSource?.fields.find((f) => f.name === 'provider'); - if (!provider) { - return undefined; - } - return getLiteral(provider.value); - } - private validateAttributes(dm: DataModel, accept: ValidationAcceptor) { getAllAttributes(dm).forEach((attr) => validateAttributeApplication(attr, accept, dm)); } diff --git a/packages/orm/package.json b/packages/orm/package.json index ccd94a9c..eb5ee668 100644 --- a/packages/orm/package.json +++ b/packages/orm/package.json @@ -56,6 +56,16 @@ "default": "./dist/dialects/postgres.cjs" } }, + "./dialects/mysql": { + "import": { + "types": "./dist/dialects/mysql.d.ts", + "default": "./dist/dialects/mysql.js" + }, + "require": { + "types": "./dist/dialects/mysql.d.cts", + "default": "./dist/dialects/mysql.cjs" + } + }, "./dialects/sql.js": { "import": { "types": "./dist/dialects/sql.js.d.ts", @@ -100,6 +110,7 @@ "peerDependencies": { "better-sqlite3": "catalog:", "pg": "catalog:", + "mysql2": "catalog:", "sql.js": "catalog:", "zod": "catalog:" }, @@ -110,6 +121,9 @@ "pg": { "optional": true }, + "mysql2": { + "optional": true + }, "sql.js": { "optional": true } diff --git a/packages/orm/src/client/client-impl.ts b/packages/orm/src/client/client-impl.ts index 719f90f3..fc8f92c7 100644 --- a/packages/orm/src/client/client-impl.ts +++ b/packages/orm/src/client/client-impl.ts @@ -31,7 +31,7 @@ import { FindOperationHandler } from './crud/operations/find'; import { GroupByOperationHandler } from './crud/operations/group-by'; import { UpdateOperationHandler } from './crud/operations/update'; import { InputValidator } from './crud/validator'; -import { createConfigError, createNotFoundError } from './errors'; +import { createConfigError, createNotFoundError, createNotSupportedError } from './errors'; import { ZenStackDriver } from './executor/zenstack-driver'; import { ZenStackQueryExecutor } from './executor/zenstack-query-executor'; import * as BuiltinFunctions from './functions'; @@ -564,6 +564,11 @@ function createModelCrudHandler( }, createManyAndReturn: (args: unknown) => { + if (client.$schema.provider.type === 'mysql') { + throw createNotSupportedError( + '"createManyAndReturn" is not supported by "mysql" provider. Use "createMany" or multiple "create" calls instead.', + ); + } return createPromise( 'createManyAndReturn', 'createManyAndReturn', @@ -594,6 +599,11 @@ function createModelCrudHandler( }, updateManyAndReturn: (args: unknown) => { + if (client.$schema.provider.type === 'mysql') { + throw createNotSupportedError( + '"updateManyAndReturn" is not supported by "mysql" provider. Use "updateMany" or multiple "update" calls instead.', + ); + } return createPromise( 'updateManyAndReturn', 'updateManyAndReturn', diff --git a/packages/orm/src/client/contract.ts b/packages/orm/src/client/contract.ts index 945f3645..7492c02b 100644 --- a/packages/orm/src/client/contract.ts +++ b/packages/orm/src/client/contract.ts @@ -306,6 +306,77 @@ export type AllModelOperations< Model extends GetModels, Options extends QueryOptions, ExtQueryArgs, +> = CommonModelOperations & + // provider-specific operations + (Schema['provider']['type'] extends 'mysql' + ? {} + : { + /** + * Creates multiple entities and returns them. + * @param args - create args. See {@link createMany} for input. Use + * `select` and `omit` to control the fields returned. + * @returns the created entities + * + * @example + * ```ts + * // create multiple entities and return selected fields + * await db.user.createManyAndReturn({ + * data: [ + * { name: 'Alex', email: 'alex@zenstack.dev' }, + * { name: 'John', email: 'john@zenstack.dev' } + * ], + * select: { id: true, email: true } + * }); + * ``` + */ + createManyAndReturn< + T extends CreateManyAndReturnArgs & + ExtractExtQueryArgs, + >( + args?: SelectSubset< + T, + CreateManyAndReturnArgs & ExtractExtQueryArgs + >, + ): ZenStackPromise[]>; + + /** + * Updates multiple entities and returns them. + * @param args - update args. Only scalar fields are allowed for data. + * @returns the updated entities + * + * @example + * ```ts + * // update many entities and return selected fields + * await db.user.updateManyAndReturn({ + * where: { email: { endsWith: '@zenstack.dev' } }, + * data: { role: 'ADMIN' }, + * select: { id: true, email: true } + * }); // result: `Array<{ id: string; email: string }>` + * + * // limit the number of updated entities + * await db.user.updateManyAndReturn({ + * where: { email: { endsWith: '@zenstack.dev' } }, + * data: { role: 'ADMIN' }, + * limit: 10 + * }); + * ``` + */ + updateManyAndReturn< + T extends UpdateManyAndReturnArgs & + ExtractExtQueryArgs, + >( + args: Subset< + T, + UpdateManyAndReturnArgs & ExtractExtQueryArgs + >, + ): ZenStackPromise[]>; + }); + +type CommonModelOperations< + Schema extends SchemaDef, + Model extends GetModels, + Options extends QueryOptions, + ExtQueryArgs, > = { /** * Returns a list of entities. @@ -517,33 +588,6 @@ export type AllModelOperations< args?: SelectSubset & ExtractExtQueryArgs>, ): ZenStackPromise; - /** - * Creates multiple entities and returns them. - * @param args - create args. See {@link createMany} for input. Use - * `select` and `omit` to control the fields returned. - * @returns the created entities - * - * @example - * ```ts - * // create multiple entities and return selected fields - * await db.user.createManyAndReturn({ - * data: [ - * { name: 'Alex', email: 'alex@zenstack.dev' }, - * { name: 'John', email: 'john@zenstack.dev' } - * ], - * select: { id: true, email: true } - * }); - * ``` - */ - createManyAndReturn< - T extends CreateManyAndReturnArgs & ExtractExtQueryArgs, - >( - args?: SelectSubset< - T, - CreateManyAndReturnArgs & ExtractExtQueryArgs - >, - ): ZenStackPromise[]>; - /** * Updates a uniquely identified entity. * @param args - update args. See {@link findMany} for how to control @@ -689,37 +733,6 @@ export type AllModelOperations< args: Subset & ExtractExtQueryArgs>, ): ZenStackPromise; - /** - * Updates multiple entities and returns them. - * @param args - update args. Only scalar fields are allowed for data. - * @returns the updated entities - * - * @example - * ```ts - * // update many entities and return selected fields - * await db.user.updateManyAndReturn({ - * where: { email: { endsWith: '@zenstack.dev' } }, - * data: { role: 'ADMIN' }, - * select: { id: true, email: true } - * }); // result: `Array<{ id: string; email: string }>` - * - * // limit the number of updated entities - * await db.user.updateManyAndReturn({ - * where: { email: { endsWith: '@zenstack.dev' } }, - * data: { role: 'ADMIN' }, - * limit: 10 - * }); - * ``` - */ - updateManyAndReturn< - T extends UpdateManyAndReturnArgs & ExtractExtQueryArgs, - >( - args: Subset< - T, - UpdateManyAndReturnArgs & ExtractExtQueryArgs - >, - ): ZenStackPromise[]>; - /** * Creates or updates an entity. * @param args - upsert args diff --git a/packages/orm/src/client/crud-types.ts b/packages/orm/src/client/crud-types.ts index 93f770c6..bfee7ff1 100644 --- a/packages/orm/src/client/crud-types.ts +++ b/packages/orm/src/client/crud-types.ts @@ -1079,12 +1079,15 @@ export type FindArgs< Collection extends boolean, AllowFilter extends boolean = true, > = (Collection extends true - ? SortAndTakeArgs & { - /** - * Distinct fields - */ - distinct?: OrArray>; - } + ? SortAndTakeArgs & + (ProviderSupportsDistinct extends true + ? { + /** + * Distinct fields. Only supported by providers that natively support SQL "DISTINCT ON". + */ + distinct?: OrArray>; + } + : {}) : {}) & (AllowFilter extends true ? FilterArgs : {}) & SelectIncludeOmit; @@ -2108,8 +2111,8 @@ type MapType = T extends keyof TypeM ? EnumValue : unknown; -// type ProviderSupportsDistinct = Schema['provider']['type'] extends 'postgresql' -// ? true -// : false; +type ProviderSupportsDistinct = Schema['provider']['type'] extends 'postgresql' + ? true + : false; // #endregion diff --git a/packages/orm/src/client/crud/dialects/base-dialect.ts b/packages/orm/src/client/crud/dialects/base-dialect.ts index f4bf56ff..515878d9 100644 --- a/packages/orm/src/client/crud/dialects/base-dialect.ts +++ b/packages/orm/src/client/crud/dialects/base-dialect.ts @@ -1,5 +1,5 @@ import { enumerate, invariant, isPlainObject } from '@zenstackhq/common-helpers'; -import type { Expression, ExpressionBuilder, ExpressionWrapper, SqlBool, ValueNode } from 'kysely'; +import type { AliasableExpression, Expression, ExpressionBuilder, ExpressionWrapper, SqlBool, ValueNode } from 'kysely'; import { expressionBuilder, sql, type SelectQueryBuilder } from 'kysely'; import { match, P } from 'ts-pattern'; import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types'; @@ -44,14 +44,63 @@ export abstract class BaseCrudDialect { protected readonly options: ClientOptions, ) {} - transformPrimitive(value: unknown, _type: BuiltinType, _forArrayField: boolean) { + // #region capability flags + + /** + * Whether the dialect supports updating with a limit on the number of updated rows. + */ + abstract get supportsUpdateWithLimit(): boolean; + + /** + * Whether the dialect supports deleting with a limit on the number of deleted rows. + */ + abstract get supportsDeleteWithLimit(): boolean; + + /** + * Whether the dialect supports DISTINCT ON. + */ + abstract get supportsDistinctOn(): boolean; + + /** + * Whether the dialect support inserting with `DEFAULT` as field value. + */ + abstract get supportsDefaultAsFieldValue(): boolean; + + /** + * Whether the dialect supports the RETURNING clause in INSERT/UPDATE/DELETE statements. + */ + abstract get supportsReturning(): boolean; + + /** + * Whether the dialect supports `INSERT INTO ... DEFAULT VALUES` syntax. + */ + abstract get supportsInsertDefaultValues(): boolean; + + /** + * How to perform insert ignore operation. + */ + abstract get insertIgnoreMethod(): 'onConflict' | 'ignore'; + + // #endregion + + // #region value transformation + + /** + * Transforms input value before sending to database. + */ + transformInput(value: unknown, _type: BuiltinType, _forArrayField: boolean) { return value; } + /** + * Transforms output value received from database. + */ transformOutput(value: unknown, _type: BuiltinType, _array: boolean) { return value; } + // #endregion + // #region common query builders buildSelectModel(model: string, modelAlias: string) { @@ -90,7 +139,7 @@ export abstract class BaseCrudDialect { result = this.buildSkipTake(result, skip, take); // orderBy - result = this.buildOrderBy(result, model, modelAlias, args.orderBy, negateOrderBy); + result = this.buildOrderBy(result, model, modelAlias, args.orderBy, negateOrderBy, take); // distinct if ('distinct' in args && (args as any).distinct) { @@ -160,8 +209,8 @@ export abstract class BaseCrudDialect { private buildCursorFilter( model: string, query: SelectQueryBuilder, - cursor: FindArgs, true>['cursor'], - orderBy: FindArgs, true>['orderBy'], + cursor: object, + orderBy: OrArray> | undefined, negateOrderBy: boolean, modelAlias: string, ) { @@ -394,51 +443,35 @@ export abstract class BaseCrudDialect { continue; } - switch (key) { - case 'some': { - result = this.and( - result, - this.eb( + const countSelect = (negate: boolean) => { + const filter = this.buildFilter(relationModel, relationFilterSelectAlias, subPayload); + return ( + this.eb + // the outer select is needed to avoid mysql's scope issue + .selectFrom( this.buildSelectModel(relationModel, relationFilterSelectAlias) .select(() => this.eb.fn.count(this.eb.lit(1)).as('$count')) .where(buildPkFkWhereRefs(this.eb)) - .where(() => this.buildFilter(relationModel, relationFilterSelectAlias, subPayload)), - '>', - 0, - ), - ); + .where(() => (negate ? this.eb.not(filter) : filter)) + .as('$sub'), + ) + .select('$count') + ); + }; + + switch (key) { + case 'some': { + result = this.and(result, this.eb(countSelect(false), '>', 0)); break; } case 'every': { - result = this.and( - result, - this.eb( - this.buildSelectModel(relationModel, relationFilterSelectAlias) - .select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count')) - .where(buildPkFkWhereRefs(this.eb)) - .where(() => - this.eb.not(this.buildFilter(relationModel, relationFilterSelectAlias, subPayload)), - ), - '=', - 0, - ), - ); + result = this.and(result, this.eb(countSelect(true), '=', 0)); break; } case 'none': { - result = this.and( - result, - this.eb( - this.buildSelectModel(relationModel, relationFilterSelectAlias) - .select(() => this.eb.fn.count(this.eb.lit(1)).as('$count')) - .where(buildPkFkWhereRefs(this.eb)) - .where(() => this.buildFilter(relationModel, relationFilterSelectAlias, subPayload)), - '=', - 0, - ), - ); + result = this.and(result, this.eb(countSelect(false), '=', 0)); break; } } @@ -456,7 +489,7 @@ export abstract class BaseCrudDialect { continue; } - const value = this.transformPrimitive(_value, fieldType, !!fieldDef.array); + const value = this.transformInput(_value, fieldType, !!fieldDef.array); switch (key) { case 'equals': { @@ -550,7 +583,7 @@ export abstract class BaseCrudDialect { const path = filter.path; const jsonReceiver = this.buildJsonPathSelection(receiver, path); - const stringReceiver = this.eb.cast(jsonReceiver, 'text'); + const stringReceiver = this.castText(jsonReceiver); const mode = filter.mode ?? 'default'; invariant(mode === 'default' || mode === 'insensitive', 'Invalid JSON filter mode'); @@ -658,7 +691,7 @@ export abstract class BaseCrudDialect { const clauses: Expression[] = []; if (filter === null) { - return this.eb(receiver, '=', 'null'); + return this.eb(receiver, '=', this.transformInput(null, 'Json', false)); } invariant(filter && typeof filter === 'object', 'Typed JSON filter payload must be an object'); @@ -688,7 +721,7 @@ export abstract class BaseCrudDialect { let _receiver = fieldReceiver; if (fieldDef.type === 'String') { // trim quotes for string fields - _receiver = this.eb.fn('trim', [this.eb.cast(fieldReceiver, 'text'), sql.lit('"')]); + _receiver = this.trimTextQuotes(this.castText(fieldReceiver)); } clauses.push(this.buildPrimitiveFilter(_receiver, fieldDef, value)); } @@ -702,17 +735,24 @@ export abstract class BaseCrudDialect { if (value instanceof DbNullClass) { return this.eb(lhs, 'is', null); } else if (value instanceof JsonNullClass) { - return this.eb.and([this.eb(lhs, '=', 'null'), this.eb(lhs, 'is not', null)]); + return this.eb.and([ + this.eb(lhs, '=', this.transformInput(null, 'Json', false)), + this.eb(lhs, 'is not', null), + ]); } else if (value instanceof AnyNullClass) { // AnyNull matches both DB NULL and JSON null - return this.eb.or([this.eb(lhs, 'is', null), this.eb(lhs, '=', 'null')]); + return this.eb.or([this.eb(lhs, 'is', null), this.eb(lhs, '=', this.transformInput(null, 'Json', false))]); } else { - return this.buildLiteralFilter(lhs, 'Json', value); + return this.buildJsonEqualityFilter(lhs, value); } } + protected buildJsonEqualityFilter(lhs: Expression, rhs: unknown) { + return this.buildLiteralFilter(lhs, 'Json', rhs); + } + private buildLiteralFilter(lhs: Expression, type: BuiltinType, rhs: unknown) { - return this.eb(lhs, '=', rhs !== null && rhs !== undefined ? this.transformPrimitive(rhs, type, false) : rhs); + return this.eb(lhs, '=', rhs !== null && rhs !== undefined ? this.transformInput(rhs, type, false) : rhs); } private buildStandardFilter( @@ -869,7 +909,7 @@ export abstract class BaseCrudDialect { private buildStringLike(receiver: Expression, pattern: string, insensitive: boolean) { const { supportsILike } = this.getStringCasingBehavior(); const op = insensitive && supportsILike ? 'ilike' : 'like'; - return sql`${receiver} ${sql.raw(op)} ${sql.val(pattern)} escape '\\'`; + return sql`${receiver} ${sql.raw(op)} ${sql.val(pattern)} escape ${sql.val('\\')}`; } private prepStringCasing( @@ -895,7 +935,7 @@ export abstract class BaseCrudDialect { type, payload, fieldRef, - (value) => this.transformPrimitive(value, type, false), + (value) => this.transformInput(value, type, false), (value) => this.buildNumberFilter(fieldRef, type, value), ); return this.and(...conditions); @@ -906,7 +946,7 @@ export abstract class BaseCrudDialect { 'Boolean', payload, fieldRef, - (value) => this.transformPrimitive(value, 'Boolean', false), + (value) => this.transformInput(value, 'Boolean', false), (value) => this.buildBooleanFilter(fieldRef, value as BooleanFilter), true, ['equals', 'not'], @@ -919,7 +959,7 @@ export abstract class BaseCrudDialect { 'DateTime', payload, fieldRef, - (value) => this.transformPrimitive(value, 'DateTime', false), + (value) => this.transformInput(value, 'DateTime', false), (value) => this.buildDateTimeFilter(fieldRef, value as DateTimeFilter), true, ); @@ -931,7 +971,7 @@ export abstract class BaseCrudDialect { 'Bytes', payload, fieldRef, - (value) => this.transformPrimitive(value, 'Bytes', false), + (value) => this.transformInput(value, 'Bytes', false), (value) => this.buildBytesFilter(fieldRef, value as BytesFilter), true, ['equals', 'in', 'notIn', 'not'], @@ -958,6 +998,7 @@ export abstract class BaseCrudDialect { modelAlias: string, orderBy: OrArray, boolean, boolean>> | undefined, negated: boolean, + take: number | undefined, ) { if (!orderBy) { return query; @@ -980,7 +1021,7 @@ export abstract class BaseCrudDialect { // aggregations if (['_count', '_avg', '_sum', '_min', '_max'].includes(field)) { - invariant(value && typeof value === 'object', `invalid orderBy value for field "${field}"`); + invariant(typeof value === 'object', `invalid orderBy value for field "${field}"`); for (const [k, v] of Object.entries(value)) { invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`); result = result.orderBy( @@ -991,22 +1032,6 @@ export abstract class BaseCrudDialect { continue; } - switch (field) { - case '_count': { - invariant(value && typeof value === 'object', 'invalid orderBy value for field "_count"'); - for (const [k, v] of Object.entries(value)) { - invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`); - result = result.orderBy( - (eb) => eb.fn.count(buildFieldRef(model, k, modelAlias)), - this.negateSort(v, negated), - ); - } - continue; - } - default: - break; - } - const fieldDef = requireField(this.schema, model, field); if (!fieldDef.relation) { @@ -1014,19 +1039,18 @@ export abstract class BaseCrudDialect { if (value === 'asc' || value === 'desc') { result = result.orderBy(fieldRef, this.negateSort(value, negated)); } else if ( - value && typeof value === 'object' && 'nulls' in value && 'sort' in value && (value.sort === 'asc' || value.sort === 'desc') && (value.nulls === 'first' || value.nulls === 'last') ) { - result = result.orderBy(fieldRef, (ob) => { - const dir = this.negateSort(value.sort, negated); - ob = dir === 'asc' ? ob.asc() : ob.desc(); - ob = value.nulls === 'first' ? ob.nullsFirst() : ob.nullsLast(); - return ob; - }); + result = this.buildOrderByField( + result, + fieldRef, + this.negateSort(value.sort, negated), + value.nulls, + ); } } else { // order by relation @@ -1069,7 +1093,7 @@ export abstract class BaseCrudDialect { ), ); }); - result = this.buildOrderBy(result, relationModel, joinAlias, value, negated); + result = this.buildOrderBy(result, relationModel, joinAlias, value, negated, take); } } } @@ -1253,16 +1277,16 @@ export abstract class BaseCrudDialect { // #region utils - private negateSort(sort: SortOrder, negated: boolean) { + protected negateSort(sort: SortOrder, negated: boolean) { return negated ? (sort === 'asc' ? 'desc' : 'asc') : sort; } public true(): Expression { - return this.eb.lit(this.transformPrimitive(true, 'Boolean', false) as boolean); + return this.eb.lit(this.transformInput(true, 'Boolean', false) as boolean); } public false(): Expression { - return this.eb.lit(this.transformPrimitive(false, 'Boolean', false) as boolean); + return this.eb.lit(this.transformInput(false, 'Boolean', false) as boolean); } public isTrue(expression: Expression) { @@ -1388,37 +1412,32 @@ export abstract class BaseCrudDialect { /** * Builds an Kysely expression that returns a JSON object for the given key-value pairs. */ - abstract buildJsonObject(value: Record>): ExpressionWrapper; + abstract buildJsonObject(value: Record>): AliasableExpression; /** * Builds an Kysely expression that returns the length of an array. */ - abstract buildArrayLength(array: Expression): ExpressionWrapper; + abstract buildArrayLength(array: Expression): AliasableExpression; /** * Builds an array literal SQL string for the given values. */ - abstract buildArrayLiteralSQL(values: unknown[]): string; + abstract buildArrayLiteralSQL(values: unknown[]): AliasableExpression; /** - * Whether the dialect supports updating with a limit on the number of updated rows. - */ - abstract get supportsUpdateWithLimit(): boolean; - - /** - * Whether the dialect supports deleting with a limit on the number of deleted rows. + * Casts the given expression to an integer type. */ - abstract get supportsDeleteWithLimit(): boolean; + abstract castInt>(expression: T): T; /** - * Whether the dialect supports DISTINCT ON. + * Casts the given expression to a text type. */ - abstract get supportsDistinctOn(): boolean; + abstract castText>(expression: T): T; /** - * Whether the dialect support inserting with `DEFAULT` as field value. + * Trims double quotes from the start and end of a text expression. */ - abstract get supportInsertWithDefault(): boolean; + abstract trimTextQuotes>(expression: T): T; /** * Gets the SQL column type for the given field definition. @@ -1430,6 +1449,11 @@ export abstract class BaseCrudDialect { */ abstract getStringCasingBehavior(): { supportsILike: boolean; likeCaseSensitive: boolean }; + /** + * Builds a VALUES table and select all fields from it. + */ + abstract buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]): SelectQueryBuilder; + /** * Builds a JSON path selection expression. */ @@ -1452,5 +1476,15 @@ export abstract class BaseCrudDialect { buildFilter: (elem: Expression) => Expression, ): Expression; + /** + * Builds an ORDER BY clause for a field with NULLS FIRST/LAST support. + */ + protected abstract buildOrderByField( + query: SelectQueryBuilder, + field: Expression, + sort: SortOrder, + nulls: 'first' | 'last', + ): SelectQueryBuilder; + // #endregion } diff --git a/packages/orm/src/client/crud/dialects/index.ts b/packages/orm/src/client/crud/dialects/index.ts index ede19cdd..fb9a7379 100644 --- a/packages/orm/src/client/crud/dialects/index.ts +++ b/packages/orm/src/client/crud/dialects/index.ts @@ -2,6 +2,7 @@ import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import type { ClientOptions } from '../../options'; import type { BaseCrudDialect } from './base-dialect'; +import { MySqlCrudDialect } from './mysql'; import { PostgresCrudDialect } from './postgresql'; import { SqliteCrudDialect } from './sqlite'; @@ -12,5 +13,6 @@ export function getCrudDialect( return match(schema.provider.type) .with('sqlite', () => new SqliteCrudDialect(schema, options)) .with('postgresql', () => new PostgresCrudDialect(schema, options)) + .with('mysql', () => new MySqlCrudDialect(schema, options)) .exhaustive(); } diff --git a/packages/orm/src/client/crud/dialects/lateral-join-dialect-base.ts b/packages/orm/src/client/crud/dialects/lateral-join-dialect-base.ts new file mode 100644 index 00000000..6bc1f887 --- /dev/null +++ b/packages/orm/src/client/crud/dialects/lateral-join-dialect-base.ts @@ -0,0 +1,291 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import { type AliasableExpression, type Expression, type ExpressionBuilder, type SelectQueryBuilder } from 'kysely'; +import type { FieldDef, GetModels, SchemaDef } from '../../../schema'; +import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; +import type { FindArgs } from '../../crud-types'; +import { + buildJoinPairs, + getDelegateDescendantModels, + getManyToManyRelation, + isRelationField, + requireField, + requireIdFields, + requireModel, +} from '../../query-utils'; +import { BaseCrudDialect } from './base-dialect'; + +/** + * Base class for dialects that support lateral joins (MySQL and PostgreSQL). + * Contains common logic for building relation selections using lateral joins and JSON aggregation. + */ +export abstract class LateralJoinDialectBase extends BaseCrudDialect { + /** + * Builds an array aggregation expression. + */ + protected abstract buildArrayAgg(arg: Expression): AliasableExpression; + + override buildRelationSelection( + query: SelectQueryBuilder, + model: string, + relationField: string, + parentAlias: string, + payload: true | FindArgs, true>, + ): SelectQueryBuilder { + const relationResultName = `${parentAlias}$${relationField}`; + const joinedQuery = this.buildRelationJSON( + model, + query, + relationField, + parentAlias, + payload, + relationResultName, + ); + return joinedQuery.select(`${relationResultName}.$data as ${relationField}`); + } + + private buildRelationJSON( + model: string, + qb: SelectQueryBuilder, + relationField: string, + parentAlias: string, + payload: true | FindArgs, true>, + resultName: string, + ) { + const relationFieldDef = requireField(this.schema, model, relationField); + const relationModel = relationFieldDef.type as GetModels; + + return qb.leftJoinLateral( + (eb) => { + const relationSelectName = `${resultName}$sub`; + const relationModelDef = requireModel(this.schema, relationModel); + + let tbl: SelectQueryBuilder; + + if (this.canJoinWithoutNestedSelect(relationModelDef, payload)) { + // build join directly + tbl = this.buildModelSelect(relationModel, relationSelectName, payload, false); + + // parent join filter + tbl = this.buildRelationJoinFilter( + tbl, + model, + relationField, + relationModel, + relationSelectName, + parentAlias, + ); + } else { + // join with a nested query + tbl = eb.selectFrom(() => { + let subQuery = this.buildModelSelect(relationModel, `${relationSelectName}$t`, payload, true); + + // parent join filter + subQuery = this.buildRelationJoinFilter( + subQuery, + model, + relationField, + relationModel, + `${relationSelectName}$t`, + parentAlias, + ); + + if (typeof payload !== 'object' || payload.take === undefined) { + // force adding a limit otherwise the ordering is ignored by some databases + // during JSON array aggregation + subQuery = subQuery.limit(Number.MAX_SAFE_INTEGER); + } + + return subQuery.as(relationSelectName); + }); + } + + // select relation result + tbl = this.buildRelationObjectSelect( + relationModel, + relationSelectName, + relationFieldDef, + tbl, + payload, + resultName, + ); + + // add nested joins for each relation + tbl = this.buildRelationJoins(tbl, relationModel, relationSelectName, payload, resultName); + + // alias the join table + return tbl.as(resultName); + }, + (join) => join.onTrue(), + ); + } + + private buildRelationJoinFilter( + query: SelectQueryBuilder, + model: string, + relationField: string, + relationModel: GetModels, + relationModelAlias: string, + parentAlias: string, + ) { + const m2m = getManyToManyRelation(this.schema, model, relationField); + if (m2m) { + // many-to-many relation + const parentIds = requireIdFields(this.schema, model); + const relationIds = requireIdFields(this.schema, relationModel); + invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field'); + invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field'); + query = query.where((eb) => + eb( + eb.ref(`${relationModelAlias}.${relationIds[0]}`), + 'in', + eb + .selectFrom(m2m.joinTable) + .select(`${m2m.joinTable}.${m2m.otherFkName}`) + .whereRef(`${parentAlias}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`), + ), + ); + } else { + const joinPairs = buildJoinPairs(this.schema, model, parentAlias, relationField, relationModelAlias); + query = query.where((eb) => + this.and(...joinPairs.map(([left, right]) => eb(this.eb.ref(left), '=', this.eb.ref(right)))), + ); + } + return query; + } + + private buildRelationObjectSelect( + relationModel: string, + relationModelAlias: string, + relationFieldDef: FieldDef, + qb: SelectQueryBuilder, + payload: true | FindArgs, true>, + parentResultName: string, + ) { + qb = qb.select((eb) => { + const objArgs = this.buildRelationObjectArgs( + relationModel, + relationModelAlias, + eb, + payload, + parentResultName, + ); + + if (relationFieldDef.array) { + return this.buildArrayAgg(this.buildJsonObject(objArgs)).as('$data'); + } else { + return this.buildJsonObject(objArgs).as('$data'); + } + }); + + return qb; + } + + private buildRelationObjectArgs( + relationModel: string, + relationModelAlias: string, + eb: ExpressionBuilder, + payload: true | FindArgs, true>, + parentResultName: string, + ) { + const relationModelDef = requireModel(this.schema, relationModel); + const objArgs: Record> = {}; + + const descendantModels = getDelegateDescendantModels(this.schema, relationModel); + if (descendantModels.length > 0) { + // select all JSONs built from delegate descendants + Object.assign( + objArgs, + ...descendantModels.map((subModel) => ({ + [`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`]: eb.ref( + `${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`, + ), + })), + ); + } + + if (payload === true || !payload.select) { + // select all scalar fields except for omitted + const omit = typeof payload === 'object' ? payload.omit : undefined; + + Object.assign( + objArgs, + ...Object.entries(relationModelDef.fields) + .filter(([, value]) => !value.relation) + .filter(([name]) => !this.shouldOmitField(omit, relationModel, name)) + .map(([field]) => ({ + [field]: this.fieldRef(relationModel, field, relationModelAlias, false), + })), + ); + } else if (payload.select) { + // select specific fields + Object.assign( + objArgs, + ...Object.entries(payload.select) + .filter(([, value]) => value) + .map(([field, value]) => { + if (field === '_count') { + const subJson = this.buildCountJson( + relationModel as GetModels, + eb, + relationModelAlias, + value, + ); + return { [field]: subJson }; + } else { + const fieldDef = requireField(this.schema, relationModel, field); + const fieldValue = fieldDef.relation + ? // reference the synthesized JSON field + eb.ref(`${parentResultName}$${field}.$data`) + : // reference a plain field + this.fieldRef(relationModel, field, relationModelAlias, false); + return { [field]: fieldValue }; + } + }), + ); + } + + if (typeof payload === 'object' && payload.include && typeof payload.include === 'object') { + // include relation fields + + Object.assign( + objArgs, + ...Object.entries(payload.include) + .filter(([, value]) => value) + .map(([field]) => ({ + [field]: eb.ref(`${parentResultName}$${field}.$data`), + })), + ); + } + + return objArgs; + } + + private buildRelationJoins( + query: SelectQueryBuilder, + relationModel: string, + relationModelAlias: string, + payload: true | FindArgs, true>, + parentResultName: string, + ) { + let result = query; + if (typeof payload === 'object') { + const selectInclude = payload.include ?? payload.select; + if (selectInclude && typeof selectInclude === 'object') { + Object.entries(selectInclude) + .filter(([, value]) => value) + .filter(([field]) => isRelationField(this.schema, relationModel, field)) + .forEach(([field, value]) => { + result = this.buildRelationJSON( + relationModel, + result, + field, + relationModelAlias, + value, + `${parentResultName}$${field}`, + ); + }); + } + } + return result; + } +} diff --git a/packages/orm/src/client/crud/dialects/mysql.ts b/packages/orm/src/client/crud/dialects/mysql.ts new file mode 100644 index 00000000..e196323a --- /dev/null +++ b/packages/orm/src/client/crud/dialects/mysql.ts @@ -0,0 +1,379 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import Decimal from 'decimal.js'; +import type { AliasableExpression, TableExpression } from 'kysely'; +import { + expressionBuilder, + sql, + type Expression, + type ExpressionWrapper, + type SelectQueryBuilder, + type SqlBool, +} from 'kysely'; +import { match } from 'ts-pattern'; +import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types'; +import type { BuiltinType, FieldDef, SchemaDef } from '../../../schema'; +import type { SortOrder } from '../../crud-types'; +import { createInternalError, createInvalidInputError, createNotSupportedError } from '../../errors'; +import type { ClientOptions } from '../../options'; +import { isTypeDef } from '../../query-utils'; +import { LateralJoinDialectBase } from './lateral-join-dialect-base'; + +export class MySqlCrudDialect extends LateralJoinDialectBase { + constructor(schema: Schema, options: ClientOptions) { + super(schema, options); + } + + override get provider() { + return 'mysql' as const; + } + + // #region capabilities + + override get supportsUpdateWithLimit(): boolean { + return true; + } + + override get supportsDeleteWithLimit(): boolean { + return true; + } + + override get supportsDistinctOn(): boolean { + return false; + } + + override get supportsReturning(): boolean { + return false; + } + + override get supportsInsertDefaultValues(): boolean { + return false; + } + + override get supportsDefaultAsFieldValue() { + return true; + } + + override get insertIgnoreMethod() { + return 'ignore' as const; + } + + // #endregion + + // #region value transformation + + override transformInput(value: unknown, type: BuiltinType, forArrayField: boolean): unknown { + if (value === undefined) { + return value; + } + + // Handle special null classes for JSON fields + if (value instanceof JsonNullClass) { + return this.eb.cast(sql.lit('null'), 'json'); + } else if (value instanceof DbNullClass) { + return null; + } else if (value instanceof AnyNullClass) { + invariant(false, 'should not reach here: AnyNull is not a valid input value'); + } + + if (isTypeDef(this.schema, type)) { + // type-def fields (regardless array or scalar) are stored as scalar `Json` and + // their input values need to be stringified if not already (i.e., provided in + // default values) + if (typeof value !== 'string') { + return this.transformInput(value, 'Json', forArrayField); + } else { + return value; + } + } else if (Array.isArray(value)) { + if (type === 'Json') { + // type-def arrays reach here + return JSON.stringify(value); + } else { + throw createNotSupportedError(`MySQL does not support array literals`); + } + } else { + return match(type) + .with('Boolean', () => (value ? 1 : 0)) // MySQL uses 1/0 for boolean like SQLite + .with('DateTime', () => { + // MySQL DATETIME format: 'YYYY-MM-DD HH:MM:SS.mmm' + if (value instanceof Date) { + // force UTC + return value.toISOString().replace('Z', '+00:00'); + } else if (typeof value === 'string') { + // parse and force UTC + return new Date(value).toISOString().replace('Z', '+00:00'); + } else { + return value; + } + }) + .with('Decimal', () => (value !== null ? value.toString() : value)) + .with('Json', () => { + return this.eb.cast(this.eb.val(JSON.stringify(value)), 'json'); + }) + .with('Bytes', () => + Buffer.isBuffer(value) ? value : value instanceof Uint8Array ? Buffer.from(value) : value, + ) + .otherwise(() => value); + } + } + + override transformOutput(value: unknown, type: BuiltinType, array: boolean) { + if (value === null || value === undefined) { + return value; + } + return match(type) + .with('Boolean', () => this.transformOutputBoolean(value)) + .with('DateTime', () => this.transformOutputDate(value)) + .with('Bytes', () => this.transformOutputBytes(value)) + .with('BigInt', () => this.transformOutputBigInt(value)) + .with('Decimal', () => this.transformDecimal(value)) + .otherwise(() => super.transformOutput(value, type, array)); + } + + private transformOutputBoolean(value: unknown) { + return !!value; + } + + private transformOutputBigInt(value: unknown) { + if (typeof value === 'bigint') { + return value; + } + invariant( + typeof value === 'string' || typeof value === 'number', + `Expected string or number, got ${typeof value}`, + ); + return BigInt(value); + } + + private transformDecimal(value: unknown) { + if (value instanceof Decimal) { + return value; + } + invariant( + typeof value === 'string' || typeof value === 'number' || value instanceof Decimal, + `Expected string, number or Decimal, got ${typeof value}`, + ); + return new Decimal(value); + } + + private transformOutputDate(value: unknown) { + if (typeof value === 'string') { + // MySQL DateTime columns are returned as strings (non-ISO but parsable as JS Date), + // convert to ISO Date by appending 'Z' if not present + return new Date(!value.endsWith('Z') ? value + 'Z' : value); + } else if (value instanceof Date) { + return value; + } else { + return value; + } + } + + private transformOutputBytes(value: unknown) { + return Buffer.isBuffer(value) ? Uint8Array.from(value) : value; + } + + // #endregion + + // #region other overrides + + protected buildArrayAgg(arg: Expression): AliasableExpression { + return this.eb.fn.coalesce(sql`JSON_ARRAYAGG(${arg})`, sql`JSON_ARRAY()`); + } + + override buildSkipTake( + query: SelectQueryBuilder, + skip: number | undefined, + take: number | undefined, + ) { + if (take !== undefined) { + query = query.limit(take); + } + if (skip !== undefined) { + query = query.offset(skip); + if (take === undefined) { + // MySQL requires offset to be used with limit + query = query.limit(Number.MAX_SAFE_INTEGER); + } + } + return query; + } + + override buildJsonObject(value: Record>) { + return this.eb.fn( + 'JSON_OBJECT', + Object.entries(value).flatMap(([key, value]) => [sql.lit(key), value]), + ); + } + + override castInt>(expression: T): T { + return this.eb.cast(expression, sql.raw('unsigned')) as unknown as T; + } + + override castText>(expression: T): T { + // Use utf8mb4 character set collation to match MySQL 8.0+ default and avoid + // collation conflicts when comparing with VALUES ROW columns + return sql`CAST(${expression} AS CHAR CHARACTER SET utf8mb4)` as unknown as T; + } + + override trimTextQuotes>(expression: T): T { + return sql`TRIM(BOTH ${sql.lit('"')} FROM ${expression})` as unknown as T; + } + + override buildArrayLength(array: Expression): AliasableExpression { + return this.eb.fn('JSON_LENGTH', [array]); + } + + override buildArrayLiteralSQL(_values: unknown[]): AliasableExpression { + throw new Error('MySQL does not support array literals'); + } + + protected override buildJsonEqualityFilter( + lhs: Expression, + rhs: unknown, + ): ExpressionWrapper { + // MySQL's JSON equality comparison is key-order sensitive, use bi-directional JSON_CONTAINS + // instead to achieve key-order insensitive comparison + return this.eb.and([ + this.eb.fn('JSON_CONTAINS', [lhs, this.eb.val(JSON.stringify(rhs))]), + this.eb.fn('JSON_CONTAINS', [this.eb.val(JSON.stringify(rhs)), lhs]), + ]); + } + + protected override buildJsonPathSelection(receiver: Expression, path: string | undefined) { + if (path) { + return this.eb.fn('JSON_EXTRACT', [receiver, this.eb.val(path)]); + } else { + return receiver; + } + } + + protected override buildJsonArrayFilter( + lhs: Expression, + operation: 'array_contains' | 'array_starts_with' | 'array_ends_with', + value: unknown, + ) { + return match(operation) + .with('array_contains', () => { + const v = Array.isArray(value) ? value : [value]; + return sql`JSON_CONTAINS(${lhs}, ${sql.val(JSON.stringify(v))})`; + }) + .with('array_starts_with', () => + this.eb( + this.eb.fn('JSON_EXTRACT', [lhs, this.eb.val('$[0]')]), + '=', + this.transformInput(value, 'Json', false), + ), + ) + .with('array_ends_with', () => + this.eb( + sql`JSON_EXTRACT(${lhs}, CONCAT('$[', JSON_LENGTH(${lhs}) - 1, ']'))`, + '=', + this.transformInput(value, 'Json', false), + ), + ) + .exhaustive(); + } + + protected override buildJsonArrayExistsPredicate( + receiver: Expression, + buildFilter: (elem: Expression) => Expression, + ) { + // MySQL doesn't have a direct json_array_elements, we need to use JSON_TABLE or a different approach + // For simplicity, we'll use EXISTS with a subquery that unnests the JSON array + return this.eb.exists( + this.eb + .selectFrom(sql`JSON_TABLE(${receiver}, '$[*]' COLUMNS(value JSON PATH '$'))`.as('$items')) + .select(this.eb.lit(1).as('$t')) + .where(buildFilter(this.eb.ref('$items.value'))), + ); + } + + override getFieldSqlType(fieldDef: FieldDef) { + // TODO: respect `@db.x` attributes + if (fieldDef.relation) { + throw createInternalError('Cannot get SQL type of a relation field'); + } + + let result: string; + + if (this.schema.enums?.[fieldDef.type]) { + // enums are treated as text/varchar + result = 'varchar(255)'; + } else { + result = match(fieldDef.type) + .with('String', () => 'varchar(255)') + .with('Boolean', () => 'tinyint(1)') // MySQL uses tinyint(1) for boolean + .with('Int', () => 'int') + .with('BigInt', () => 'bigint') + .with('Float', () => 'double') + .with('Decimal', () => 'decimal') + .with('DateTime', () => 'datetime') + .with('Bytes', () => 'blob') + .with('Json', () => 'json') + // fallback to text + .otherwise(() => 'text'); + } + + if (fieldDef.array) { + // MySQL stores arrays as JSON + result = 'json'; + } + + return result; + } + + override getStringCasingBehavior() { + // MySQL LIKE is case-insensitive by default (depends on collation), no ILIKE support + return { supportsILike: false, likeCaseSensitive: false }; + } + + override buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]) { + const cols = rows[0]?.length ?? 0; + + if (fields.length !== cols) { + throw createInvalidInputError('Number of fields must match number of columns in each row'); + } + + // check all rows have the same length + for (const row of rows) { + if (row.length !== cols) { + throw createInvalidInputError('All rows must have the same number of columns'); + } + } + + // build final alias name as `$values(f1, f2, ...)` + const aliasWithColumns = `$values(${fields.map((f) => f.name).join(', ')})`; + + const eb = expressionBuilder(); + + return eb + .selectFrom( + sql`(VALUES ${sql.join( + rows.map((row) => sql`ROW(${sql.join(row.map((v) => sql.val(v)))})`), + sql.raw(', '), + )}) as ${sql.raw(aliasWithColumns)}` as unknown as TableExpression, + ) + .selectAll(); + } + + protected override buildOrderByField( + query: SelectQueryBuilder, + field: Expression, + sort: SortOrder, + nulls: 'first' | 'last', + ) { + let result = query; + if (nulls === 'first') { + // NULLS FIRST: order by IS NULL DESC (nulls=1 first), then the actual field + result = result.orderBy(sql`${field} IS NULL`, 'desc'); + result = result.orderBy(field, sort); + } else { + // NULLS LAST: order by IS NULL ASC (nulls=0 last), then the actual field + result = result.orderBy(sql`${field} IS NULL`, 'asc'); + result = result.orderBy(field, sort); + } + return result; + } + + // #endregion +} diff --git a/packages/orm/src/client/crud/dialects/postgresql.ts b/packages/orm/src/client/crud/dialects/postgresql.ts index bae04907..02e9435f 100644 --- a/packages/orm/src/client/crud/dialects/postgresql.ts +++ b/packages/orm/src/client/crud/dialects/postgresql.ts @@ -1,11 +1,10 @@ import { invariant } from '@zenstackhq/common-helpers'; import Decimal from 'decimal.js'; import { + expressionBuilder, sql, + type AliasableExpression, type Expression, - type ExpressionBuilder, - type ExpressionWrapper, - type RawBuilder, type SelectQueryBuilder, type SqlBool, } from 'kysely'; @@ -13,26 +12,14 @@ import { parse as parsePostgresArray } from 'postgres-array'; import { match } from 'ts-pattern'; import z from 'zod'; import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types'; -import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schema'; -import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; -import type { FindArgs } from '../../crud-types'; -import { createInternalError } from '../../errors'; +import type { BuiltinType, FieldDef, SchemaDef } from '../../../schema'; +import type { SortOrder } from '../../crud-types'; +import { createInternalError, createInvalidInputError } from '../../errors'; import type { ClientOptions } from '../../options'; -import { - buildJoinPairs, - getDelegateDescendantModels, - getEnum, - getManyToManyRelation, - isEnum, - isRelationField, - isTypeDef, - requireField, - requireIdFields, - requireModel, -} from '../../query-utils'; -import { BaseCrudDialect } from './base-dialect'; - -export class PostgresCrudDialect extends BaseCrudDialect { +import { getEnum, isEnum, isTypeDef } from '../../query-utils'; +import { LateralJoinDialectBase } from './lateral-join-dialect-base'; + +export class PostgresCrudDialect extends LateralJoinDialectBase { private isoDateSchema = z.iso.datetime({ local: true, offset: true }); constructor(schema: Schema, options: ClientOptions) { @@ -43,7 +30,41 @@ export class PostgresCrudDialect extends BaseCrudDiale return 'postgresql' as const; } - override transformPrimitive(value: unknown, type: BuiltinType, forArrayField: boolean): unknown { + // #region capabilities + + override get supportsUpdateWithLimit(): boolean { + return false; + } + + override get supportsDeleteWithLimit(): boolean { + return false; + } + + override get supportsDistinctOn(): boolean { + return true; + } + + override get supportsReturning(): boolean { + return true; + } + + override get supportsDefaultAsFieldValue() { + return true; + } + + override get supportsInsertDefaultValues(): boolean { + return true; + } + + override get insertIgnoreMethod() { + return 'onConflict' as const; + } + + // #endregion + + // #region value transformation + + override transformInput(value: unknown, type: BuiltinType, forArrayField: boolean): unknown { if (value === undefined) { return value; } @@ -79,14 +100,14 @@ export class PostgresCrudDialect extends BaseCrudDiale // cast to enum array `CAST(ARRAY[...] AS "enum_type"[])` return this.eb.cast( sql`ARRAY[${sql.join( - value.map((v) => this.transformPrimitive(v, type, false)), + value.map((v) => this.transformInput(v, type, false)), sql.raw(','), )}]`, this.createSchemaQualifiedEnumType(type, true), ); } else { // `Json[]` fields need their input as array (not stringified) - return value.map((v) => this.transformPrimitive(v, type, false)); + return value.map((v) => this.transformInput(v, type, false)); } } else { return match(type) @@ -99,7 +120,12 @@ export class PostgresCrudDialect extends BaseCrudDiale ) .with('Decimal', () => (value !== null ? value.toString() : value)) .with('Json', () => { - if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') { + if ( + value === null || + typeof value === 'string' || + typeof value === 'number' || + typeof value === 'boolean' + ) { // postgres requires simple JSON values to be stringified return JSON.stringify(value); } else { @@ -219,264 +245,12 @@ export class PostgresCrudDialect extends BaseCrudDiale return value; } - override buildRelationSelection( - query: SelectQueryBuilder, - model: string, - relationField: string, - parentAlias: string, - payload: true | FindArgs, true>, - ): SelectQueryBuilder { - const relationResultName = `${parentAlias}$${relationField}`; - const joinedQuery = this.buildRelationJSON( - model, - query, - relationField, - parentAlias, - payload, - relationResultName, - ); - return joinedQuery.select(`${relationResultName}.$data as ${relationField}`); - } + // #endregion - private buildRelationJSON( - model: string, - qb: SelectQueryBuilder, - relationField: string, - parentAlias: string, - payload: true | FindArgs, true>, - resultName: string, - ) { - const relationFieldDef = requireField(this.schema, model, relationField); - const relationModel = relationFieldDef.type as GetModels; - - return qb.leftJoinLateral( - (eb) => { - const relationSelectName = `${resultName}$sub`; - const relationModelDef = requireModel(this.schema, relationModel); - - let tbl: SelectQueryBuilder; - - if (this.canJoinWithoutNestedSelect(relationModelDef, payload)) { - // build join directly - tbl = this.buildModelSelect(relationModel, relationSelectName, payload, false); - - // parent join filter - tbl = this.buildRelationJoinFilter( - tbl, - model, - relationField, - relationModel, - relationSelectName, - parentAlias, - ); - } else { - // join with a nested query - tbl = eb.selectFrom(() => { - let subQuery = this.buildModelSelect(relationModel, `${relationSelectName}$t`, payload, true); - - // parent join filter - subQuery = this.buildRelationJoinFilter( - subQuery, - model, - relationField, - relationModel, - `${relationSelectName}$t`, - parentAlias, - ); - - return subQuery.as(relationSelectName); - }); - } - - // select relation result - tbl = this.buildRelationObjectSelect( - relationModel, - relationSelectName, - relationFieldDef, - tbl, - payload, - resultName, - ); - - // add nested joins for each relation - tbl = this.buildRelationJoins(tbl, relationModel, relationSelectName, payload, resultName); - - // alias the join table - return tbl.as(resultName); - }, - (join) => join.onTrue(), - ); - } + // #region other overrides - private buildRelationJoinFilter( - query: SelectQueryBuilder, - model: string, - relationField: string, - relationModel: GetModels, - relationModelAlias: string, - parentAlias: string, - ) { - const m2m = getManyToManyRelation(this.schema, model, relationField); - if (m2m) { - // many-to-many relation - const parentIds = requireIdFields(this.schema, model); - const relationIds = requireIdFields(this.schema, relationModel); - invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field'); - invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field'); - query = query.where((eb) => - eb( - eb.ref(`${relationModelAlias}.${relationIds[0]}`), - 'in', - eb - .selectFrom(m2m.joinTable) - .select(`${m2m.joinTable}.${m2m.otherFkName}`) - .whereRef(`${parentAlias}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`), - ), - ); - } else { - const joinPairs = buildJoinPairs(this.schema, model, parentAlias, relationField, relationModelAlias); - query = query.where((eb) => - this.and(...joinPairs.map(([left, right]) => eb(this.eb.ref(left), '=', this.eb.ref(right)))), - ); - } - return query; - } - - private buildRelationObjectSelect( - relationModel: string, - relationModelAlias: string, - relationFieldDef: FieldDef, - qb: SelectQueryBuilder, - payload: true | FindArgs, true>, - parentResultName: string, - ) { - qb = qb.select((eb) => { - const objArgs = this.buildRelationObjectArgs( - relationModel, - relationModelAlias, - eb, - payload, - parentResultName, - ); - - if (relationFieldDef.array) { - return eb.fn - .coalesce(sql`jsonb_agg(jsonb_build_object(${sql.join(objArgs)}))`, sql`'[]'::jsonb`) - .as('$data'); - } else { - return sql`jsonb_build_object(${sql.join(objArgs)})`.as('$data'); - } - }); - - return qb; - } - - private buildRelationObjectArgs( - relationModel: string, - relationModelAlias: string, - eb: ExpressionBuilder, - payload: true | FindArgs, true>, - parentResultName: string, - ) { - const relationModelDef = requireModel(this.schema, relationModel); - const objArgs: Array< - string | ExpressionWrapper | SelectQueryBuilder | RawBuilder - > = []; - - const descendantModels = getDelegateDescendantModels(this.schema, relationModel); - if (descendantModels.length > 0) { - // select all JSONs built from delegate descendants - objArgs.push( - ...descendantModels - .map((subModel) => [ - sql.lit(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`), - eb.ref(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`), - ]) - .flatMap((v) => v), - ); - } - - if (payload === true || !payload.select) { - // select all scalar fields except for omitted - const omit = typeof payload === 'object' ? payload.omit : undefined; - objArgs.push( - ...Object.entries(relationModelDef.fields) - .filter(([, value]) => !value.relation) - .filter(([name]) => !this.shouldOmitField(omit, relationModel, name)) - .map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, relationModelAlias, false)]) - .flatMap((v) => v), - ); - } else if (payload.select) { - // select specific fields - objArgs.push( - ...Object.entries(payload.select) - .filter(([, value]) => value) - .map(([field, value]) => { - if (field === '_count') { - const subJson = this.buildCountJson( - relationModel as GetModels, - eb, - relationModelAlias, - value, - ); - return [sql.lit(field), subJson]; - } else { - const fieldDef = requireField(this.schema, relationModel, field); - const fieldValue = fieldDef.relation - ? // reference the synthesized JSON field - eb.ref(`${parentResultName}$${field}.$data`) - : // reference a plain field - this.fieldRef(relationModel, field, relationModelAlias, false); - return [sql.lit(field), fieldValue]; - } - }) - .flatMap((v) => v), - ); - } - - if (typeof payload === 'object' && payload.include && typeof payload.include === 'object') { - // include relation fields - objArgs.push( - ...Object.entries(payload.include) - .filter(([, value]) => value) - .map(([field]) => [ - sql.lit(field), - // reference the synthesized JSON field - eb.ref(`${parentResultName}$${field}.$data`), - ]) - .flatMap((v) => v), - ); - } - return objArgs; - } - - private buildRelationJoins( - query: SelectQueryBuilder, - relationModel: string, - relationModelAlias: string, - payload: true | FindArgs, true>, - parentResultName: string, - ) { - let result = query; - if (typeof payload === 'object') { - const selectInclude = payload.include ?? payload.select; - if (selectInclude && typeof selectInclude === 'object') { - Object.entries(selectInclude) - .filter(([, value]) => value) - .filter(([field]) => isRelationField(this.schema, relationModel, field)) - .forEach(([field, value]) => { - result = this.buildRelationJSON( - relationModel, - result, - field, - relationModelAlias, - value, - `${parentResultName}$${field}`, - ); - }); - } - } - return result; + protected buildArrayAgg(arg: Expression) { + return this.eb.fn.coalesce(sql`jsonb_agg(${arg})`, sql`'[]'::jsonb`); } override buildSkipTake( @@ -500,27 +274,30 @@ export class PostgresCrudDialect extends BaseCrudDiale ); } - override get supportsUpdateWithLimit(): boolean { - return false; + override castInt>(expression: T): T { + return this.eb.cast(expression, 'integer') as unknown as T; } - override get supportsDeleteWithLimit(): boolean { - return false; + override castText>(expression: T): T { + return this.eb.cast(expression, 'text') as unknown as T; } - override get supportsDistinctOn(): boolean { - return true; + override trimTextQuotes>(expression: T): T { + return this.eb.fn('trim', [expression, sql.lit('"')]) as unknown as T; } - override buildArrayLength(array: Expression): ExpressionWrapper { + override buildArrayLength(array: Expression): AliasableExpression { return this.eb.fn('array_length', [array]); } - override buildArrayLiteralSQL(values: unknown[]): string { + override buildArrayLiteralSQL(values: unknown[]): AliasableExpression { if (values.length === 0) { - return '{}'; + return sql`{}`; } else { - return `ARRAY[${values.map((v) => (typeof v === 'string' ? `'${v}'` : v))}]`; + return sql`ARRAY[${sql.join( + values.map((v) => sql.val(v)), + sql.raw(','), + )}]`; } } @@ -546,14 +323,14 @@ export class PostgresCrudDialect extends BaseCrudDiale this.eb( this.eb.fn('jsonb_extract_path', [lhs, this.eb.val('0')]), '=', - this.transformPrimitive(value, 'Json', false), + this.transformInput(value, 'Json', false), ), ) .with('array_ends_with', () => this.eb( this.eb.fn('jsonb_extract_path', [lhs, sql`(jsonb_array_length(${lhs}) - 1)::text`]), '=', - this.transformPrimitive(value, 'Json', false), + this.transformInput(value, 'Json', false), ), ) .exhaustive(); @@ -571,10 +348,6 @@ export class PostgresCrudDialect extends BaseCrudDiale ); } - override get supportInsertWithDefault() { - return true; - } - override getFieldSqlType(fieldDef: FieldDef) { // TODO: respect `@db.x` attributes if (fieldDef.relation) { @@ -612,4 +385,53 @@ export class PostgresCrudDialect extends BaseCrudDiale // Postgres `LIKE` is case-sensitive, `ILIKE` is case-insensitive return { supportsILike: true, likeCaseSensitive: true }; } + + override buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]) { + if (rows.length === 0) { + throw createInvalidInputError('At least one row is required to build values table'); + } + + // check all rows have the same length + const rowLength = rows[0]!.length; + + if (fields.length !== rowLength) { + throw createInvalidInputError('Number of fields must match number of columns in each row'); + } + + for (const row of rows) { + if (row.length !== rowLength) { + throw createInvalidInputError('All rows must have the same number of columns'); + } + } + + const eb = expressionBuilder(); + + return eb + .selectFrom( + sql`(VALUES ${sql.join( + rows.map((row) => sql`(${sql.join(row.map((v) => sql.val(v)))})`), + sql.raw(', '), + )})`.as('$values'), + ) + .select( + fields.map((f, i) => + sql`CAST(${sql.ref(`$values.column${i + 1}`)} AS ${sql.raw(this.getFieldSqlType(f))})`.as(f.name), + ), + ); + } + + protected override buildOrderByField( + query: SelectQueryBuilder, + field: Expression, + sort: SortOrder, + nulls: 'first' | 'last', + ) { + return query.orderBy(field, (ob) => { + ob = sort === 'asc' ? ob.asc() : ob.desc(); + ob = nulls === 'first' ? ob.nullsFirst() : ob.nullsLast(); + return ob; + }); + } + + // #endregion } diff --git a/packages/orm/src/client/crud/dialects/sqlite.ts b/packages/orm/src/client/crud/dialects/sqlite.ts index 32bb4e4e..4ef87a46 100644 --- a/packages/orm/src/client/crud/dialects/sqlite.ts +++ b/packages/orm/src/client/crud/dialects/sqlite.ts @@ -1,8 +1,9 @@ import { invariant } from '@zenstackhq/common-helpers'; import Decimal from 'decimal.js'; import { - ExpressionWrapper, + expressionBuilder, sql, + type AliasableExpression, type Expression, type ExpressionBuilder, type RawBuilder, @@ -13,8 +14,8 @@ import { match } from 'ts-pattern'; import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types'; import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schema'; import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; -import type { FindArgs } from '../../crud-types'; -import { createInternalError, createNotSupportedError } from '../../errors'; +import type { FindArgs, SortOrder } from '../../crud-types'; +import { createInternalError, createInvalidInputError, createNotSupportedError } from '../../errors'; import { getDelegateDescendantModels, getManyToManyRelation, @@ -30,7 +31,41 @@ export class SqliteCrudDialect extends BaseCrudDialect return 'sqlite' as const; } - override transformPrimitive(value: unknown, type: BuiltinType, _forArrayField: boolean): unknown { + // #region capabilities + + override get supportsUpdateWithLimit() { + return false; + } + + override get supportsDeleteWithLimit() { + return false; + } + + override get supportsDistinctOn() { + return false; + } + + override get supportsReturning() { + return true; + } + + override get supportsDefaultAsFieldValue() { + return false; + } + + override get supportsInsertDefaultValues(): boolean { + return true; + } + + override get insertIgnoreMethod() { + return 'onConflict' as const; + } + + // #endregion + + // #region value transformation + + override transformInput(value: unknown, type: BuiltinType, _forArrayField: boolean): unknown { if (value === undefined) { return value; } @@ -50,7 +85,7 @@ export class SqliteCrudDialect extends BaseCrudDialect } if (Array.isArray(value)) { - return value.map((v) => this.transformPrimitive(v, type, false)); + return value.map((v) => this.transformInput(v, type, false)); } else { return match(type) .with('Boolean', () => (value ? 1 : 0)) @@ -137,6 +172,10 @@ export class SqliteCrudDialect extends BaseCrudDialect return value; } + // #endregion + + // #region other overrides + override buildRelationSelection( query: SelectQueryBuilder, model: string, @@ -404,28 +443,24 @@ export class SqliteCrudDialect extends BaseCrudDialect ); } - override get supportsUpdateWithLimit() { - return false; - } - - override get supportsDeleteWithLimit() { - return false; + override buildArrayLength(array: Expression): AliasableExpression { + return this.eb.fn('json_array_length', [array]); } - override get supportsDistinctOn() { - return false; + override buildArrayLiteralSQL(_values: unknown[]): AliasableExpression { + throw new Error('SQLite does not support array literals'); } - override buildArrayLength(array: Expression): ExpressionWrapper { - return this.eb.fn('json_array_length', [array]); + override castInt>(expression: T): T { + return expression; } - override buildArrayLiteralSQL(_values: unknown[]): string { - throw new Error('SQLite does not support array literals'); + override castText>(expression: T): T { + return this.eb.cast(expression, 'text') as unknown as T; } - override get supportInsertWithDefault() { - return false; + override trimTextQuotes>(expression: T): T { + return this.eb.fn('trim', [expression, sql.lit('"')]) as unknown as T; } override getFieldSqlType(fieldDef: FieldDef) { @@ -462,4 +497,48 @@ export class SqliteCrudDialect extends BaseCrudDialect // SQLite `LIKE` is case-insensitive, and there is no `ILIKE` return { supportsILike: false, likeCaseSensitive: false }; } + + override buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]) { + if (rows.length === 0) { + throw createInvalidInputError('At least one row is required to build values table'); + } + + // check all rows have the same length + const rowLength = rows[0]!.length; + + if (fields.length !== rowLength) { + throw createInvalidInputError('Number of fields must match number of columns in each row'); + } + + for (const row of rows) { + if (row.length !== rowLength) { + throw createInvalidInputError('All rows must have the same number of columns'); + } + } + + const eb = expressionBuilder(); + + return eb + .selectFrom( + sql`(VALUES ${sql.join( + rows.map((row) => sql`(${sql.join(row.map((v) => sql.val(v)))})`), + sql.raw(', '), + )})`.as('$values'), + ) + .select(fields.map((f, i) => eb.ref(`$values.column${i + 1}`).as(f.name))); + } + + protected override buildOrderByField( + query: SelectQueryBuilder, + field: Expression, + sort: SortOrder, + nulls: 'first' | 'last', + ) { + return query.orderBy(field, (ob) => { + ob = sort === 'asc' ? ob.asc() : ob.desc(); + ob = nulls === 'first' ? ob.nullsFirst() : ob.nullsLast(); + return ob; + }); + } + // #endregion } diff --git a/packages/orm/src/client/crud/operations/aggregate.ts b/packages/orm/src/client/crud/operations/aggregate.ts index f92a8518..da7af7b1 100644 --- a/packages/orm/src/client/crud/operations/aggregate.ts +++ b/packages/orm/src/client/crud/operations/aggregate.ts @@ -52,7 +52,14 @@ export class AggregateOperationHandler extends BaseOpe subQuery = this.dialect.buildSkipTake(subQuery, skip, take); // orderBy - subQuery = this.dialect.buildOrderBy(subQuery, this.model, this.model, parsedArgs.orderBy, negateOrderBy); + subQuery = this.dialect.buildOrderBy( + subQuery, + this.model, + this.model, + parsedArgs.orderBy, + negateOrderBy, + take, + ); return subQuery.as('$sub'); }); @@ -62,18 +69,18 @@ export class AggregateOperationHandler extends BaseOpe switch (key) { case '_count': { if (value === true) { - query = query.select((eb) => eb.cast(eb.fn.countAll(), 'integer').as('_count')); + query = query.select((eb) => this.dialect.castInt(eb.fn.countAll()).as('_count')); } else { Object.entries(value).forEach(([field, val]) => { if (val === true) { if (field === '_all') { query = query.select((eb) => - eb.cast(eb.fn.countAll(), 'integer').as(`_count._all`), + this.dialect.castInt(eb.fn.countAll()).as(`_count._all`), ); } else { query = query.select((eb) => - eb - .cast(eb.fn.count(eb.ref(`$sub.${field}` as any)), 'integer') + this.dialect + .castInt(eb.fn.count(eb.ref(`$sub.${field}`))) .as(`${key}.${field}`), ); } @@ -96,7 +103,7 @@ export class AggregateOperationHandler extends BaseOpe .with('_max', () => eb.fn.max) .with('_min', () => eb.fn.min) .exhaustive(); - return fn(eb.ref(`$sub.${field}` as any)).as(`${key}.${field}`); + return fn(eb.ref(`$sub.${field}`)).as(`${key}.${field}`); }); } }); diff --git a/packages/orm/src/client/crud/operations/base.ts b/packages/orm/src/client/crud/operations/base.ts index 5eb4b2d1..33cc1a64 100644 --- a/packages/orm/src/client/crud/operations/base.ts +++ b/packages/orm/src/client/crud/operations/base.ts @@ -8,6 +8,7 @@ import { sql, UpdateResult, type Compilable, + type ExpressionBuilder, type IsolationLevel, type QueryResult, type SelectQueryBuilder, @@ -430,13 +431,9 @@ export abstract class BaseOperationHandler { Array.isArray(value.set) ) { // deal with nested "set" for scalar lists - createFields[field] = this.dialect.transformPrimitive( - value.set, - fieldDef.type as BuiltinType, - true, - ); + createFields[field] = this.dialect.transformInput(value.set, fieldDef.type as BuiltinType, true); } else { - createFields[field] = this.dialect.transformPrimitive( + createFields[field] = this.dialect.transformInput( value, fieldDef.type as BuiltinType, !!fieldDef.array, @@ -469,19 +466,77 @@ export abstract class BaseOperationHandler { // return id fields if no returnFields specified returnFields = returnFields ?? requireIdFields(this.schema, model); - const query = kysely - .insertInto(model) - .$if(Object.keys(updatedData).length === 0, (qb) => qb.defaultValues()) - .$if(Object.keys(updatedData).length > 0, (qb) => qb.values(updatedData)) - .returning(returnFields as any) - .modifyEnd( - this.makeContextComment({ - model, - operation: 'create', - }), - ); + let createdEntity: any; + + if (this.dialect.supportsReturning) { + const query = kysely + .insertInto(model) + .$if(Object.keys(updatedData).length === 0, (qb) => + qb + // case for `INSERT INTO ... DEFAULT VALUES` syntax + .$if(this.dialect.supportsInsertDefaultValues, () => qb.defaultValues()) + // case for `INSERT INTO ... VALUES ({})` syntax + .$if(!this.dialect.supportsInsertDefaultValues, () => qb.values({})), + ) + .$if(Object.keys(updatedData).length > 0, (qb) => qb.values(updatedData)) + .returning(returnFields as any) + .modifyEnd( + this.makeContextComment({ + model, + operation: 'create', + }), + ); + + createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create'); + } else { + // Fallback for databases that don't support RETURNING (e.g., MySQL) + const insertQuery = kysely + .insertInto(model) + .$if(Object.keys(updatedData).length === 0, (qb) => + qb + // case for `INSERT INTO ... DEFAULT VALUES` syntax + .$if(this.dialect.supportsInsertDefaultValues, () => qb.defaultValues()) + // case for `INSERT INTO ... VALUES ({})` syntax + .$if(!this.dialect.supportsInsertDefaultValues, () => qb.values({})), + ) + .$if(Object.keys(updatedData).length > 0, (qb) => qb.values(updatedData)) + .modifyEnd( + this.makeContextComment({ + model, + operation: 'create', + }), + ); + + const insertResult = await this.executeQuery(kysely, insertQuery, 'create'); + + // Build WHERE clause to find the inserted record + const idFields = requireIdFields(this.schema, model); + const idValues: Record = {}; - const createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create'); + for (const idField of idFields) { + if (insertResult.insertId !== undefined && insertResult.insertId !== null) { + const fieldDef = this.requireField(model, idField); + if (this.isAutoIncrementField(fieldDef)) { + // auto-generated id value + idValues[idField] = insertResult.insertId; + continue; + } + } + + if (updatedData[idField] !== undefined) { + // ID was provided in the insert + idValues[idField] = updatedData[idField]; + } else { + throw createInternalError( + `Cannot determine ID field "${idField}" value for created model "${model}"`, + ); + } + } + + // for dialects that don't support RETURNING, the outside logic will always + // read back the created record, we just return the id fields here + createdEntity = idValues; + } if (Object.keys(postCreateRelations).length > 0) { // process nested creates that need to happen after the current entity is created @@ -513,6 +568,14 @@ export abstract class BaseOperationHandler { return createdEntity; } + private isAutoIncrementField(fieldDef: FieldDef) { + return ( + fieldDef.default && + ExpressionUtils.isCall(fieldDef.default) && + fieldDef.default.function === 'autoincrement' + ); + } + private async processBaseModelCreate(kysely: ToKysely, model: string, createFields: any, forModel: string) { const thisCreateFields: any = {}; const remainingFields: any = {}; @@ -618,7 +681,12 @@ export abstract class BaseOperationHandler { A: sortedRecords[0]!.entity[firstIds[0]!], B: sortedRecords[1]!.entity[secondIds[0]!], } as any) - .onConflict((oc) => oc.columns(['A', 'B'] as any).doNothing()) + // case for `INSERT IGNORE` or `ON CONFLICT DO NOTHING` syntax + .$if(this.dialect.insertIgnoreMethod === 'onConflict', (qb) => + qb.onConflict((oc) => oc.columns(['A', 'B'] as any).doNothing()), + ) + // case for `INSERT IGNORE` syntax + .$if(this.dialect.insertIgnoreMethod === 'ignore', (qb) => qb.ignore()) .execute(); return result[0] as any; } else { @@ -794,7 +862,7 @@ export abstract class BaseOperationHandler { const modelDef = this.requireModel(model); - let relationKeyPairs: { fk: string; pk: string }[] = []; + const relationKeyPairs: { fk: string; pk: string }[] = []; if (fromRelation) { const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( this.schema, @@ -804,7 +872,7 @@ export abstract class BaseOperationHandler { if (ownedByModel) { throw createInvalidInputError('incorrect relation hierarchy for createMany', model); } - relationKeyPairs = keyPairs; + relationKeyPairs.push(...keyPairs); } let createData = enumerate(input.data).map((item) => { @@ -812,7 +880,7 @@ export abstract class BaseOperationHandler { for (const [name, value] of Object.entries(item)) { const fieldDef = this.requireField(model, name); invariant(!fieldDef.relation, 'createMany does not support relations'); - newItem[name] = this.dialect.transformPrimitive(value, fieldDef.type as BuiltinType, !!fieldDef.array); + newItem[name] = this.dialect.transformInput(value, fieldDef.type as BuiltinType, !!fieldDef.array); } if (fromRelation) { for (const { fk, pk } of relationKeyPairs) { @@ -822,7 +890,7 @@ export abstract class BaseOperationHandler { return this.fillGeneratedAndDefaultValues(modelDef, newItem); }); - if (!this.dialect.supportInsertWithDefault) { + if (!this.dialect.supportsDefaultAsFieldValue) { // if the dialect doesn't support `DEFAULT` as insert field values, // we need to double check if data rows have mismatching fields, and // if so, make sure all fields have default value filled if not provided @@ -846,7 +914,7 @@ export abstract class BaseOperationHandler { fieldDef.default !== null && typeof fieldDef.default !== 'object' ) { - item[field] = this.dialect.transformPrimitive( + item[field] = this.dialect.transformInput( fieldDef.default, fieldDef.type as BuiltinType, !!fieldDef.array, @@ -876,7 +944,13 @@ export abstract class BaseOperationHandler { const query = kysely .insertInto(model) .values(createData) - .$if(!!input.skipDuplicates, (qb) => qb.onConflict((oc) => oc.doNothing())) + .$if(!!input.skipDuplicates, (qb) => + qb + // case for `INSERT ... ON CONFLICT DO NOTHING` syntax + .$if(this.dialect.insertIgnoreMethod === 'onConflict', () => qb.onConflict((oc) => oc.doNothing())) + // case for `INSERT IGNORE` syntax + .$if(this.dialect.insertIgnoreMethod === 'ignore', () => qb.ignore()), + ) .modifyEnd( this.makeContextComment({ model, @@ -889,8 +963,21 @@ export abstract class BaseOperationHandler { return { count: Number(result.numAffectedRows) } as Result; } else { fieldsToReturn = fieldsToReturn ?? requireIdFields(this.schema, model); - const result = await query.returning(fieldsToReturn as any).execute(); - return result as Result; + + if (this.dialect.supportsReturning) { + const result = await query.returning(fieldsToReturn as any).execute(); + return result as Result; + } else { + // Fallback for databases that don't support RETURNING (e.g., MySQL) + // For createMany without RETURNING, we can't reliably get all inserted records + // especially with auto-increment IDs. The best we can do is return the count. + // If users need the created records, they should use multiple create() calls + // or the application should query after insertion. + throw createNotSupportedError( + `\`createManyAndReturn\` is not supported for ${this.dialect.provider}. ` + + `Use multiple \`create\` calls or query the records after insertion.`, + ); + } } } @@ -923,12 +1010,21 @@ export abstract class BaseOperationHandler { } // create base model entity - const baseEntities = await this.createMany( - kysely, - model as GetModels, - { data: thisCreateRows, skipDuplicates }, - true, - ); + let baseEntities: unknown[]; + if (this.dialect.supportsReturning) { + baseEntities = await this.createMany( + kysely, + model as GetModels, + { data: thisCreateRows, skipDuplicates }, + true, + ); + } else { + // fall back to multiple creates if RETURNING is not supported + baseEntities = []; + for (const row of thisCreateRows) { + baseEntities.push(await this.create(kysely, model, row, undefined, true)); + } + } // copy over id fields from base model for (let i = 0; i < baseEntities.length; i++) { @@ -950,7 +1046,7 @@ export abstract class BaseOperationHandler { if (typeof fieldDef?.default === 'object' && 'kind' in fieldDef.default) { const generated = this.evalGenerator(fieldDef.default); if (generated !== undefined) { - values[field] = this.dialect.transformPrimitive( + values[field] = this.dialect.transformInput( generated, fieldDef.type as BuiltinType, !!fieldDef.array, @@ -958,7 +1054,7 @@ export abstract class BaseOperationHandler { } } else if (fieldDef?.updatedAt) { // TODO: should this work at kysely level instead? - values[field] = this.dialect.transformPrimitive(new Date(), 'DateTime', false); + values[field] = this.dialect.transformInput(new Date(), 'DateTime', false); } else if (fieldDef?.default !== undefined) { let value = fieldDef.default; if (fieldDef.type === 'Json') { @@ -969,11 +1065,7 @@ export abstract class BaseOperationHandler { value = JSON.parse(value); } } - values[field] = this.dialect.transformPrimitive( - value, - fieldDef.type as BuiltinType, - !!fieldDef.array, - ); + values[field] = this.dialect.transformInput(value, fieldDef.type as BuiltinType, !!fieldDef.array); } } } @@ -1043,39 +1135,7 @@ export abstract class BaseOperationHandler { throw createInvalidInputError('data must be an object'); } - const parentWhere: any = {}; - let m2m: ReturnType = undefined; - - if (fromRelation) { - m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field); - if (!m2m) { - // merge foreign key conditions from the relation - const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( - this.schema, - fromRelation.model, - fromRelation.field, - ); - if (ownedByModel) { - const fromEntity = await this.readUnique(kysely, fromRelation.model as GetModels, { - where: fromRelation.ids, - }); - for (const { fk, pk } of keyPairs) { - parentWhere[pk] = fromEntity[fk]; - } - } else { - for (const { fk, pk } of keyPairs) { - parentWhere[fk] = fromRelation.ids[pk]; - } - } - } else { - // many-to-many relation, filter for parent with "some" - const fromRelationFieldDef = this.requireField(fromRelation.model, fromRelation.field); - invariant(fromRelationFieldDef.relation?.opposite); - parentWhere[fromRelationFieldDef.relation.opposite] = { - some: fromRelation.ids, - }; - } - } + const parentWhere = await this.buildUpdateParentRelationFilter(kysely, fromRelation); let combinedWhere: WhereInput, false> = where ?? {}; if (Object.keys(parentWhere).length > 0) { @@ -1088,11 +1148,11 @@ export abstract class BaseOperationHandler { // fill in automatically updated fields const autoUpdatedFields: string[] = []; for (const [fieldName, fieldDef] of Object.entries(modelDef.fields)) { - if (fieldDef.updatedAt) { + if (fieldDef.updatedAt && finalData[fieldName] === undefined) { if (finalData === data) { finalData = clone(data); } - finalData[fieldName] = this.dialect.transformPrimitive(new Date(), 'DateTime', false); + finalData[fieldName] = this.dialect.transformInput(new Date(), 'DateTime', false); autoUpdatedFields.push(fieldName); } } @@ -1114,11 +1174,18 @@ export abstract class BaseOperationHandler { } let needIdRead = false; - if (modelDef.baseModel && !this.isIdFilter(model, combinedWhere)) { - // when updating a model with delegate base, base fields may be referenced in the filter, - // so we read the id out if the filter is not ready an id filter, and and use it as the - // update filter instead - needIdRead = true; + if (!this.isIdFilter(model, combinedWhere)) { + if (modelDef.baseModel) { + // when updating a model with delegate base, base fields may be referenced in the filter, + // so we read the id out if the filter is not ready an id filter, and and use it as the + // update filter instead + needIdRead = true; + } + if (!this.dialect.supportsReturning) { + // for dialects that don't support RETURNING, we need to read the id fields + // to identify the updated entity + needIdRead = true; + } } if (needIdRead) { @@ -1192,19 +1259,77 @@ export abstract class BaseOperationHandler { return thisEntity; } else { fieldsToReturn = fieldsToReturn ?? requireIdFields(this.schema, model); - const query = kysely - .updateTable(model) - .where(() => this.dialect.buildFilter(model, model, combinedWhere)) - .set(updateFields) - .returning(fieldsToReturn as any) - .modifyEnd( - this.makeContextComment({ - model, - operation: 'update', - }), - ); - const updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update'); + let updatedEntity: any; + + if (this.dialect.supportsReturning) { + const query = kysely + .updateTable(model) + .where(() => this.dialect.buildFilter(model, model, combinedWhere)) + .set(updateFields) + .returning(fieldsToReturn as any) + .modifyEnd( + this.makeContextComment({ + model, + operation: 'update', + }), + ); + + updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update'); + } else { + // Fallback for databases that don't support RETURNING (e.g., MySQL) + const updateQuery = kysely + .updateTable(model) + .where(() => this.dialect.buildFilter(model, model, combinedWhere)) + .set(updateFields) + .modifyEnd( + this.makeContextComment({ + model, + operation: 'update', + }), + ); + + const updateResult = await this.executeQuery(kysely, updateQuery, 'update'); + if (!updateResult.numAffectedRows) { + // no rows updated + updatedEntity = null; + } else { + // collect id field/values from the original filter + const idFields = requireIdFields(this.schema, model); + const filterIdValues: any = {}; + for (const key of idFields) { + if (combinedWhere[key] !== undefined && typeof combinedWhere[key] !== 'object') { + filterIdValues[key] = combinedWhere[key]; + } + } + + // check if we are updating any id fields + const updatingIdFields = idFields.some((idField) => idField in updateFields); + + if (Object.keys(filterIdValues).length === idFields.length && !updatingIdFields) { + // if we have all id fields in the original filter and ids are not being updated, + // we can simply return the id values as the update result + updatedEntity = filterIdValues; + } else { + // otherwise we need to re-query the updated entity + + // replace id fields in the filter with updated values if they are being updated + const readFilter: any = { ...combinedWhere }; + for (const idField of idFields) { + if (idField in updateFields && updateFields[idField] !== undefined) { + // if id fields are being updated, use the new values + readFilter[idField] = updateFields[idField]; + } + } + const selectQuery = kysely + .selectFrom(model) + .select(fieldsToReturn as any) + .where(() => this.dialect.buildFilter(model, model, readFilter)); + updatedEntity = await this.executeQueryTakeFirst(kysely, selectQuery, 'update'); + } + } + } + if (!updatedEntity) { if (throwIfNotFound) { throw createNotFoundError(model); @@ -1217,6 +1342,42 @@ export abstract class BaseOperationHandler { } } + private async buildUpdateParentRelationFilter(kysely: AnyKysely, fromRelation: FromRelationContext | undefined) { + const parentWhere: any = {}; + let m2m: ReturnType = undefined; + if (fromRelation) { + m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field); + if (!m2m) { + // merge foreign key conditions from the relation + const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + this.schema, + fromRelation.model, + fromRelation.field, + ); + if (ownedByModel) { + const fromEntity = await this.readUnique(kysely, fromRelation.model, { + where: fromRelation.ids, + }); + for (const { fk, pk } of keyPairs) { + parentWhere[pk] = fromEntity[fk]; + } + } else { + for (const { fk, pk } of keyPairs) { + parentWhere[fk] = fromRelation.ids[pk]; + } + } + } else { + // many-to-many relation, filter for parent with "some" + const fromRelationFieldDef = this.requireField(fromRelation.model, fromRelation.field); + invariant(fromRelationFieldDef.relation?.opposite); + parentWhere[fromRelationFieldDef.relation.opposite] = { + some: fromRelation.ids, + }; + } + } + return parentWhere; + } + private processScalarFieldUpdateData(model: string, field: string, data: any): any { const fieldDef = this.requireField(model, field); if (this.isNumericIncrementalUpdate(fieldDef, data[field])) { @@ -1229,7 +1390,7 @@ export abstract class BaseOperationHandler { return this.transformScalarListUpdate(model, field, fieldDef, data[field]); } - return this.dialect.transformPrimitive(data[field], fieldDef.type as BuiltinType, !!fieldDef.array); + return this.dialect.transformInput(data[field], fieldDef.type as BuiltinType, !!fieldDef.array); } private isNumericIncrementalUpdate(fieldDef: FieldDef, value: any) { @@ -1294,7 +1455,7 @@ export abstract class BaseOperationHandler { ); const key = Object.keys(payload)[0]; - const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, false); + const value = this.dialect.transformInput(payload[key!], fieldDef.type as BuiltinType, false); const eb = expressionBuilder(); const fieldRef = this.dialect.fieldRef(model, field); @@ -1317,7 +1478,7 @@ export abstract class BaseOperationHandler { ) { invariant(Object.keys(payload).length === 1, 'Only one of "set", "push" can be provided'); const key = Object.keys(payload)[0]; - const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, true); + const value = this.dialect.transformInput(payload[key!], fieldDef.type as BuiltinType, true); const eb = expressionBuilder(); const fieldRef = this.dialect.fieldRef(model, field); @@ -1351,6 +1512,7 @@ export abstract class BaseOperationHandler { limit: number | undefined, returnData: ReturnData, filterModel?: string, + fromRelation?: FromRelationContext, fieldsToReturn?: readonly string[], ): Promise { if (typeof data !== 'object') { @@ -1366,6 +1528,12 @@ export abstract class BaseOperationHandler { throw createNotSupportedError('Updating with a limit is not supported for polymorphic models'); } + const parentWhere = await this.buildUpdateParentRelationFilter(kysely, fromRelation); + let combinedWhere: WhereInput, false> = where ?? {}; + if (Object.keys(parentWhere).length > 0) { + combinedWhere = Object.keys(combinedWhere).length > 0 ? { AND: [parentWhere, combinedWhere] } : parentWhere; + } + filterModel ??= model; let updateFields: any = {}; @@ -1376,27 +1544,12 @@ export abstract class BaseOperationHandler { updateFields[field] = this.processScalarFieldUpdateData(model, field, data); } - let shouldFallbackToIdFilter = false; - - if (limit !== undefined && !this.dialect.supportsUpdateWithLimit) { - // if the dialect doesn't support update with limit natively, we'll - // simulate it by filtering by id with a limit - shouldFallbackToIdFilter = true; - } - - if (modelDef.isDelegate || modelDef.baseModel) { - // if the model is in a delegate hierarchy, we'll need to filter by - // id because the filter may involve fields in different models in - // the hierarchy - shouldFallbackToIdFilter = true; - } - let resultFromBaseModel: any = undefined; if (modelDef.baseModel) { const baseResult = await this.processBaseModelUpdateMany( kysely, modelDef.baseModel, - where, + combinedWhere, updateFields, filterModel, ); @@ -1410,12 +1563,27 @@ export abstract class BaseOperationHandler { return resultFromBaseModel ?? ((returnData ? [] : { count: 0 }) as Result); } + let shouldFallbackToIdFilter = false; + + if (limit !== undefined && !this.dialect.supportsUpdateWithLimit) { + // if the dialect doesn't support update with limit natively, we'll + // simulate it by filtering by id with a limit + shouldFallbackToIdFilter = true; + } + + if (modelDef.isDelegate || modelDef.baseModel) { + // if the model is in a delegate hierarchy, we'll need to filter by + // id because the filter may involve fields in different models in + // the hierarchy + shouldFallbackToIdFilter = true; + } + let query = kysely.updateTable(model).set(updateFields); if (!shouldFallbackToIdFilter) { // simple filter query = query - .where(() => this.dialect.buildFilter(model, model, where)) + .where(() => this.dialect.buildFilter(model, model, combinedWhere)) .$if(limit !== undefined, (qb) => qb.limit(limit!)); } else { query = query.where((eb) => @@ -1425,11 +1593,17 @@ export abstract class BaseOperationHandler { ...this.buildIdFieldRefs(kysely, model), ), 'in', - this.dialect - .buildSelectModel(filterModel, filterModel) - .where(this.dialect.buildFilter(filterModel, filterModel, where)) - .select(this.buildIdFieldRefs(kysely, filterModel)) - .$if(limit !== undefined, (qb) => qb.limit(limit!)), + // the outer "select *" is needed to isolate the sub query (as needed for dialects like mysql) + eb + .selectFrom( + this.dialect + .buildSelectModel(filterModel, filterModel) + .where(this.dialect.buildFilter(filterModel, filterModel, combinedWhere)) + .select(this.buildIdFieldRefs(kysely, filterModel)) + .$if(limit !== undefined, (qb) => qb.limit(limit!)) + .as('$sub'), + ) + .selectAll(), ), ); } @@ -1441,9 +1615,71 @@ export abstract class BaseOperationHandler { return { count: Number(result.numAffectedRows) } as Result; } else { fieldsToReturn = fieldsToReturn ?? requireIdFields(this.schema, model); - const finalQuery = query.returning(fieldsToReturn as any); - const result = await this.executeQuery(kysely, finalQuery, 'update'); - return result.rows as Result; + + if (this.dialect.supportsReturning) { + const finalQuery = query.returning(fieldsToReturn as any); + const result = await this.executeQuery(kysely, finalQuery, 'update'); + return result.rows as Result; + } else { + // Fallback for databases that don't support RETURNING (e.g., MySQL) + // First, select the records to be updated + let selectQuery = kysely.selectFrom(model).selectAll(); + + if (!shouldFallbackToIdFilter) { + selectQuery = selectQuery + .where(() => this.dialect.buildFilter(model, model, combinedWhere)) + .$if(limit !== undefined, (qb) => qb.limit(limit!)); + } else { + selectQuery = selectQuery.where((eb) => + eb( + eb.refTuple( + // @ts-expect-error + ...this.buildIdFieldRefs(kysely, model), + ), + 'in', + this.dialect + .buildSelectModel(filterModel, filterModel) + .where(this.dialect.buildFilter(filterModel, filterModel, combinedWhere)) + .select(this.buildIdFieldRefs(kysely, filterModel)) + .$if(limit !== undefined, (qb) => qb.limit(limit!)), + ), + ); + } + + const recordsToUpdate = await this.executeQuery(kysely, selectQuery, 'update'); + + // Execute the update + await this.executeQuery(kysely, query, 'update'); + + // Return the IDs of updated records, then query them back with updated values + if (recordsToUpdate.rows.length === 0) { + return [] as Result; + } + + const idFields = requireIdFields(this.schema, model); + const updatedIds = recordsToUpdate.rows.map((row: any) => { + const id: Record = {}; + for (const idField of idFields) { + id[idField] = row[idField]; + } + return id; + }); + + // Query back the updated records + const resultQuery = kysely + .selectFrom(model) + .selectAll() + .where((eb) => { + const conditions = updatedIds.map((id) => { + const idConditions = Object.entries(id).map(([field, value]) => eb.eb(field, '=', value)); + return eb.and(idConditions); + }); + return eb.or(conditions); + }); + + const result = await this.executeQuery(kysely, resultQuery, 'update'); + return result.rows as Result; + } } } @@ -1593,8 +1829,17 @@ export abstract class BaseOperationHandler { case 'updateMany': { for (const _item of enumerate(value)) { - const item = _item as { where: any; data: any }; - await this.update(kysely, fieldModel, item.where, item.data, fromRelationContext, false, false); + const item = _item as { where: any; data: any; limit: number | undefined }; + await this.updateMany( + kysely, + fieldModel, + item.where, + item.data, + item.limit, + false, + fieldModel, + fromRelationContext, + ); } break; } @@ -1680,9 +1925,7 @@ export abstract class BaseOperationHandler { if (!relationFieldDef.array) { const query = kysely .updateTable(model) - .where((eb) => - eb.and(keyPairs.map(({ fk, pk }) => eb(eb.ref(fk as any), '=', fromRelation.ids[pk]))), - ) + .where((eb) => eb.and(keyPairs.map(({ fk, pk }) => eb(eb.ref(fk), '=', fromRelation.ids[pk])))) .set(keyPairs.reduce((acc, { fk }) => ({ ...acc, [fk]: null }), {} as any)) .modifyEnd( this.makeContextComment({ @@ -1987,7 +2230,7 @@ export abstract class BaseOperationHandler { expectedDeleteCount = deleteConditions.length; } - let deleteResult: QueryResult; + let deleteResult: Awaited>; let deleteFromModel: string; const m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field); @@ -2052,7 +2295,7 @@ export abstract class BaseOperationHandler { } // validate result - if (throwForNotFound && expectedDeleteCount > deleteResult.rows.length) { + if (throwForNotFound && expectedDeleteCount > (deleteResult.numAffectedRows ?? 0)) { // some entities were not deleted throw createNotFoundError(deleteFromModel); } @@ -2085,7 +2328,6 @@ export abstract class BaseOperationHandler { } fieldsToReturn = fieldsToReturn ?? requireIdFields(this.schema, model); - let query = kysely.deleteFrom(model).returning(fieldsToReturn as any); let needIdFilter = false; @@ -2102,32 +2344,42 @@ export abstract class BaseOperationHandler { needIdFilter = true; } - if (!needIdFilter) { - query = query.where(() => this.dialect.buildFilter(model, model, where)); - } else { - query = query.where((eb) => - eb( - eb.refTuple( - // @ts-expect-error - ...this.buildIdFieldRefs(kysely, model), - ), - 'in', - this.dialect - .buildSelectModel(filterModel, filterModel) - .where(() => this.dialect.buildFilter(filterModel, filterModel, where)) - .select(this.buildIdFieldRefs(kysely, filterModel)) - .$if(limit !== undefined, (qb) => qb.limit(limit!)), - ), - ); - } + const deleteFilter = needIdFilter + ? (eb: ExpressionBuilder) => + eb( + eb.refTuple( + // @ts-expect-error + ...this.buildIdFieldRefs(kysely, model), + ), + 'in', + // the outer "select *" is needed to isolate the sub query (as needed for dialects like mysql) + eb + .selectFrom( + this.dialect + .buildSelectModel(filterModel, filterModel) + .where(() => this.dialect.buildFilter(filterModel, filterModel, where)) + .select(this.buildIdFieldRefs(kysely, filterModel)) + .$if(limit !== undefined, (qb) => qb.limit(limit!)) + .as('$sub'), + ) + .selectAll(), + ) + : () => this.dialect.buildFilter(model, model, where); // if the model being deleted has a relation to a model that extends a delegate model, and if that // relation is set to trigger a cascade delete from this model, the deletion will not automatically // clean up the base hierarchy of the relation side (because polymorphic model's cascade deletion // works downward not upward). We need to take care of the base deletions manually here. + await this.processDelegateRelationDelete(kysely, modelDef, where, limit); - query = query.modifyEnd(this.makeContextComment({ model, operation: 'delete' })); + const query = kysely + .deleteFrom(model) + .where(deleteFilter) + .$if(this.dialect.supportsReturning, (qb) => qb.returning(fieldsToReturn)) + .$if(limit !== undefined && this.dialect.supportsDeleteWithLimit, (qb) => qb.limit(limit!)) + .modifyEnd(this.makeContextComment({ model, operation: 'delete' })); + return this.executeQuery(kysely, query, 'delete'); } @@ -2269,6 +2521,11 @@ export abstract class BaseOperationHandler { return { needReadBack: true, selectedFields: undefined }; } + if (!this.dialect.supportsReturning) { + // if the dialect doesn't support RETURNING, we always need read back + return { needReadBack: true, selectedFields: undefined }; + } + if (args.include && typeof args.include === 'object' && Object.keys(args.include).length > 0) { // includes present, need read back to fetch relations return { needReadBack: true, selectedFields: undefined }; diff --git a/packages/orm/src/client/crud/operations/count.ts b/packages/orm/src/client/crud/operations/count.ts index 0b31d795..fd986175 100644 --- a/packages/orm/src/client/crud/operations/count.ts +++ b/packages/orm/src/client/crud/operations/count.ts @@ -38,15 +38,15 @@ export class CountOperationHandler extends BaseOperati query = query.select((eb) => Object.keys(parsedArgs.select!).map((key) => key === '_all' - ? eb.cast(eb.fn.countAll(), 'integer').as('_all') - : eb.cast(eb.fn.count(eb.ref(`${subQueryName}.${key}` as any)), 'integer').as(key), + ? this.dialect.castInt(eb.fn.countAll()).as('_all') + : this.dialect.castInt(eb.fn.count(eb.ref(`${subQueryName}.${key}`))).as(key), ), ); const result = await this.executeQuery(this.kysely, query, 'count'); return result.rows[0]; } else { // simple count all - query = query.select((eb) => eb.cast(eb.fn.countAll(), 'integer').as('count')); + query = query.select((eb) => this.dialect.castInt(eb.fn.countAll()).as('count')); const result = await this.executeQuery(this.kysely, query, 'count'); return (result.rows[0] as any).count as number; } diff --git a/packages/orm/src/client/crud/operations/create.ts b/packages/orm/src/client/crud/operations/create.ts index b58871c0..eeb0802d 100644 --- a/packages/orm/src/client/crud/operations/create.ts +++ b/packages/orm/src/client/crud/operations/create.ts @@ -62,7 +62,8 @@ export class CreateOperationHandler extends BaseOperat if (args === undefined) { return { count: 0 }; } - return this.createMany(this.kysely, this.model, args, false); + + return this.safeTransaction((tx) => this.createMany(tx, this.model, args, false)); } private async runCreateManyAndReturn(args?: CreateManyAndReturnArgs>) { diff --git a/packages/orm/src/client/crud/operations/delete.ts b/packages/orm/src/client/crud/operations/delete.ts index af9942a9..e0c3875b 100644 --- a/packages/orm/src/client/crud/operations/delete.ts +++ b/packages/orm/src/client/crud/operations/delete.ts @@ -33,9 +33,10 @@ export class DeleteOperationHandler extends BaseOperat }); } const deleteResult = await this.delete(tx, this.model, args.where, undefined, undefined, selectedFields); - if (deleteResult.rows.length === 0) { + if (!deleteResult.numAffectedRows) { throw createNotFoundError(this.model); } + return needReadBack ? preDeleteRead : deleteResult.rows[0]; }); @@ -53,7 +54,7 @@ export class DeleteOperationHandler extends BaseOperat async runDeleteMany(args: DeleteManyArgs> | undefined) { return await this.safeTransaction(async (tx) => { const result = await this.delete(tx, this.model, args?.where, args?.limit); - return { count: result.rows.length }; + return { count: Number(result.numAffectedRows ?? 0) }; }); } } diff --git a/packages/orm/src/client/crud/operations/group-by.ts b/packages/orm/src/client/crud/operations/group-by.ts index cae9f65f..ad429cf3 100644 --- a/packages/orm/src/client/crud/operations/group-by.ts +++ b/packages/orm/src/client/crud/operations/group-by.ts @@ -32,7 +32,7 @@ export class GroupByOperationHandler extends BaseOpera query = this.dialect.buildSkipTake(query, skip, take); // orderBy - query = this.dialect.buildOrderBy(query, this.model, this.model, parsedArgs.orderBy, negateOrderBy); + query = this.dialect.buildOrderBy(query, this.model, this.model, parsedArgs.orderBy, negateOrderBy, take); // having if (parsedArgs.having) { @@ -49,17 +49,17 @@ export class GroupByOperationHandler extends BaseOpera switch (key) { case '_count': { if (value === true) { - query = query.select((eb) => eb.cast(eb.fn.countAll(), 'integer').as('_count')); + query = query.select((eb) => this.dialect.castInt(eb.fn.countAll()).as('_count')); } else { Object.entries(value).forEach(([field, val]) => { if (val === true) { if (field === '_all') { query = query.select((eb) => - eb.cast(eb.fn.countAll(), 'integer').as(`_count._all`), + this.dialect.castInt(eb.fn.countAll()).as(`_count._all`), ); } else { query = query.select((eb) => - eb.cast(eb.fn.count(fieldRef(field)), 'integer').as(`${key}.${field}`), + this.dialect.castInt(eb.fn.count(fieldRef(field))).as(`${key}.${field}`), ); } } diff --git a/packages/orm/src/client/crud/operations/update.ts b/packages/orm/src/client/crud/operations/update.ts index 477f03f5..d8bd57b5 100644 --- a/packages/orm/src/client/crud/operations/update.ts +++ b/packages/orm/src/client/crud/operations/update.ts @@ -101,6 +101,7 @@ export class UpdateOperationHandler extends BaseOperat args.limit, true, undefined, + undefined, selectedFields, ); @@ -189,6 +190,11 @@ export class UpdateOperationHandler extends BaseOperat return baseResult; } + if (!this.dialect.supportsReturning) { + // if dialect doesn't support "returning", we always need to read back + return { needReadBack: true, selectedFields: undefined }; + } + // further check if we're not updating any non-relation fields, because if so, // SQL "returning" is not effective, we need to always read back diff --git a/packages/orm/src/client/executor/name-mapper.ts b/packages/orm/src/client/executor/name-mapper.ts index 37918875..ad6f1832 100644 --- a/packages/orm/src/client/executor/name-mapper.ts +++ b/packages/orm/src/client/executor/name-mapper.ts @@ -27,6 +27,9 @@ import { ValuesNode, } from 'kysely'; import type { EnumDef, EnumField, FieldDef, ModelDef, SchemaDef } from '../../schema'; +import type { ClientContract } from '../contract'; +import { getCrudDialect } from '../crud/dialects'; +import type { BaseCrudDialect } from '../crud/dialects/base-dialect'; import { extractFieldName, extractModelName, @@ -50,10 +53,12 @@ export class QueryNameMapper extends OperationNodeTransformer { private readonly modelToTableMap = new Map(); private readonly fieldToColumnMap = new Map(); private readonly scopes: Scope[] = []; + private readonly dialect: BaseCrudDialect; - constructor(private readonly schema: SchemaDef) { + constructor(private readonly client: ClientContract) { super(); - for (const [modelName, modelDef] of Object.entries(schema.models)) { + this.dialect = getCrudDialect(client.$schema, client.$options); + for (const [modelName, modelDef] of Object.entries(client.$schema.models)) { const mappedName = this.getMappedName(modelDef); if (mappedName) { this.modelToTableMap.set(modelName, mappedName); @@ -68,6 +73,10 @@ export class QueryNameMapper extends OperationNodeTransformer { } } + private get schema() { + return this.client.$schema; + } + // #region overrides protected override transformSelectQuery(node: SelectQueryNode) { @@ -761,7 +770,7 @@ export class QueryNameMapper extends OperationNodeTransformer { } // the explicit cast to "text" is needed to address postgres's case-when type inference issue - const finalExpr = caseWhen!.else(eb.cast(new ExpressionWrapper(node), 'text')).end(); + const finalExpr = caseWhen!.else(this.dialect.castText(new ExpressionWrapper(node))).end(); if (aliasName) { return finalExpr.as(aliasName).toOperationNode() as SelectionNodeChild; } else { diff --git a/packages/orm/src/client/executor/zenstack-query-executor.ts b/packages/orm/src/client/executor/zenstack-query-executor.ts index 9012cff8..4fa8b189 100644 --- a/packages/orm/src/client/executor/zenstack-query-executor.ts +++ b/packages/orm/src/client/executor/zenstack-query-executor.ts @@ -2,17 +2,23 @@ import { invariant } from '@zenstackhq/common-helpers'; import type { QueryId } from 'kysely'; import { AndNode, + ColumnNode, + ColumnUpdateNode, CompiledQuery, createQueryId, DefaultQueryExecutor, DeleteQueryNode, + expressionBuilder, InsertQueryNode, + PrimitiveValueListNode, ReturningNode, SelectionNode, SelectQueryNode, SingleConnectionProvider, TableNode, UpdateQueryNode, + ValueNode, + ValuesNode, WhereNode, type ConnectionProvider, type DatabaseConnection, @@ -27,9 +33,11 @@ import { match } from 'ts-pattern'; import type { ModelDef, SchemaDef, TypeDefDef } from '../../schema'; import { type ClientImpl } from '../client-impl'; import { TransactionIsolationLevel, type ClientContract } from '../contract'; +import { getCrudDialect } from '../crud/dialects'; +import type { BaseCrudDialect } from '../crud/dialects/base-dialect'; import { createDBQueryError, createInternalError, ORMError } from '../errors'; import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin'; -import { stripAlias } from '../query-utils'; +import { requireIdFields, stripAlias } from '../query-utils'; import { QueryNameMapper } from './name-mapper'; import type { ZenStackDriver } from './zenstack-driver'; @@ -41,8 +49,31 @@ type MutationInfo = { where: WhereNode | undefined; }; +type CallBeforeMutationHooksArgs = { + queryNode: OperationNode; + mutationInfo: MutationInfo; + loadBeforeMutationEntities: () => Promise[] | undefined>; + client: ClientContract; + queryId: QueryId; +}; + +type CallAfterMutationHooksArgs = { + queryResult: QueryResult; + queryNode: OperationNode; + mutationInfo: MutationInfo; + client: ClientContract; + filterFor: 'inTx' | 'outTx' | 'all'; + connection: DatabaseConnection; + queryId: QueryId; + beforeMutationEntities?: Record[]; + afterMutationEntities?: Record[]; +}; + export class ZenStackQueryExecutor extends DefaultQueryExecutor { + // #region constructor, fields and props + private readonly nameMapper: QueryNameMapper | undefined; + private readonly dialect: BaseCrudDialect; constructor( private client: ClientImpl, @@ -59,8 +90,10 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { client.$schema.provider.type === 'postgresql' || // postgres queries need to be schema-qualified this.schemaHasMappedNames(client.$schema) ) { - this.nameMapper = new QueryNameMapper(client.$schema); + this.nameMapper = new QueryNameMapper(client as unknown as ClientContract); } + + this.dialect = getCrudDialect(client.$schema, client.$options); } private schemaHasMappedNames(schema: SchemaDef) { @@ -82,12 +115,24 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { return this.client.$options; } - override executeQuery(compiledQuery: CompiledQuery) { + private get hasEntityMutationPlugins() { + return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation); + } + + private get hasEntityMutationPluginsWithAfterMutationHooks() { + return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation?.afterEntityMutation); + } + + // #endregion + + // #region main entry point + + override async executeQuery(compiledQuery: CompiledQuery) { // proceed with the query with kysely interceptors // if the query is a raw query, we need to carry over the parameters const queryParams = (compiledQuery as any).$raw ? compiledQuery.parameters : undefined; - return this.provideConnection(async (connection) => { + const result = await this.provideConnection(async (connection) => { let startedTx = false; try { // mutations are wrapped in tx if not already in one @@ -124,6 +169,8 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { } } }); + + return this.ensureProperQueryResult(compiledQuery.query, result); } private async proceedQueryWithKyselyInterceptors( @@ -161,63 +208,39 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { return result; } - private getMutationInfo(queryNode: MutationQueryNode): MutationInfo { - const model = this.getMutationModel(queryNode); - const { action, where } = match(queryNode) - .when(InsertQueryNode.is, () => ({ - action: 'create' as const, - where: undefined, - })) - .when(UpdateQueryNode.is, (node) => ({ - action: 'update' as const, - where: node.where, - })) - .when(DeleteQueryNode.is, (node) => ({ - action: 'delete' as const, - where: node.where, - })) - .exhaustive(); - - return { model, action, where }; - } - private async proceedQuery( connection: DatabaseConnection, query: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: QueryId, ) { - let compiled: CompiledQuery | undefined; - if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) { // no need to handle mutation hooks, just proceed - const finalQuery = this.processNameMapping(query); - - // inherit the original queryId - compiled = this.compileQuery(finalQuery, queryId); - if (parameters) { - compiled = { ...compiled, parameters }; - } - return this.internalExecuteQuery(connection, compiled); + return this.internalExecuteQuery(query, connection, queryId, parameters); } - if ( + let preUpdateIds: Record[] | undefined; + const mutationModel = this.getMutationModel(query); + const needLoadAfterMutationEntities = (InsertQueryNode.is(query) || UpdateQueryNode.is(query)) && - this.hasEntityMutationPluginsWithAfterMutationHooks - ) { - // need to make sure the query node has "returnAll" for insert and update queries - // so that after-mutation hooks can get the mutated entities with all fields - query = { - ...query, - returning: ReturningNode.create([SelectionNode.createSelectAll()]), - }; - } - const finalQuery = this.processNameMapping(query); - - // inherit the original queryId - compiled = this.compileQuery(finalQuery, queryId); - if (parameters) { - compiled = { ...compiled, parameters }; + this.hasEntityMutationPluginsWithAfterMutationHooks; + + if (needLoadAfterMutationEntities) { + if (this.dialect.supportsReturning) { + // need to make sure the query node has "returnAll" for insert and update queries + // so that after-mutation hooks can get the mutated entities with all fields + query = { + ...query, + returning: ReturningNode.create([SelectionNode.createSelectAll()]), + }; + } else { + if (UpdateQueryNode.is(query)) { + // if we're updating and the dialect doesn't support RETURNING, need to load + // entity IDs before the update in so we can use them to load the entities + // after the update + preUpdateIds = await this.getPreUpdateIds(mutationModel, query, connection); + } + } } // the client passed to hooks needs to be in sync with current in-transaction @@ -226,163 +249,89 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { const connectionClient = this.createClientForConnection(connection, currentlyInTx); - const mutationInfo = this.getMutationInfo(finalQuery); + const mutationInfo = this.getMutationInfo(query); // cache already loaded before-mutation entities let beforeMutationEntities: Record[] | undefined; const loadBeforeMutationEntities = async () => { if (beforeMutationEntities === undefined && (UpdateQueryNode.is(query) || DeleteQueryNode.is(query))) { - beforeMutationEntities = await this.loadEntities(mutationInfo.model, mutationInfo.where, connection); + beforeMutationEntities = await this.loadEntities( + mutationInfo.model, + mutationInfo.where, + connection, + undefined, + ); } return beforeMutationEntities; }; // call before mutation hooks - await this.callBeforeMutationHooks( - finalQuery, + await this.callBeforeMutationHooks({ + queryNode: query, mutationInfo, loadBeforeMutationEntities, - connectionClient, + client: connectionClient, queryId, - ); + }); + + // execute the final query + const result = await this.internalExecuteQuery(query, connection, queryId, parameters); + + let afterMutationEntities: Record[] | undefined; + if (needLoadAfterMutationEntities) { + afterMutationEntities = await this.loadAfterMutationEntities( + mutationInfo, + query, + result, + connection, + preUpdateIds, + ); + } - const result = await this.internalExecuteQuery(connection, compiled); + const baseArgs: CallAfterMutationHooksArgs = { + queryResult: result, + queryNode: query, + mutationInfo, + filterFor: 'all', + client: connectionClient, + connection, + queryId, + beforeMutationEntities, + afterMutationEntities, + }; if (!this.driver.isTransactionConnection(connection)) { // not in a transaction, just call all after-mutation hooks - await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'all', queryId); + await this.callAfterMutationHooks({ + ...baseArgs, + filterFor: 'all', + }); } else { // run after-mutation hooks that are requested to be run inside tx - await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'inTx', queryId); + await this.callAfterMutationHooks({ + ...baseArgs, + filterFor: 'inTx', + }); // register other after-mutation hooks to be run after the tx is committed this.driver.registerTransactionCommitCallback(connection, () => - this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'outTx', queryId), + this.callAfterMutationHooks({ + ...baseArgs, + filterFor: 'outTx', + }), ); } return result; } - private processNameMapping(query: Node): Node { - return this.nameMapper?.transformNode(query) ?? query; - } - - private createClientForConnection(connection: DatabaseConnection, inTx: boolean) { - const innerExecutor = this.withConnectionProvider(new SingleConnectionProvider(connection)); - innerExecutor.suppressMutationHooks = true; - const innerClient = this.client.withExecutor(innerExecutor); - if (inTx) { - innerClient.forceTransaction(); - } - return innerClient as unknown as ClientContract; - } - - private get hasEntityMutationPlugins() { - return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation); - } - - private get hasEntityMutationPluginsWithAfterMutationHooks() { - return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation?.afterEntityMutation); - } - - private isMutationNode(queryNode: RootOperationNode): queryNode is MutationQueryNode { - return InsertQueryNode.is(queryNode) || UpdateQueryNode.is(queryNode) || DeleteQueryNode.is(queryNode); - } - - override withPlugin(plugin: KyselyPlugin) { - return new ZenStackQueryExecutor( - this.client, - this.driver, - this.compiler, - this.adapter, - this.connectionProvider, - [...this.plugins, plugin], - this.suppressMutationHooks, - ); - } - - override withPlugins(plugins: ReadonlyArray) { - return new ZenStackQueryExecutor( - this.client, - this.driver, - this.compiler, - this.adapter, - this.connectionProvider, - [...this.plugins, ...plugins], - this.suppressMutationHooks, - ); - } + // #endregion - override withPluginAtFront(plugin: KyselyPlugin) { - return new ZenStackQueryExecutor( - this.client, - this.driver, - this.compiler, - this.adapter, - this.connectionProvider, - [plugin, ...this.plugins], - this.suppressMutationHooks, - ); - } - - override withoutPlugins() { - return new ZenStackQueryExecutor( - this.client, - this.driver, - this.compiler, - this.adapter, - this.connectionProvider, - [], - this.suppressMutationHooks, - ); - } - - override withConnectionProvider(connectionProvider: ConnectionProvider) { - const newExecutor = new ZenStackQueryExecutor( - this.client, - this.driver, - this.compiler, - this.adapter, - connectionProvider, - this.plugins as KyselyPlugin[], - this.suppressMutationHooks, - ); - // replace client with a new one associated with the new executor - newExecutor.client = this.client.withExecutor(newExecutor); - return newExecutor; - } + // #region before and after mutation hooks - private getMutationModel(queryNode: OperationNode): string { - return match(queryNode) - .when(InsertQueryNode.is, (node) => { - invariant(node.into, 'InsertQueryNode must have an into clause'); - return node.into.table.identifier.name; - }) - .when(UpdateQueryNode.is, (node) => { - invariant(node.table, 'UpdateQueryNode must have a table'); - const { node: tableNode } = stripAlias(node.table); - invariant(TableNode.is(tableNode), 'UpdateQueryNode must use a TableNode'); - return tableNode.table.identifier.name; - }) - .when(DeleteQueryNode.is, (node) => { - invariant(node.from.froms.length === 1, 'Delete query must have exactly one from table'); - const { node: tableNode } = stripAlias(node.from.froms[0]!); - invariant(TableNode.is(tableNode), 'DeleteQueryNode must use a TableNode'); - return tableNode.table.identifier.name; - }) - .otherwise((node) => { - throw createInternalError(`Invalid query node: ${node}`); - }) as string; - } + private async callBeforeMutationHooks(args: CallBeforeMutationHooksArgs) { + const { queryNode, mutationInfo, loadBeforeMutationEntities, client, queryId } = args; - private async callBeforeMutationHooks( - queryNode: OperationNode, - mutationInfo: MutationInfo, - loadBeforeMutationEntities: () => Promise[] | undefined>, - client: ClientContract, - queryId: QueryId, - ) { if (this.options.plugins) { for (const plugin of this.options.plugins) { const onEntityMutation = plugin.onEntityMutation; @@ -402,14 +351,10 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { } } - private async callAfterMutationHooks( - queryResult: QueryResult, - queryNode: OperationNode, - mutationInfo: MutationInfo, - client: ClientContract, - filterFor: 'inTx' | 'outTx' | 'all', - queryId: QueryId, - ) { + private async callAfterMutationHooks(args: CallAfterMutationHooksArgs) { + const { queryNode, mutationInfo, client, filterFor, queryId, beforeMutationEntities, afterMutationEntities } = + args; + const hooks: AfterEntityMutationCallback[] = []; // tsc perf @@ -434,46 +379,261 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { return; } - const mutationModel = this.getMutationModel(queryNode); - - const loadAfterMutationEntities = async () => { - if (mutationInfo.action === 'delete') { - return undefined; - } else { - return queryResult.rows as Record[]; - } - }; - for (const hook of hooks) { await hook({ - model: mutationModel, + model: mutationInfo.model, action: mutationInfo.action, queryNode, - loadAfterMutationEntities, + loadAfterMutationEntities: () => Promise.resolve(afterMutationEntities), + beforeMutationEntities, client, queryId, }); } } + private async loadAfterMutationEntities( + mutationInfo: MutationInfo, + queryNode: OperationNode, + queryResult: QueryResult, + connection: DatabaseConnection, + preUpdateIds: Record[] | undefined, + ): Promise[] | undefined> { + if (mutationInfo.action === 'delete') { + return undefined; + } + + if (this.dialect.supportsReturning) { + // entities are returned in the query result + return queryResult.rows as Record[]; + } else { + const mutatedIds = InsertQueryNode.is(queryNode) + ? this.getInsertIds(mutationInfo.model, queryNode, queryResult) + : preUpdateIds; + + if (mutatedIds) { + const idFields = requireIdFields(this.client.$schema, mutationInfo.model); + const eb = expressionBuilder(); + const filter = eb( + // @ts-ignore + eb.refTuple(...idFields), + 'in', + mutatedIds.map((idObj) => + eb.tuple( + // @ts-ignore + ...idFields.map((idField) => eb.val(idObj[idField] as any)), + ), + ), + ); + const entities = await this.loadEntities( + mutationInfo.model, + WhereNode.create(filter.toOperationNode()), + connection, + undefined, + ); + return entities; + } else { + console.warn( + `Unable to load after-mutation entities for hooks: model "${mutationInfo.model}", operation "${mutationInfo.action}". +This happens when the following conditions are met: + +1. The database does not support RETURNING clause for INSERT/UPDATE, e.g., MySQL. +2. The mutation creates or updates multiple entities at once. +3. For create: the model does not have all ID fields explicitly set in the mutation data. +4. For update: the mutation modifies ID fields. + +In such cases, ZenStack cannot reliably determine the IDs of the mutated entities to reload them. +`, + ); + return []; + } + } + } + + private async getPreUpdateIds(mutationModel: string, query: UpdateQueryNode, connection: DatabaseConnection) { + // Get the ID fields for this model + const idFields = requireIdFields(this.client.$schema, mutationModel); + + // Check if the update modifies any ID fields + if (query.updates) { + for (const update of query.updates) { + if (ColumnUpdateNode.is(update)) { + // Extract the column name from the update + const columnNode = update.column; + if (ColumnNode.is(columnNode)) { + const columnName = columnNode.column.name; + if (idFields.includes(columnName)) { + // ID field is being updated, return undefined + return undefined; + } + } + } + } + } + + // No ID fields are being updated, load the entities + return await this.loadEntities(this.getMutationModel(query), query.where, connection, idFields); + } + + private getInsertIds( + mutationModel: string, + query: InsertQueryNode, + queryResult: QueryResult, + ): Record[] | undefined { + const idFields = requireIdFields(this.client.$schema, mutationModel); + + if ( + InsertQueryNode.is(query) && + queryResult.numAffectedRows === 1n && + queryResult.insertId && + idFields.length === 1 + ) { + // single row creation, return the insertId directly + return [ + { + [idFields[0]!]: queryResult.insertId, + }, + ]; + } + + const columns = query.columns; + if (!columns) { + return undefined; + } + + const values = query.values; + if (!values || !ValuesNode.is(values)) { + return undefined; + } + + // Extract ID values for each row + const allIds: Record[] = []; + for (const valuesItem of values.values) { + const rowIds: Record = {}; + + if (PrimitiveValueListNode.is(valuesItem)) { + // PrimitiveValueListNode case + invariant(valuesItem.values.length === columns.length, 'Values count must match columns count'); + for (const idField of idFields) { + const colIndex = columns.findIndex((col) => col.column.name === idField); + if (colIndex === -1) { + // ID field not included in insert columns + return undefined; + } + rowIds[idField] = valuesItem.values[colIndex]; + } + } else { + // ValueListNode case + invariant(valuesItem.values.length === columns.length, 'Values count must match columns count'); + for (const idField of idFields) { + const colIndex = columns.findIndex((col) => col.column.name === idField); + if (colIndex === -1) { + // ID field not included in insert columns + return undefined; + } + const valueNode = valuesItem.values[colIndex]; + if (!valueNode || !ValueNode.is(valueNode)) { + // not a literal value + return undefined; + } + rowIds[idField] = valueNode.value; + } + } + + allIds.push(rowIds); + } + + return allIds; + } + private async loadEntities( model: string, where: WhereNode | undefined, connection: DatabaseConnection, + fieldsToLoad: readonly string[] | undefined, ): Promise[]> { - const selectQuery = this.kysely.selectFrom(model).selectAll(); + let selectQuery = this.kysely.selectFrom(model); + if (fieldsToLoad) { + selectQuery = selectQuery.select(fieldsToLoad); + } else { + selectQuery = selectQuery.selectAll(); + } let selectQueryNode = selectQuery.toOperationNode() as SelectQueryNode; selectQueryNode = { ...selectQueryNode, where: this.andNodes(selectQueryNode.where, where), }; - const compiled = this.compileQuery(selectQueryNode, createQueryId()); // execute the query directly with the given connection to avoid triggering // any other side effects - const result = await this.internalExecuteQuery(connection, compiled); + const result = await this.internalExecuteQuery(selectQueryNode, connection); return result.rows as Record[]; } + // #endregion + + // #region utilities + + private getMutationInfo(queryNode: MutationQueryNode): MutationInfo { + const model = this.getMutationModel(queryNode); + const { action, where } = match(queryNode) + .when(InsertQueryNode.is, () => ({ + action: 'create' as const, + where: undefined, + })) + .when(UpdateQueryNode.is, (node) => ({ + action: 'update' as const, + where: node.where, + })) + .when(DeleteQueryNode.is, (node) => ({ + action: 'delete' as const, + where: node.where, + })) + .exhaustive(); + + return { model, action, where }; + } + + private isMutationNode(queryNode: RootOperationNode): queryNode is MutationQueryNode { + return InsertQueryNode.is(queryNode) || UpdateQueryNode.is(queryNode) || DeleteQueryNode.is(queryNode); + } + + private getMutationModel(queryNode: OperationNode): string { + return match(queryNode) + .when(InsertQueryNode.is, (node) => { + invariant(node.into, 'InsertQueryNode must have an into clause'); + return node.into.table.identifier.name; + }) + .when(UpdateQueryNode.is, (node) => { + invariant(node.table, 'UpdateQueryNode must have a table'); + const { node: tableNode } = stripAlias(node.table); + invariant(TableNode.is(tableNode), 'UpdateQueryNode must use a TableNode'); + return tableNode.table.identifier.name; + }) + .when(DeleteQueryNode.is, (node) => { + invariant(node.from.froms.length === 1, 'Delete query must have exactly one from table'); + const { node: tableNode } = stripAlias(node.from.froms[0]!); + invariant(TableNode.is(tableNode), 'DeleteQueryNode must use a TableNode'); + return tableNode.table.identifier.name; + }) + .otherwise((node) => { + throw createInternalError(`Invalid query node: ${node}`); + }) as string; + } + + private processNameMapping(query: Node): Node { + return this.nameMapper?.transformNode(query) ?? query; + } + + private createClientForConnection(connection: DatabaseConnection, inTx: boolean) { + const innerExecutor = this.withConnectionProvider(new SingleConnectionProvider(connection)); + innerExecutor.suppressMutationHooks = true; + const innerClient = this.client.withExecutor(innerExecutor); + if (inTx) { + innerClient.forceTransaction(); + } + return innerClient as unknown as ClientContract; + } + private andNodes(condition1: WhereNode | undefined, condition2: WhereNode | undefined) { if (condition1 && condition2) { return WhereNode.create(AndNode.create(condition1, condition2)); @@ -484,9 +644,24 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { } } - private async internalExecuteQuery(connection: DatabaseConnection, compiledQuery: CompiledQuery) { + private async internalExecuteQuery( + query: RootOperationNode, + connection: DatabaseConnection, + queryId?: QueryId, + parameters?: readonly unknown[], + ) { + // no need to handle mutation hooks, just proceed + const finalQuery = this.processNameMapping(query); + + // inherit the original queryId + let compiledQuery = this.compileQuery(finalQuery, queryId ?? createQueryId()); + if (parameters) { + compiledQuery = { ...compiledQuery, parameters: parameters }; + } + try { - return await connection.executeQuery(compiledQuery); + const result = await connection.executeQuery(compiledQuery); + return this.ensureProperQueryResult(compiledQuery.query, result); } catch (err) { throw createDBQueryError( `Failed to execute query: ${err}`, @@ -496,4 +671,88 @@ export class ZenStackQueryExecutor extends DefaultQueryExecutor { ); } } + + private ensureProperQueryResult(query: RootOperationNode, result: QueryResult) { + let finalResult = result; + + if (this.isMutationNode(query)) { + // Kysely dialects don't consistently set numAffectedRows, so we fix it here + // to simplify the consumer's code + finalResult = { + ...result, + numAffectedRows: result.numAffectedRows ?? BigInt(result.rows.length), + }; + } + + return finalResult; + } + + // #endregion + + // #region other overrides + + override withPlugin(plugin: KyselyPlugin) { + return new ZenStackQueryExecutor( + this.client, + this.driver, + this.compiler, + this.adapter, + this.connectionProvider, + [...this.plugins, plugin], + this.suppressMutationHooks, + ); + } + + override withPlugins(plugins: ReadonlyArray) { + return new ZenStackQueryExecutor( + this.client, + this.driver, + this.compiler, + this.adapter, + this.connectionProvider, + [...this.plugins, ...plugins], + this.suppressMutationHooks, + ); + } + + override withPluginAtFront(plugin: KyselyPlugin) { + return new ZenStackQueryExecutor( + this.client, + this.driver, + this.compiler, + this.adapter, + this.connectionProvider, + [plugin, ...this.plugins], + this.suppressMutationHooks, + ); + } + + override withoutPlugins() { + return new ZenStackQueryExecutor( + this.client, + this.driver, + this.compiler, + this.adapter, + this.connectionProvider, + [], + this.suppressMutationHooks, + ); + } + + override withConnectionProvider(connectionProvider: ConnectionProvider) { + const newExecutor = new ZenStackQueryExecutor( + this.client, + this.driver, + this.compiler, + this.adapter, + connectionProvider, + this.plugins as KyselyPlugin[], + this.suppressMutationHooks, + ); + // replace client with a new one associated with the new executor + newExecutor.client = this.client.withExecutor(newExecutor); + return newExecutor; + } + + // #endregion } diff --git a/packages/orm/src/client/functions.ts b/packages/orm/src/client/functions.ts index 5690a1e4..1f3ff6ce 100644 --- a/packages/orm/src/client/functions.ts +++ b/packages/orm/src/client/functions.ts @@ -53,8 +53,11 @@ const textMatch = ( op = 'like'; } + // coalesce to empty string to consistently handle nulls across databases + searchExpr = eb.fn.coalesce(searchExpr, sql.lit('')); + // escape special characters in search string - const escapedSearch = sql`REPLACE(REPLACE(REPLACE(CAST(${searchExpr} as text), '\\', '\\\\'), '%', '\\%'), '_', '\\_')`; + const escapedSearch = sql`REPLACE(REPLACE(REPLACE(${dialect.castText(searchExpr)}, ${sql.val('\\')}, ${sql.val('\\\\')}), ${sql.val('%')}, ${sql.val('\\%')}), ${sql.val('_')}, ${sql.val('\\_')})`; searchExpr = match(method) .with('contains', () => eb.fn('CONCAT', [sql.lit('%'), escapedSearch, sql.lit('%')])) @@ -62,7 +65,7 @@ const textMatch = ( .with('endsWith', () => eb.fn('CONCAT', [sql.lit('%'), escapedSearch])) .exhaustive(); - return sql`${fieldExpr} ${sql.raw(op)} ${searchExpr} escape '\\'`; + return sql`${fieldExpr} ${sql.raw(op)} ${searchExpr} escape ${sql.val('\\')}`; }; export const has: ZModelFunction = (eb, args) => { diff --git a/packages/orm/src/client/helpers/schema-db-pusher.ts b/packages/orm/src/client/helpers/schema-db-pusher.ts index 01b265c4..38df33ce 100644 --- a/packages/orm/src/client/helpers/schema-db-pusher.ts +++ b/packages/orm/src/client/helpers/schema-db-pusher.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { CreateTableBuilder, sql, type ColumnDataType, type OnModifyForeignAction } from 'kysely'; +import { CreateTableBuilder, sql, type ColumnDataType, type OnModifyForeignAction, type RawBuilder } from 'kysely'; import toposort from 'toposort'; import { match } from 'ts-pattern'; import { @@ -13,6 +13,11 @@ import { import type { ToKysely } from '../query-builder'; import { requireModel } from '../query-utils'; +/** + * This class is for testing purposes only. It should never be used in production. + * + * @private + */ export class SchemaDbPusher { constructor( private readonly schema: Schema, @@ -21,7 +26,7 @@ export class SchemaDbPusher { async push() { await this.kysely.transaction().execute(async (tx) => { - if (this.schema.enums && this.schema.provider.type === 'postgresql') { + if (this.schema.enums && this.providerSupportsNativeEnum) { for (const [name, enumDef] of Object.entries(this.schema.enums)) { let enumValues: string[]; if (enumDef.fields) { @@ -57,6 +62,10 @@ export class SchemaDbPusher { }); } + private get providerSupportsNativeEnum() { + return ['postgresql'].includes(this.schema.provider.type); + } + private sortModels(models: ModelDef[]): ModelDef[] { const graph: [ModelDef, ModelDef | undefined][] = []; @@ -114,7 +123,7 @@ export class SchemaDbPusher { // create fk constraint const baseModelDef = requireModel(this.schema, modelDef.baseModel); table = table.addForeignKeyConstraint( - `fk_${modelDef.baseModel}_delegate`, + `fk_${modelDef.baseModel}_${modelDef.name}_delegate`, baseModelDef.idFields as string[], modelDef.baseModel, baseModelDef.idFields as string[], @@ -213,13 +222,25 @@ export class SchemaDbPusher { } // @default - if (fieldDef.default !== undefined) { + if (fieldDef.default !== undefined && this.isDefaultValueSupportedForType(fieldDef.type)) { if (typeof fieldDef.default === 'object' && 'kind' in fieldDef.default) { if (ExpressionUtils.isCall(fieldDef.default) && fieldDef.default.function === 'now') { - col = col.defaultTo(sql`CURRENT_TIMESTAMP`); + col = + this.schema.provider.type === 'mysql' + ? col.defaultTo(sql`CURRENT_TIMESTAMP(3)`) + : col.defaultTo(sql`CURRENT_TIMESTAMP`); } } else { - col = col.defaultTo(fieldDef.default); + if ( + this.schema.provider.type === 'mysql' && + fieldDef.type === 'DateTime' && + typeof fieldDef.default === 'string' + ) { + const defaultValue = new Date(fieldDef.default).toISOString().replace('Z', '+00:00'); + col = col.defaultTo(defaultValue); + } else { + col = col.defaultTo(fieldDef.default); + } } } @@ -233,7 +254,7 @@ export class SchemaDbPusher { col = col.notNull(); } - if (this.isAutoIncrement(fieldDef) && this.schema.provider.type === 'sqlite') { + if (this.isAutoIncrement(fieldDef) && this.columnSupportsAutoIncrement()) { col = col.autoIncrement(); } @@ -241,9 +262,43 @@ export class SchemaDbPusher { }); } + private isDefaultValueSupportedForType(type: string) { + return match(this.schema.provider.type) + .with('postgresql', () => true) + .with('sqlite', () => true) + .with('mysql', () => !['Json', 'Bytes'].includes(type)) + .exhaustive(); + } + private mapFieldType(fieldDef: FieldDef) { if (this.schema.enums?.[fieldDef.type]) { - return this.schema.provider.type === 'postgresql' ? sql.ref(fieldDef.type) : 'text'; + if (this.schema.provider.type === 'postgresql') { + return sql.ref(fieldDef.type); + } else if (this.schema.provider.type === 'mysql') { + // MySQL requires inline ENUM definition + const enumDef = this.schema.enums[fieldDef.type]!; + let enumValues: string[]; + if (enumDef.fields) { + enumValues = Object.values(enumDef.fields).map((f) => { + const mapAttr = f.attributes?.find((a) => a.name === '@map'); + if (!mapAttr || !mapAttr.args?.[0]) { + return f.name; + } else { + const mappedName = ExpressionUtils.getLiteralValue(mapAttr.args[0].value); + invariant( + mappedName && typeof mappedName === 'string', + `Invalid @map attribute for enum field ${f.name}`, + ); + return mappedName; + } + }); + } else { + enumValues = Object.values(enumDef.values); + } + return sql.raw(`enum(${enumValues.map((v) => `'${v}'`).join(', ')})`); + } else { + return 'text'; + } } if (this.isAutoIncrement(fieldDef) && this.schema.provider.type === 'postgresql') { @@ -251,20 +306,20 @@ export class SchemaDbPusher { } if (this.isCustomType(fieldDef.type)) { - return 'jsonb'; + return this.jsonType; } const type = fieldDef.type as BuiltinType; - const result = match(type) - .with('String', () => 'text') - .with('Boolean', () => 'boolean') - .with('Int', () => 'integer') - .with('Float', () => 'real') - .with('BigInt', () => 'bigint') - .with('Decimal', () => 'decimal') - .with('DateTime', () => 'timestamp') - .with('Bytes', () => (this.schema.provider.type === 'postgresql' ? 'bytea' : 'blob')) - .with('Json', () => 'jsonb') + const result = match>(type) + .with('String', () => this.stringType) + .with('Boolean', () => this.booleanType) + .with('Int', () => this.intType) + .with('Float', () => this.floatType) + .with('BigInt', () => this.bigIntType) + .with('Decimal', () => this.decimalType) + .with('DateTime', () => this.dateTimeType) + .with('Bytes', () => this.bytesType) + .with('Json', () => this.jsonType) .otherwise(() => { throw new Error(`Unsupported field type: ${type}`); }); @@ -339,4 +394,63 @@ export class SchemaDbPusher { .with('SetDefault', () => 'set default') .exhaustive(); } + + // #region Type mappings and capabilities + + private get jsonType(): ColumnDataType { + return match(this.schema.provider.type) + .with('mysql', () => 'json') + .otherwise(() => 'jsonb'); + } + + private get bytesType(): ColumnDataType { + return match(this.schema.provider.type) + .with('postgresql', () => 'bytea') + .with('mysql', () => 'blob') + .otherwise(() => 'blob'); + } + + private get stringType() { + return match>(this.schema.provider.type) + .with('mysql', () => sql.raw('varchar(255)')) + .otherwise(() => 'text'); + } + + private get booleanType() { + return match>(this.schema.provider.type) + .with('mysql', () => sql.raw('tinyint(1)')) + .otherwise(() => 'boolean'); + } + + private get intType(): ColumnDataType { + return 'integer'; + } + + private get floatType() { + return match>(this.schema.provider.type) + .with('mysql', () => sql.raw('double')) + .otherwise(() => 'real'); + } + + private get bigIntType(): ColumnDataType { + return 'bigint'; + } + + private get decimalType() { + return match>(this.schema.provider.type) + .with('mysql', () => sql.raw('decimal(65, 30)')) + .otherwise(() => 'decimal'); + } + + private get dateTimeType() { + return match>(this.schema.provider.type) + .with('mysql', () => sql.raw('datetime(3)')) + .otherwise(() => 'timestamp'); + } + + private columnSupportsAutoIncrement() { + return ['sqlite', 'mysql'].includes(this.schema.provider.type); + } + + // #endregion } diff --git a/packages/orm/src/client/plugin.ts b/packages/orm/src/client/plugin.ts index 81dff0ec..e531242e 100644 --- a/packages/orm/src/client/plugin.ts +++ b/packages/orm/src/client/plugin.ts @@ -250,6 +250,12 @@ export type PluginAfterEntityMutationArgs = MutationHo */ loadAfterMutationEntities(): Promise[] | undefined>; + /** + * The entities before mutation. Only available if `beforeEntityMutation` hook is provided and + * the `loadBeforeMutationEntities` function is called in that hook. + */ + beforeMutationEntities?: Record[]; + /** * The ZenStack client you can use to perform additional operations. * See {@link EntityMutationHooksDef.runAfterMutationWithinTransaction} for detailed transaction behavior. diff --git a/packages/orm/src/client/query-utils.ts b/packages/orm/src/client/query-utils.ts index 51096722..66e41c40 100644 --- a/packages/orm/src/client/query-utils.ts +++ b/packages/orm/src/client/query-utils.ts @@ -12,7 +12,6 @@ import { match } from 'ts-pattern'; import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type SchemaDef } from '../schema'; import { extractFields } from '../utils/object-utils'; import type { AGGREGATE_OPERATORS } from './constants'; -import type { OrderBy } from './crud-types'; import { createInternalError } from './errors'; export function hasModel(schema: SchemaDef, model: string) { @@ -70,6 +69,29 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri throw createInternalError(`Model or type "${modelOrType}" not found in schema`, modelOrType); } +/** + * Gets all model fields, by default non-relation, non-computed, non-inherited fields only. + */ +export function getModelFields( + schema: SchemaDef, + model: string, + options?: { relations?: boolean; computed?: boolean; inherited?: boolean }, +) { + const modelDef = requireModel(schema, model); + return Object.values(modelDef.fields).filter((f) => { + if (f.relation && !options?.relations) { + return false; + } + if (f.computed && !options?.computed) { + return false; + } + if (f.originModel && !options?.inherited) { + return false; + } + return true; + }); +} + export function getIdFields(schema: SchemaDef, model: GetModels) { const modelDef = getModel(schema, model); return modelDef?.idFields; @@ -222,9 +244,9 @@ export function buildJoinPairs( }); } -export function makeDefaultOrderBy(schema: SchemaDef, model: string) { +export function makeDefaultOrderBy(schema: SchemaDef, model: string) { const idFields = requireIdFields(schema, model); - return idFields.map((f) => ({ [f]: 'asc' }) as OrderBy, true, false>); + return idFields.map((f) => ({ [f]: 'asc' }) as const); } export function getManyToManyRelation(schema: SchemaDef, model: string, field: string) { diff --git a/packages/orm/src/dialects/mysql.ts b/packages/orm/src/dialects/mysql.ts new file mode 100644 index 00000000..d641eb03 --- /dev/null +++ b/packages/orm/src/dialects/mysql.ts @@ -0,0 +1 @@ +export { MysqlDialect, type MysqlDialectConfig } from 'kysely'; diff --git a/packages/orm/tsup.config.ts b/packages/orm/tsup.config.ts index 14ebdca2..8318a049 100644 --- a/packages/orm/tsup.config.ts +++ b/packages/orm/tsup.config.ts @@ -7,6 +7,7 @@ export default defineConfig({ helpers: 'src/helpers.ts', 'dialects/sqlite': 'src/dialects/sqlite.ts', 'dialects/postgres': 'src/dialects/postgres.ts', + 'dialects/mysql': 'src/dialects/mysql.ts', 'dialects/sql.js': 'src/dialects/sql.js/index.ts', }, outDir: 'dist', diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 7977ccb2..f33de0ea 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -545,7 +545,7 @@ export class ExpressionTransformer { } else if (value === false) { return falseNode(this.dialect); } else { - const transformed = this.dialect.transformPrimitive(value, type, false) ?? null; + const transformed = this.dialect.transformInput(value, type, false) ?? null; if (!Array.isArray(transformed)) { // simple primitives can be immediate values return ValueNode.createImmediate(transformed); diff --git a/packages/plugins/policy/src/functions.ts b/packages/plugins/policy/src/functions.ts index a42de65b..1cb8a95c 100644 --- a/packages/plugins/policy/src/functions.ts +++ b/packages/plugins/policy/src/functions.ts @@ -94,9 +94,14 @@ export const check: ZModelFunction = ( // build the final nested select that evaluates the policy condition const result = eb - .selectFrom(relationModel) - .where(joinCondition) - .select(new ExpressionWrapper(policyCondition).as('$condition')); + .selectFrom( + eb + .selectFrom(relationModel) + .where(joinCondition) + .select(new ExpressionWrapper(policyCondition).as('$condition')) + .as('$sub'), + ) + .selectAll(); return result; }; diff --git a/packages/plugins/policy/src/policy-handler.ts b/packages/plugins/policy/src/policy-handler.ts index 54169064..ec12d1a7 100644 --- a/packages/plugins/policy/src/policy-handler.ts +++ b/packages/plugins/policy/src/policy-handler.ts @@ -16,7 +16,6 @@ import { expressionBuilder, ExpressionWrapper, FromNode, - FunctionNode, IdentifierNode, InsertQueryNode, JoinNode, @@ -24,7 +23,6 @@ import { OperatorNode, ParensNode, PrimitiveValueListNode, - RawNode, ReferenceNode, ReturningNode, SelectAllNode, @@ -33,10 +31,10 @@ import { sql, TableNode, UpdateQueryNode, - ValueListNode, ValueNode, ValuesNode, WhereNode, + type Expression as KyselyExpression, type OperationNode, type QueryResult, type RootOperationNode, @@ -67,6 +65,7 @@ type FieldLevelPolicyOperations = Exclude; export class PolicyHandler extends OperationNodeTransformer { private readonly dialect: BaseCrudDialect; + private readonly eb = expressionBuilder(); constructor(private readonly client: ClientContract) { super(); @@ -112,11 +111,17 @@ export class PolicyHandler extends OperationNodeTransf } // post-update: load before-update entities if needed - const hasPostUpdatePolicies = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel); - + const needsPostUpdateCheck = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel); let beforeUpdateInfo: Awaited> | undefined; - if (hasPostUpdatePolicies) { - beforeUpdateInfo = await this.loadBeforeUpdateEntities(mutationModel, node.where, proceed); + if (needsPostUpdateCheck) { + beforeUpdateInfo = await this.loadBeforeUpdateEntities( + mutationModel, + node.where, + proceed, + // force load pre-update entities if dialect doesn't support returning, + // so we can rely on pre-update ids to read back updated entities + !this.dialect.supportsReturning, + ); } // #endregion @@ -129,86 +134,12 @@ export class PolicyHandler extends OperationNodeTransf // #region Post mutation work - if (hasPostUpdatePolicies && result.rows.length > 0) { - // verify if before-update rows and post-update rows still id-match - if (beforeUpdateInfo) { - invariant(beforeUpdateInfo.rows.length === result.rows.length); - const idFields = QueryUtils.requireIdFields(this.client.$schema, mutationModel); - for (const postRow of result.rows) { - const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f])); - if (!beforeRow) { - throw createRejectedByPolicyError( - mutationModel, - RejectedByPolicyReason.OTHER, - 'Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.', - ); - } - } - } - - // entities updated filter - const idConditions = this.buildIdConditions(mutationModel, result.rows); - - // post-update policy filter - const postUpdateFilter = this.buildPolicyFilter(mutationModel, undefined, 'post-update'); - - // read the post-update row with filter applied - - const eb = expressionBuilder(); - - // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for before-update rows - const beforeUpdateTable: SelectQueryNode | undefined = beforeUpdateInfo - ? { - kind: 'SelectQueryNode', - from: FromNode.create([ - ParensNode.create( - ValuesNode.create( - beforeUpdateInfo!.rows.map((r) => - PrimitiveValueListNode.create(beforeUpdateInfo!.fields.map((f) => r[f])), - ), - ), - ), - ]), - selections: beforeUpdateInfo.fields.map((name, index) => { - const def = QueryUtils.requireField(this.client.$schema, mutationModel, name); - const castedColumnRef = - sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as( - name, - ); - return SelectionNode.create(castedColumnRef.toOperationNode()); - }), - } - : undefined; - - const postUpdateQuery = eb - .selectFrom(mutationModel) - .select(() => [eb(eb.fn('COUNT', [eb.lit(1)]), '=', result.rows.length).as('$condition')]) - .where(() => new ExpressionWrapper(conjunction(this.dialect, [idConditions, postUpdateFilter]))) - .$if(!!beforeUpdateInfo, (qb) => - qb.leftJoin( - () => new ExpressionWrapper(beforeUpdateTable!).as('$before'), - (join) => { - const idFields = QueryUtils.requireIdFields(this.client.$schema, mutationModel); - return idFields.reduce( - (acc, f) => acc.onRef(`${mutationModel}.${f}`, '=', `$before.${f}`), - join, - ); - }, - ), - ); - - const postUpdateResult = await proceed(postUpdateQuery.toOperationNode()); - if (!postUpdateResult.rows[0]?.$condition) { - throw createRejectedByPolicyError( - mutationModel, - RejectedByPolicyReason.NO_ACCESS, - 'some or all updated rows failed to pass post-update policy check', - ); - } - - // #endregion + if ((result.numAffectedRows ?? 0) > 0 && needsPostUpdateCheck) { + await this.postUpdateCheck(mutationModel, beforeUpdateInfo, result, proceed); } + // #endregion + // #region Read back if (!node.returning || this.onlyReturningId(node)) { @@ -272,13 +203,13 @@ export class PolicyHandler extends OperationNodeTransf // build a query to count rows that will be rejected by field-level policies // `SELECT COALESCE(SUM((not ) as integer), 0) AS $filteredCount WHERE AND ` - const preUpdateCheckQuery = expressionBuilder() + const preUpdateCheckQuery = this.eb .selectFrom(mutationModel) .select((eb) => eb.fn .coalesce( eb.fn.sum( - eb.cast(new ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)), 'integer'), + this.dialect.castInt(new ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter))), ), eb.lit(0), ) @@ -296,6 +227,105 @@ export class PolicyHandler extends OperationNodeTransf } } + private async postUpdateCheck( + model: string, + beforeUpdateInfo: Awaited>, + updateResult: QueryResult, + proceed: ProceedKyselyQueryFunction, + ) { + let postUpdateRows: Record[]; + if (this.dialect.supportsReturning) { + // if dialect supports returning, use returned rows directly + postUpdateRows = updateResult.rows; + } else { + // otherwise, need to read back updated rows using pre-update ids + + invariant(beforeUpdateInfo, 'beforeUpdateInfo must be defined for dialects not supporting returning'); + + const idConditions = this.buildIdConditions(model, beforeUpdateInfo!.rows); + const idFields = QueryUtils.requireIdFields(this.client.$schema, model); + const postUpdateQuery: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(model)]), + where: WhereNode.create(idConditions), + selections: idFields.map((field) => SelectionNode.create(ColumnNode.create(field))), + }; + const postUpdateQueryResult = await proceed(postUpdateQuery); + postUpdateRows = postUpdateQueryResult.rows; + } + + if (beforeUpdateInfo) { + // verify if before-update rows and post-update rows still id-match + if (beforeUpdateInfo.rows.length !== postUpdateRows.length) { + throw createRejectedByPolicyError( + model, + RejectedByPolicyReason.OTHER, + 'Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.', + ); + } + const idFields = QueryUtils.requireIdFields(this.client.$schema, model); + for (const postRow of postUpdateRows) { + const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f])); + if (!beforeRow) { + throw createRejectedByPolicyError( + model, + RejectedByPolicyReason.OTHER, + 'Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.', + ); + } + } + } + + // entities updated filter + const idConditions = this.buildIdConditions(model, postUpdateRows); + + // post-update policy filter + const postUpdateFilter = this.buildPolicyFilter(model, undefined, 'post-update'); + + // read the post-update row with filter applied + + const eb = expressionBuilder(); + + // before update table is joined if fields from `before()` are used in post-update policies + const needsBeforeUpdateJoin = !!beforeUpdateInfo?.fields; + + let beforeUpdateTable: SelectQueryNode | undefined = undefined; + + if (needsBeforeUpdateJoin) { + // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for before-update rows + const fieldDefs = beforeUpdateInfo.fields!.map((name) => + QueryUtils.requireField(this.client.$schema, model, name), + ); + const rows = beforeUpdateInfo.rows.map((r) => beforeUpdateInfo!.fields!.map((f) => r[f])); + beforeUpdateTable = this.dialect.buildValuesTableSelect(fieldDefs, rows).toOperationNode(); + } + + const postUpdateQuery = eb + .selectFrom(model) + .select(() => [ + eb(eb.fn('COUNT', [eb.lit(1)]), '=', Number(updateResult.numAffectedRows ?? 0)).as('$condition'), + ]) + .where(() => new ExpressionWrapper(conjunction(this.dialect, [idConditions, postUpdateFilter]))) + .$if(needsBeforeUpdateJoin, (qb) => + qb.leftJoin( + () => new ExpressionWrapper(beforeUpdateTable!).as('$before'), + (join) => { + const idFields = QueryUtils.requireIdFields(this.client.$schema, model); + return idFields.reduce((acc, f) => acc.onRef(`${model}.${f}`, '=', `$before.${f}`), join); + }, + ), + ); + + const postUpdateResult = await proceed(postUpdateQuery.toOperationNode()); + if (!postUpdateResult.rows[0]?.$condition) { + throw createRejectedByPolicyError( + model, + RejectedByPolicyReason.NO_ACCESS, + 'some or all updated rows failed to pass post-update policy check', + ); + } + } + // #endregion // #region Transformations @@ -395,8 +425,9 @@ export class PolicyHandler extends OperationNodeTransf protected override transformInsertQuery(node: InsertQueryNode) { // pre-insert check is done in `handle()` - let onConflict = node.onConflict; + let processedNode = node; + let onConflict = node.onConflict; if (onConflict?.updates) { // for "on conflict do update", we need to apply policy filter to the "where" clause const { mutationModel, alias } = this.getMutationModel(node); @@ -412,10 +443,50 @@ export class PolicyHandler extends OperationNodeTransf updateWhere: WhereNode.create(filter), }; } + processedNode = { ...node, onConflict }; } - // merge updated onConflict - const processedNode = onConflict ? { ...node, onConflict } : node; + let onDuplicateKey = node.onDuplicateKey; + if (onDuplicateKey?.updates) { + // for "on duplicate key update", we need to wrap updates in IF(filter, newValue, oldValue) + // so that updates only happen when the policy filter is satisfied + const { mutationModel } = this.getMutationModel(node); + + // Build the filter without alias, but will still contain model name as table reference + const filterWithTableRef = this.buildPolicyFilter(mutationModel, undefined, 'update'); + + // Strip table references from the filter since ON DUPLICATE KEY UPDATE doesn't support them + const filter = this.stripTableReferences(filterWithTableRef, mutationModel); + + // transform each update to: IF(filter, newValue, oldValue) + const wrappedUpdates = onDuplicateKey.updates.map((update) => { + // For each column update, wrap it with IF condition + // IF(filter, newValue, columnName) - columnName references the existing row value + const columnName = ColumnNode.is(update.column) ? update.column.column.name : undefined; + if (!columnName) { + // keep original update if we can't extract column name + return update; + } + + // Create the wrapped value: IF(filter, newValue, columnName) + // In MySQL's ON DUPLICATE KEY UPDATE context: + // - VALUES(col) = the value from the INSERT statement + // - col = the existing row value before update + const wrappedValue = + sql`IF(${new ExpressionWrapper(filter)}, ${new ExpressionWrapper(update.value)}, ${sql.ref(columnName)})`.toOperationNode(); + + return { + ...update, + value: wrappedValue, + }; + }); + + onDuplicateKey = { + ...onDuplicateKey, + updates: wrappedUpdates, + }; + processedNode = { ...processedNode, onDuplicateKey }; + } const result = super.transformInsertQuery(processedNode); @@ -458,7 +529,7 @@ export class PolicyHandler extends OperationNodeTransf // 2. if there are post-update policies, we need to make sure id fields are selected for joining with // before-update rows - if (returning || this.hasPostUpdatePolicies(mutationModel)) { + if (this.dialect.supportsReturning && (returning || this.hasPostUpdatePolicies(mutationModel))) { const idFields = QueryUtils.requireIdFields(this.client.$schema, mutationModel); returning = ReturningNode.create(idFields.map((f) => SelectionNode.create(ColumnNode.create(f)))); } @@ -500,21 +571,23 @@ export class PolicyHandler extends OperationNodeTransf model: string, where: WhereNode | undefined, proceed: ProceedKyselyQueryFunction, + forceLoad: boolean = false, ) { const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model); - if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) { + if (!forceLoad && (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0)) { return undefined; } // combine update's where with policy filter const policyFilter = this.buildPolicyFilter(model, model, 'update'); const combinedFilter = where ? conjunction(this.dialect, [where.where, policyFilter]) : policyFilter; + const selections = beforeUpdateAccessFields ?? QueryUtils.requireIdFields(this.client.$schema, model); const query: SelectQueryNode = { kind: 'SelectQueryNode', from: FromNode.create([TableNode.create(model)]), where: WhereNode.create(combinedFilter), - selections: [...beforeUpdateAccessFields.map((f) => SelectionNode.create(ColumnNode.create(f)))], + selections: selections.map((f) => SelectionNode.create(ColumnNode.create(f))), }; const result = await proceed(query); return { fields: beforeUpdateAccessFields, rows: result.rows }; @@ -792,62 +865,30 @@ export class PolicyHandler extends OperationNodeTransf values: OperationNode[], proceed: ProceedKyselyQueryFunction, ) { - const allFields = Object.entries(QueryUtils.requireModel(this.client.$schema, model).fields).filter( - ([, def]) => !def.relation, - ); - const allValues: OperationNode[] = []; + const allFields = QueryUtils.getModelFields(this.client.$schema, model, { inherited: true }); + const allValues: KyselyExpression[] = []; - for (const [name, _def] of allFields) { - const index = fields.indexOf(name); + for (const def of allFields) { + const index = fields.indexOf(def.name); if (index >= 0) { - allValues.push(values[index]!); + allValues.push(new ExpressionWrapper(values[index]!)); } else { // set non-provided fields to null - allValues.push(ValueNode.createImmediate(null)); + allValues.push(this.eb.lit(null)); } } // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for policy evaluation - const eb = expressionBuilder(); - - const constTable: SelectQueryNode = { - kind: 'SelectQueryNode', - from: FromNode.create([ - AliasNode.create( - ParensNode.create(ValuesNode.create([ValueListNode.create(allValues)])), - IdentifierNode.create('$t'), - ), - ]), - selections: allFields.map(([name, def], index) => { - const castedColumnRef = - sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as( - name, - ); - return SelectionNode.create(castedColumnRef.toOperationNode()); - }), - }; + const valuesTable = this.dialect.buildValuesTableSelect(allFields, [allValues]); const filter = this.buildPolicyFilter(model, undefined, 'create'); - const preCreateCheck: SelectQueryNode = { - kind: 'SelectQueryNode', - from: FromNode.create([AliasNode.create(constTable, IdentifierNode.create(model))]), - selections: [ - SelectionNode.create( - AliasNode.create( - BinaryOperationNode.create( - FunctionNode.create('COUNT', [ValueNode.createImmediate(1)]), - OperatorNode.create('>'), - ValueNode.createImmediate(0), - ), - IdentifierNode.create('$condition'), - ), - ), - ], - where: WhereNode.create(filter), - }; + const preCreateCheck = this.eb + .selectFrom(valuesTable.as(model)) + .select(this.eb(this.eb.fn.count(this.eb.lit(1)), '>', 0).as('$condition')) + .where(() => new ExpressionWrapper(filter)); - const result = await proceed(preCreateCheck); + const result = await proceed(preCreateCheck.toOperationNode()); if (!result.rows[0]?.$condition) { throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS); } @@ -883,7 +924,7 @@ export class PolicyHandler extends OperationNodeTransf invariant(item.kind === 'ValueNode', 'expecting a ValueNode'); result.push({ node: ValueNode.create( - this.dialect.transformPrimitive( + this.dialect.transformInput( (item as ValueNode).value, fieldDef.type as BuiltinType, !!fieldDef.array, @@ -899,11 +940,13 @@ export class PolicyHandler extends OperationNodeTransf // are all foreign keys if (!isImplicitManyToManyJoinTable) { const fieldDef = QueryUtils.requireField(this.client.$schema, model, fields[i]!); - value = this.dialect.transformPrimitive(item, fieldDef.type as BuiltinType, !!fieldDef.array); + value = this.dialect.transformInput(item, fieldDef.type as BuiltinType, !!fieldDef.array); } + + // handle the case for list column if (Array.isArray(value)) { result.push({ - node: RawNode.createWithSql(this.dialect.buildArrayLiteralSQL(value)), + node: this.dialect.buildArrayLiteralSQL(value).toOperationNode(), raw: value, }); } else { @@ -1244,10 +1287,9 @@ export class PolicyHandler extends OperationNodeTransf // - mutation: requires both sides to be updatable const checkForOperation = operation === 'read' ? 'read' : 'update'; - const eb = expressionBuilder(); const joinTable = alias ?? tableName; - const aQuery = eb + const aQuery = this.eb .selectFrom(m2m.firstModel) .whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, '=', `${joinTable}.A`) .select(() => @@ -1256,7 +1298,7 @@ export class PolicyHandler extends OperationNodeTransf ), ); - const bQuery = eb + const bQuery = this.eb .selectFrom(m2m.secondModel) .whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, '=', `${joinTable}.B`) .select(() => @@ -1265,7 +1307,7 @@ export class PolicyHandler extends OperationNodeTransf ), ); - return eb.and([aQuery, bQuery]).toOperationNode(); + return this.eb.and([aQuery, bQuery]).toOperationNode(); } private tryRejectNonexistentModel(model: string) { @@ -1297,5 +1339,27 @@ export class PolicyHandler extends OperationNodeTransf } } + // strips table references from an OperationNode + private stripTableReferences(node: OperationNode, modelName: string): OperationNode { + return new TableReferenceStripper().strip(node, modelName); + } +} + +class TableReferenceStripper extends OperationNodeTransformer { + private tableName: string = ''; + + strip(node: OperationNode, tableName: string) { + this.tableName = tableName; + return this.transformNode(node); + } + + protected override transformReference(node: ReferenceNode) { + if (ColumnNode.is(node.column) && node.table?.table.identifier.name === this.tableName) { + // strip the table part + return ReferenceNode.create(this.transformNode(node.column)); + } + return super.transformReference(node); + } + // #endregion } diff --git a/packages/plugins/policy/src/utils.ts b/packages/plugins/policy/src/utils.ts index f42370cc..5deef04b 100644 --- a/packages/plugins/policy/src/utils.ts +++ b/packages/plugins/policy/src/utils.ts @@ -20,14 +20,14 @@ import { * Creates a `true` value node. */ export function trueNode(dialect: BaseCrudDialect) { - return ValueNode.createImmediate(dialect.transformPrimitive(true, 'Boolean', false)); + return ValueNode.createImmediate(dialect.transformInput(true, 'Boolean', false)); } /** * Creates a `false` value node. */ export function falseNode(dialect: BaseCrudDialect) { - return ValueNode.createImmediate(dialect.transformPrimitive(false, 'Boolean', false)); + return ValueNode.createImmediate(dialect.transformInput(false, 'Boolean', false)); } /** diff --git a/packages/schema/src/schema.ts b/packages/schema/src/schema.ts index 40f7d8bd..58fc1bc5 100644 --- a/packages/schema/src/schema.ts +++ b/packages/schema/src/schema.ts @@ -1,7 +1,7 @@ import type Decimal from 'decimal.js'; import type { Expression } from './expression'; -export type DataSourceProviderType = 'sqlite' | 'postgresql'; +export type DataSourceProviderType = 'sqlite' | 'postgresql' | 'mysql'; export type DataSourceProvider = { type: DataSourceProviderType; diff --git a/packages/server/test/api/rest.test.ts b/packages/server/test/api/rest.test.ts index 5cdd6f3a..96b469ba 100644 --- a/packages/server/test/api/rest.test.ts +++ b/packages/server/test/api/rest.test.ts @@ -1416,8 +1416,8 @@ describe('REST server tests', () => { email: `user1@abc.com`, posts: { create: [...Array(10).keys()].map((i) => ({ - id: i, - title: `Post${i}`, + id: i + 1, + title: `Post${i + 1}`, })), }, }, @@ -1478,8 +1478,8 @@ describe('REST server tests', () => { email: `user1@abc.com`, posts: { create: [...Array(10).keys()].map((i) => ({ - id: i, - title: `Post${i}`, + id: i + 1, + title: `Post${i + 1}`, })), }, }, @@ -3352,7 +3352,8 @@ mutation procedure sum(a: Int, b: Int): Int const b = args?.b as number | undefined; return (a ?? 0) + (b ?? 0); }, - sumIds: async ({ args }: ProcCtx) => (args.ids as number[]).reduce((acc, x) => acc + x, 0), + sumIds: async ({ args }: ProcCtx) => + (args.ids as number[]).reduce((acc, x) => acc + x, 0), echoRole: async ({ args }: ProcCtx) => args.r, echoOverview: async ({ args }: ProcCtx) => args.o, sum: async ({ args }: ProcCtx) => args.a + args.b, @@ -3373,7 +3374,7 @@ mutation procedure sum(a: Int, b: Int): Int const r = await handler({ method: 'get', path: '/$procs/echoDecimal', - query: { ...json as object, meta: { serialization: meta } } as any, + query: { ...(json as object), meta: { serialization: meta } } as any, client, }); @@ -3486,7 +3487,7 @@ mutation procedure sum(a: Int, b: Int): Int const r = await handler({ method: 'post', path: '/$procs/sum', - requestBody: { ...json as object, meta: { serialization: meta } } as any, + requestBody: { ...(json as object), meta: { serialization: meta } } as any, client, }); diff --git a/packages/server/test/api/rpc.test.ts b/packages/server/test/api/rpc.test.ts index 9493ed2a..30818ae8 100644 --- a/packages/server/test/api/rpc.test.ts +++ b/packages/server/test/api/rpc.test.ts @@ -30,7 +30,7 @@ describe('RPC API Handler Tests', () => { r = await handleRequest({ method: 'get', path: '/user/exists', - query: { q: JSON.stringify({ where: { id: 'user1' }})}, + query: { q: JSON.stringify({ where: { id: 'user1' } }) }, client: rawClient, }); expect(r.status).toBe(200); @@ -69,7 +69,7 @@ describe('RPC API Handler Tests', () => { r = await handleRequest({ method: 'get', path: '/user/exists', - query: { q: JSON.stringify({ where: { id: 'user1' }})}, + query: { q: JSON.stringify({ where: { id: 'user1' } }) }, client: rawClient, }); expect(r.status).toBe(200); @@ -167,7 +167,7 @@ procedure getUndefined(): Undefined getFalse: async () => false, getUndefined: async () => undefined, }, - }); + } as any); const handler = new RPCApiHandler({ schema: procClient.$schema }); const handleProcRequest = async (args: any) => { @@ -267,7 +267,7 @@ procedure echoOverview(o: Overview): Overview echoRole: async ({ args }: any) => args.r, echoOverview: async ({ args }: any) => args.o, }, - }); + } as any); const handler = new RPCApiHandler({ schema: procClient.$schema }); const handleProcRequest = async (args: any) => { diff --git a/packages/testtools/package.json b/packages/testtools/package.json index ed5f9bc9..7ec0135d 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -39,6 +39,7 @@ "@zenstackhq/plugin-policy": "workspace:*", "glob": "^11.1.0", "kysely": "catalog:", + "mysql2": "catalog:", "prisma": "catalog:", "tmp": "catalog:", "ts-pattern": "catalog:" diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index 69659661..c82a6b78 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -1,17 +1,19 @@ import { invariant } from '@zenstackhq/common-helpers'; import type { Model } from '@zenstackhq/language/ast'; import { ZenStackClient, type ClientContract, type ClientOptions } from '@zenstackhq/orm'; -import type { SchemaDef } from '@zenstackhq/orm/schema'; +import type { DataSourceProviderType, SchemaDef } from '@zenstackhq/orm/schema'; import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import SQLite from 'better-sqlite3'; import { glob } from 'glob'; -import { PostgresDialect, SqliteDialect, type LogEvent } from 'kysely'; +import { MysqlDialect, PostgresDialect, SqliteDialect, type LogEvent } from 'kysely'; +import { createPool as createMysqlPool } from 'mysql2'; import { execSync } from 'node:child_process'; import { createHash } from 'node:crypto'; import fs from 'node:fs'; import path from 'node:path'; import { Client as PGClient, Pool } from 'pg'; +import { match } from 'ts-pattern'; import { expect } from 'vitest'; import { createTestProject } from './project'; import { generateTsSchema } from './schema'; @@ -19,10 +21,10 @@ import { loadDocumentWithPlugins } from './utils'; export function getTestDbProvider() { const val = process.env['TEST_DB_PROVIDER'] ?? 'sqlite'; - if (!['sqlite', 'postgresql'].includes(val!)) { + if (!['sqlite', 'postgresql', 'mysql'].includes(val!)) { throw new Error(`Invalid TEST_DB_PROVIDER value: ${val}`); } - return val as 'sqlite' | 'postgresql'; + return val as 'sqlite' | 'postgresql' | 'mysql'; } export const TEST_PG_CONFIG = { @@ -34,11 +36,21 @@ export const TEST_PG_CONFIG = { export const TEST_PG_URL = `postgres://${TEST_PG_CONFIG.user}:${TEST_PG_CONFIG.password}@${TEST_PG_CONFIG.host}:${TEST_PG_CONFIG.port}`; +export const TEST_MYSQL_CONFIG = { + host: process.env['TEST_MYSQL_HOST'] ?? 'localhost', + port: process.env['TEST_MYSQL_PORT'] ? parseInt(process.env['TEST_MYSQL_PORT']) : 3306, + user: process.env['TEST_MYSQL_USER'] ?? 'root', + password: process.env['TEST_MYSQL_PASSWORD'] ?? 'mysql', + timezone: 'Z', +}; + +export const TEST_MYSQL_URL = `mysql://${TEST_MYSQL_CONFIG.user}:${TEST_MYSQL_CONFIG.password}@${TEST_MYSQL_CONFIG.host}:${TEST_MYSQL_CONFIG.port}`; + type ExtraTestClientOptions = { /** * Database provider */ - provider?: 'sqlite' | 'postgresql'; + provider?: 'sqlite' | 'postgresql' | 'mysql'; /** * The main ZModel file. Only used when `usePrismaPush` is true and `schema` is an object. @@ -106,8 +118,11 @@ export async function createTestClient( let _schema: SchemaDef; const provider = options?.provider ?? getTestDbProvider() ?? 'sqlite'; const dbName = options?.dbName ?? getTestDbName(provider); - const dbUrl = provider === 'sqlite' ? `file:${dbName}` : `${TEST_PG_URL}/${dbName}`; - + const dbUrl = match(provider) + .with('sqlite', () => `file:${dbName}`) + .with('mysql', () => `${TEST_MYSQL_URL}/${dbName}`) + .with('postgresql', () => `${TEST_PG_URL}/${dbName}`) + .exhaustive(); let model: Model | undefined; if (typeof schema === 'string') { @@ -158,7 +173,7 @@ export async function createTestClient( if (options?.debug) { console.log(`Work directory: ${workDir}`); console.log(`Database name: ${dbName}`); - _options.log = testLogger; + _options.log ??= testLogger; } // copy db file to workDir if specified @@ -208,29 +223,12 @@ export async function createTestClient( stdio: options.debug ? 'inherit' : 'ignore', }); } else { - if (provider === 'postgresql') { - invariant(dbName, 'dbName is required'); - const pgClient = new PGClient(TEST_PG_CONFIG); - await pgClient.connect(); - await pgClient.query(`DROP DATABASE IF EXISTS "${dbName}"`); - await pgClient.query(`CREATE DATABASE "${dbName}"`); - await pgClient.end(); - } + await prepareDatabase(provider, dbName); } } - if (provider === 'postgresql') { - _options.dialect = new PostgresDialect({ - pool: new Pool({ - ...TEST_PG_CONFIG, - database: dbName, - }), - }); - } else { - _options.dialect = new SqliteDialect({ - database: new SQLite(path.join(workDir!, dbName)), - }); - } + // create Kysely dialect + _options.dialect = createDialect(provider, dbName, workDir); let client = new ZenStackClient(_schema, _options); @@ -238,6 +236,7 @@ export async function createTestClient( await client.$pushSchema(); } + // install plugins if (plugins) { for (const plugin of plugins) { client = client.$use(plugin); @@ -247,6 +246,55 @@ export async function createTestClient( return client; } +function createDialect(provider: DataSourceProviderType, dbName: string, workDir: string) { + return match(provider) + .with( + 'postgresql', + () => + new PostgresDialect({ + pool: new Pool({ + ...TEST_PG_CONFIG, + database: dbName, + }), + }), + ) + .with( + 'mysql', + () => + new MysqlDialect({ + pool: createMysqlPool({ + ...TEST_MYSQL_CONFIG, + database: dbName, + }), + }), + ) + .with( + 'sqlite', + () => + new SqliteDialect({ + database: new SQLite(path.join(workDir!, dbName)), + }), + ) + .exhaustive(); +} + +async function prepareDatabase(provider: string, dbName: string) { + if (provider === 'postgresql') { + invariant(dbName, 'dbName is required'); + const pgClient = new PGClient(TEST_PG_CONFIG); + await pgClient.connect(); + await pgClient.query(`DROP DATABASE IF EXISTS "${dbName}"`); + await pgClient.query(`CREATE DATABASE "${dbName}"`); + await pgClient.end(); + } else if (provider === 'mysql') { + invariant(dbName, 'dbName is required'); + const mysqlPool = createMysqlPool(TEST_MYSQL_CONFIG); + await mysqlPool.promise().query(`DROP DATABASE IF EXISTS \`${dbName}\``); + await mysqlPool.promise().query(`CREATE DATABASE \`${dbName}\``); + await mysqlPool.promise().end(); + } +} + export async function createPolicyTestClient( schema: Schema, options?: CreateTestClientOptions, diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 1ecb015c..44f4483c 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import type { SchemaDef } from '@zenstackhq/schema'; +import type { DataSourceProviderType, SchemaDef } from '@zenstackhq/schema'; import { TsSchemaGenerator } from '@zenstackhq/sdk'; import { execSync } from 'node:child_process'; import crypto from 'node:crypto'; @@ -11,7 +11,7 @@ import { expect } from 'vitest'; import { createTestProject } from './project'; import { loadDocumentWithPlugins } from './utils'; -function makePrelude(provider: 'sqlite' | 'postgresql', dbUrl?: string) { +function makePrelude(provider: DataSourceProviderType, dbUrl?: string) { return match(provider) .with('sqlite', () => { return ` @@ -27,19 +27,37 @@ datasource db { provider = 'postgresql' url = '${dbUrl ?? 'postgres://postgres:postgres@localhost:5432/db'}' } +`; + }) + .with('mysql', () => { + return ` +datasource db { + provider = 'mysql' + url = '${dbUrl ?? 'mysql://root:mysql@localhost:3306/db'}' +} `; }) .exhaustive(); } -function replacePlaceholders(schemaText: string, provider: 'sqlite' | 'postgresql', dbUrl: string | undefined) { - const url = dbUrl ?? (provider === 'sqlite' ? 'file:./test.db' : 'postgres://postgres:postgres@localhost:5432/db'); +function replacePlaceholders( + schemaText: string, + provider: 'sqlite' | 'postgresql' | 'mysql', + dbUrl: string | undefined, +) { + const url = + dbUrl ?? + (provider === 'sqlite' + ? 'file:./test.db' + : provider === 'mysql' + ? 'mysql://root:mysql@localhost:3306/db' + : 'postgres://postgres:postgres@localhost:5432/db'); return schemaText.replace(/\$DB_URL/g, url).replace(/\$PROVIDER/g, provider); } export async function generateTsSchema( schemaText: string, - provider: 'sqlite' | 'postgresql' = 'sqlite', + provider: DataSourceProviderType = 'sqlite', dbUrl?: string, extraSourceFiles?: Record, withLiteSchema?: boolean, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a60befaa..ab856812 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -51,6 +51,9 @@ catalogs: langium-cli: specifier: 3.5.0 version: 3.5.0 + mysql2: + specifier: ^3.16.1 + version: 3.16.1 next: specifier: 16.0.10 version: 16.0.10 @@ -495,6 +498,9 @@ importers: kysely: specifier: 'catalog:' version: 0.28.8 + mysql2: + specifier: 'catalog:' + version: 3.16.1 nanoid: specifier: ^5.0.9 version: 5.0.9 @@ -703,7 +709,7 @@ importers: version: 16.0.10(@babel/core@7.28.5)(react-dom@19.2.0(react@19.2.0))(react@19.2.0) nuxt: specifier: 'catalog:' - version: 4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2) + version: 4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2) supertest: specifier: ^7.1.4 version: 7.1.4 @@ -740,6 +746,9 @@ importers: kysely: specifier: 'catalog:' version: 0.28.8 + mysql2: + specifier: 'catalog:' + version: 3.16.1 pg: specifier: 'catalog:' version: 8.16.3 @@ -883,7 +892,7 @@ importers: version: 2.0.8 nuxt: specifier: 'catalog:' - version: 4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2) + version: 4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2) tailwindcss: specifier: ^4.1.18 version: 4.1.18 @@ -1037,6 +1046,9 @@ importers: kysely: specifier: 'catalog:' version: 0.28.8 + ts-pattern: + specifier: 'catalog:' + version: 5.7.1 ulid: specifier: ^3.0.0 version: 3.0.0 @@ -4142,6 +4154,10 @@ packages: avvio@9.1.0: resolution: {integrity: sha512-fYASnYi600CsH/j9EQov7lECAniYiBFiiAtBNuZYLA2leLe9qOvZzqYHFjtIj6gD2VMoMLP14834LFWvr4IfDw==} + aws-ssl-profiles@1.1.2: + resolution: {integrity: sha512-NZKeq9AfyQvEeNlN0zSYAaWrmBffJh3IELMZfRpJVWgrpEbtEpnjvzqBPf+mxoI287JohRDoa+/nsfqqiZmF6g==} + engines: {node: '>= 6.0.0'} + axe-core@4.11.0: resolution: {integrity: sha512-ilYanEU8vxxBexpJd8cWM4ElSQq4QctCLKih0TSfjIfCQTeyH/6zVrmIJfLPrKTKJRbiG+cfnZbQIjAlJmF1jQ==} engines: {node: '>=4'} @@ -5291,6 +5307,9 @@ packages: resolution: {integrity: sha512-trLf4SzuuUxfusZADLINj+dE8clK1frKdmqiJNb1Es75fmI5oY6X2mxLVUciLLjxqw/xr72Dhy+lER6dGd02FQ==} engines: {node: '>=10'} + generate-function@2.3.1: + resolution: {integrity: sha512-eeB5GfMNeevm/GRYq20ShmsaGcmI81kIX2K9XQx5miC8KdHaC6Jm0qQ8ZNeGOi7wYB8OsdxKs+Y2oVuTFuVwKQ==} + generator-function@2.0.1: resolution: {integrity: sha512-SFdFmIJi+ybC0vjlHN0ZGVGHc3lgE0DxPAT0djjVg+kjOnSqclqmj0KQ7ykTOLP6YxoqOvuAODGdcHJn+43q3g==} engines: {node: '>= 0.4'} @@ -5668,6 +5687,9 @@ packages: is-promise@4.0.0: resolution: {integrity: sha512-hvpoI6korhJMnej285dSg6nu1+e6uxs7zG3BYAm5byqDsgJNWwxzM6z6iZiAgQR4TJ30JmBTOwqZUw3WlyH3AQ==} + is-property@1.0.2: + resolution: {integrity: sha512-Ks/IoX00TtClbGQr4TWXemAnktAQvYB7HzcCxDGqEZU6oCmb2INHuOoKxbtR+HFkmYWBKv/dOZtGRiAjDhj92g==} + is-reference@1.2.1: resolution: {integrity: sha512-U82MsXXiFIrjCK4otLT+o2NA2Cd2g5MLoOVXUZjIOhLurrRxpEXzI8O0KZHr3IjLvlAH1kTPYSuqer5T9ZVBKQ==} @@ -6060,6 +6082,9 @@ packages: resolution: {integrity: sha512-8XPvpAA8uyhfteu8pIvQxpJZ7SYYdpUivZpGy6sFsBuKRY/7rQGavedeB8aK+Zkyq6upMFVL/9AW6vOYzfRyLg==} engines: {node: '>=10'} + long@5.3.2: + resolution: {integrity: sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==} + loose-envify@1.4.0: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true @@ -6079,6 +6104,10 @@ packages: lru-cache@5.1.1: resolution: {integrity: sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w==} + lru.min@1.1.3: + resolution: {integrity: sha512-Lkk/vx6ak3rYkRR0Nhu4lFUT2VDnQSxBe8Hbl7f36358p6ow8Bnvr8lrLt98H8J1aGxfhbX4Fs5tYg2+FTwr5Q==} + engines: {bun: '>=1.0.0', deno: '>=1.30.0', node: '>=8.0.0'} + lz-string@1.5.0: resolution: {integrity: sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==} hasBin: true @@ -6248,9 +6277,17 @@ packages: muggle-string@0.4.1: resolution: {integrity: sha512-VNTrAak/KhO2i8dqqnqnAHOa3cYBwXEZe9h+D5h/1ZqFSTEFHdM65lR7RoIqq3tBBYavsOXV84NoHXZ0AkPyqQ==} + mysql2@3.16.1: + resolution: {integrity: sha512-b75qsDB3ieYEzMsT1uRGsztM/sy6vWPY40uPZlVVl8eefAotFCoS7jaDB5DxDNtlW5kdVGd9jptSpkvujNxI2A==} + engines: {node: '>= 8.0'} + mz@2.7.0: resolution: {integrity: sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==} + named-placeholders@1.1.6: + resolution: {integrity: sha512-Tz09sEL2EEuv5fFowm419c1+a/jSMiBjI9gHxVLrVdbUkkNUUfjsVYs9pVZu5oCon/kmRh9TfLEObFtkVxmY0w==} + engines: {node: '>=8.0.0'} + nanoid@3.3.11: resolution: {integrity: sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==} engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} @@ -7333,6 +7370,9 @@ packages: resolution: {integrity: sha512-uaW0WwXKpL9blXE2o0bRhoL2EGXIrZxQ2ZQ4mgcfoBxdFmQold+qWsD2jLrfZ0trjKL6vOw0j//eAwcALFjKSw==} engines: {node: '>= 18'} + seq-queue@0.0.5: + resolution: {integrity: sha512-hr3Wtp/GZIc/6DAGPDcV4/9WoZhjrkXsi5B/07QgX8tsdc6ilr7BFM6PM6rbdAX1kFSDYeZGLipIZZKyQP0O5Q==} + serialize-javascript@6.0.2: resolution: {integrity: sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==} @@ -7484,6 +7524,10 @@ packages: sql.js@1.13.0: resolution: {integrity: sha512-RJbVP1HRDlUUXahJ7VMTcu9Rm1Nzw+EBpoPr94vnbD4LwR715F3CcxE2G2k45PewcaZ57pjetYa+LoSJLAASgA==} + sqlstring@2.3.3: + resolution: {integrity: sha512-qC9iz2FlN7DQl3+wjwn3802RTyjCx7sDvfQEXchwa6CWOx07/WVfh91gBmQ9fahw8snwGEWU3xGzOt4tFyHLxg==} + engines: {node: '>= 0.6'} + srvx@0.9.8: resolution: {integrity: sha512-RZaxTKJEE/14HYn8COLuUOJAt0U55N9l1Xf6jj+T0GoA01EUH1Xz5JtSUOI+EHn+AEgPCVn7gk6jHJffrr06fQ==} engines: {node: '>=20.16.0'} @@ -9566,7 +9610,7 @@ snapshots: transitivePeerDependencies: - magicast - '@nuxt/nitro-server@4.2.2(better-sqlite3@12.5.0)(db0@0.3.4(better-sqlite3@12.5.0))(ioredis@5.8.2)(magicast@0.5.1)(nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2))(typescript@5.9.3)': + '@nuxt/nitro-server@4.2.2(better-sqlite3@12.5.0)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(ioredis@5.8.2)(magicast@0.5.1)(mysql2@3.16.1)(nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2))(typescript@5.9.3)': dependencies: '@nuxt/devalue': 2.0.2 '@nuxt/kit': 4.2.2(magicast@0.5.1) @@ -9583,15 +9627,15 @@ snapshots: impound: 1.0.0 klona: 2.0.6 mocked-exports: 0.1.1 - nitropack: 2.12.9(better-sqlite3@12.5.0) - nuxt: 4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2) + nitropack: 2.12.9(better-sqlite3@12.5.0)(mysql2@3.16.1) + nuxt: 4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2) pathe: 2.0.3 pkg-types: 2.3.0 radix3: 1.1.2 std-env: 3.10.0 ufo: 1.6.1 unctx: 2.4.1 - unstorage: 1.17.3(db0@0.3.4(better-sqlite3@12.5.0))(ioredis@5.8.2) + unstorage: 1.17.3(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(ioredis@5.8.2) vue: 3.5.26(typescript@5.9.3) vue-bundle-renderer: 2.2.0 vue-devtools-stub: 0.1.0 @@ -9655,7 +9699,7 @@ snapshots: transitivePeerDependencies: - magicast - '@nuxt/vite-builder@4.2.2(@types/node@20.19.24)(eslint@9.29.0(jiti@2.6.1))(lightningcss@1.30.2)(magicast@0.5.1)(nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2))(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vue@3.5.26(typescript@5.9.3))(yaml@2.8.2)': + '@nuxt/vite-builder@4.2.2(@types/node@20.19.24)(eslint@9.29.0(jiti@2.6.1))(lightningcss@1.30.2)(magicast@0.5.1)(nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2))(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vue@3.5.26(typescript@5.9.3))(yaml@2.8.2)': dependencies: '@nuxt/kit': 4.2.2(magicast@0.5.1) '@rollup/plugin-replace': 6.0.3(rollup@4.52.5) @@ -9675,7 +9719,7 @@ snapshots: magic-string: 0.30.21 mlly: 1.8.0 mocked-exports: 0.1.1 - nuxt: 4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2) + nuxt: 4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2) pathe: 2.0.3 pkg-types: 2.3.0 postcss: 8.5.6 @@ -11562,6 +11606,8 @@ snapshots: '@fastify/error': 4.2.0 fastq: 1.19.1 + aws-ssl-profiles@1.1.2: {} + axe-core@4.11.0: {} axobject-query@4.1.0: {} @@ -12062,9 +12108,10 @@ snapshots: es-errors: 1.3.0 is-data-view: 1.0.2 - db0@0.3.4(better-sqlite3@12.5.0): + db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1): optionalDependencies: better-sqlite3: 12.5.0 + mysql2: 3.16.1 debug@3.2.7: dependencies: @@ -12939,6 +12986,10 @@ snapshots: fuse.js@7.1.0: {} + generate-function@2.3.1: + dependencies: + is-property: 1.0.2 + generator-function@2.0.1: {} gensync@1.0.0-beta.2: {} @@ -13331,6 +13382,8 @@ snapshots: is-promise@4.0.0: {} + is-property@1.0.2: {} + is-reference@1.2.1: dependencies: '@types/estree': 1.0.8 @@ -13729,6 +13782,8 @@ snapshots: chalk: 4.1.2 is-unicode-supported: 0.1.0 + long@5.3.2: {} + loose-envify@1.4.0: dependencies: js-tokens: 4.0.0 @@ -13745,6 +13800,8 @@ snapshots: dependencies: yallist: 3.1.1 + lru.min@1.1.3: {} + lz-string@1.5.0: {} magic-regexp@0.10.0: @@ -13893,12 +13950,28 @@ snapshots: muggle-string@0.4.1: {} + mysql2@3.16.1: + dependencies: + aws-ssl-profiles: 1.1.2 + denque: 2.1.0 + generate-function: 2.3.1 + iconv-lite: 0.7.0 + long: 5.3.2 + lru.min: 1.1.3 + named-placeholders: 1.1.6 + seq-queue: 0.0.5 + sqlstring: 2.3.3 + mz@2.7.0: dependencies: any-promise: 1.3.0 object-assign: 4.1.1 thenify-all: 1.6.0 + named-placeholders@1.1.6: + dependencies: + lru.min: 1.1.3 + nanoid@3.3.11: {} nanoid@5.0.9: {} @@ -13942,7 +14015,7 @@ snapshots: nice-try@1.0.5: {} - nitropack@2.12.9(better-sqlite3@12.5.0): + nitropack@2.12.9(better-sqlite3@12.5.0)(mysql2@3.16.1): dependencies: '@cloudflare/kv-asset-handler': 0.4.0 '@rollup/plugin-alias': 5.1.1(rollup@4.52.5) @@ -13963,7 +14036,7 @@ snapshots: cookie-es: 2.0.0 croner: 9.1.0 crossws: 0.3.5 - db0: 0.3.4(better-sqlite3@12.5.0) + db0: 0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1) defu: 6.1.4 destr: 2.0.5 dot-prop: 10.1.0 @@ -14009,7 +14082,7 @@ snapshots: unenv: 2.0.0-rc.24 unimport: 5.5.0 unplugin-utils: 0.3.1 - unstorage: 1.17.3(db0@0.3.4(better-sqlite3@12.5.0))(ioredis@5.8.2) + unstorage: 1.17.3(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(ioredis@5.8.2) untyped: 2.0.0 unwasm: 0.3.11 youch: 4.1.0-beta.13 @@ -14108,16 +14181,16 @@ snapshots: dependencies: boolbase: 1.0.0 - nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2): + nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2): dependencies: '@dxup/nuxt': 0.2.2(magicast@0.5.1) '@nuxt/cli': 3.31.3(cac@6.7.14)(magicast@0.5.1) '@nuxt/devtools': 3.1.1(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(vue@3.5.26(typescript@5.9.3)) '@nuxt/kit': 4.2.2(magicast@0.5.1) - '@nuxt/nitro-server': 4.2.2(better-sqlite3@12.5.0)(db0@0.3.4(better-sqlite3@12.5.0))(ioredis@5.8.2)(magicast@0.5.1)(nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2))(typescript@5.9.3) + '@nuxt/nitro-server': 4.2.2(better-sqlite3@12.5.0)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(ioredis@5.8.2)(magicast@0.5.1)(mysql2@3.16.1)(nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2))(typescript@5.9.3) '@nuxt/schema': 4.2.2 '@nuxt/telemetry': 2.6.6(magicast@0.5.1) - '@nuxt/vite-builder': 4.2.2(@types/node@20.19.24)(eslint@9.29.0(jiti@2.6.1))(lightningcss@1.30.2)(magicast@0.5.1)(nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2))(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vue@3.5.26(typescript@5.9.3))(yaml@2.8.2) + '@nuxt/vite-builder': 4.2.2(@types/node@20.19.24)(eslint@9.29.0(jiti@2.6.1))(lightningcss@1.30.2)(magicast@0.5.1)(nuxt@4.2.2(@parcel/watcher@2.5.1)(@types/node@20.19.24)(@vue/compiler-sfc@3.5.26)(better-sqlite3@12.5.0)(cac@6.7.14)(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(eslint@9.29.0(jiti@2.6.1))(ioredis@5.8.2)(lightningcss@1.30.2)(magicast@0.5.1)(mysql2@3.16.1)(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vite@7.3.0(@types/node@20.19.24)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.2))(yaml@2.8.2))(optionator@0.9.4)(rollup@4.52.5)(terser@5.44.0)(tsx@4.20.3)(typescript@5.9.3)(vue@3.5.26(typescript@5.9.3))(yaml@2.8.2) '@unhead/vue': 2.0.19(vue@3.5.26(typescript@5.9.3)) '@vue/shared': 3.5.26 c12: 3.3.3(magicast@0.5.1) @@ -15190,6 +15263,8 @@ snapshots: transitivePeerDependencies: - supports-color + seq-queue@0.0.5: {} + serialize-javascript@6.0.2: dependencies: randombytes: 2.1.0 @@ -15380,6 +15455,8 @@ snapshots: sql.js@1.13.0: {} + sqlstring@2.3.3: {} + srvx@0.9.8: {} stable-hash@0.0.5: {} @@ -16065,7 +16142,7 @@ snapshots: '@unrs/resolver-binding-win32-ia32-msvc': 1.11.1 '@unrs/resolver-binding-win32-x64-msvc': 1.11.1 - unstorage@1.17.3(db0@0.3.4(better-sqlite3@12.5.0))(ioredis@5.8.2): + unstorage@1.17.3(db0@0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1))(ioredis@5.8.2): dependencies: anymatch: 3.1.3 chokidar: 4.0.3 @@ -16076,7 +16153,7 @@ snapshots: ofetch: 1.5.1 ufo: 1.6.1 optionalDependencies: - db0: 0.3.4(better-sqlite3@12.5.0) + db0: 0.3.4(better-sqlite3@12.5.0)(mysql2@3.16.1) ioredis: 5.8.2 untun@0.1.3: diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index f5987bd8..ca0942a8 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -22,6 +22,7 @@ catalog: nuxt: 4.2.2 '@sveltejs/kit': 2.49.1 pg: ^8.13.1 + mysql2: ^3.16.1 prisma: ^6.19.0 react: 19.2.0 react-dom: 19.2.0 diff --git a/tests/e2e/orm/client-api/computed-fields.test.ts b/tests/e2e/orm/client-api/computed-fields.test.ts index 2476b67d..c6470a72 100644 --- a/tests/e2e/orm/client-api/computed-fields.test.ts +++ b/tests/e2e/orm/client-api/computed-fields.test.ts @@ -1,4 +1,5 @@ import { createTestClient } from '@zenstackhq/testtools'; +import { sql } from 'kysely'; import { describe, expect, it } from 'vitest'; describe('Computed fields tests', () => { @@ -7,14 +8,15 @@ describe('Computed fields tests', () => { ` model User { id Int @id @default(autoincrement()) - name String - upperName String @computed + firstName String + lastName String + fullName String @computed } `, { computedFields: { User: { - upperName: (eb: any) => eb.fn('upper', ['name']), + fullName: (eb: any) => eb.fn('concat', [eb.ref('firstName'), sql.lit(' '), eb.ref('lastName')]), }, }, } as any, @@ -22,70 +24,70 @@ model User { await expect( db.user.create({ - data: { id: 1, name: 'Alex' }, + data: { id: 1, firstName: 'Alex', lastName: 'Smith' }, }), ).resolves.toMatchObject({ - upperName: 'ALEX', + fullName: 'Alex Smith', }); await expect( db.user.findUnique({ where: { id: 1 }, - select: { upperName: true }, + select: { fullName: true }, }), ).resolves.toMatchObject({ - upperName: 'ALEX', + fullName: 'Alex Smith', }); await expect( db.user.findFirst({ - where: { upperName: 'ALEX' }, + where: { fullName: 'Alex Smith' }, }), ).resolves.toMatchObject({ - upperName: 'ALEX', + fullName: 'Alex Smith', }); await expect( db.user.findFirst({ - where: { upperName: 'Alex' }, + where: { fullName: 'Alex' }, }), ).toResolveNull(); await expect( db.user.findFirst({ - orderBy: { upperName: 'desc' }, + orderBy: { fullName: 'desc' }, }), ).resolves.toMatchObject({ - upperName: 'ALEX', + fullName: 'Alex Smith', }); await expect( db.user.findFirst({ - orderBy: { upperName: 'desc' }, + orderBy: { fullName: 'desc' }, take: 1, }), ).resolves.toMatchObject({ - upperName: 'ALEX', + fullName: 'Alex Smith', }); await expect( db.user.aggregate({ - _count: { upperName: true }, + _count: { fullName: true }, }), ).resolves.toMatchObject({ - _count: { upperName: 1 }, + _count: { fullName: 1 }, }); await expect( db.user.groupBy({ - by: ['upperName'], - _count: { upperName: true }, - _max: { upperName: true }, + by: ['fullName'], + _count: { fullName: true }, + _max: { fullName: true }, }), ).resolves.toEqual([ expect.objectContaining({ - _count: { upperName: 1 }, - _max: { upperName: 'ALEX' }, + _count: { fullName: 1 }, + _max: { fullName: 'Alex Smith' }, }), ]); }); @@ -259,17 +261,19 @@ model Post extends Content { } as any, ); - const posts = await db.post.createManyAndReturn({ - data: [ - { id: 1, title: 'latest news', body: 'some news content' }, - { id: 2, title: 'random post', body: 'some other content' }, - ], - }); - expect(posts).toEqual( - expect.arrayContaining([ - expect.objectContaining({ id: 1, isNews: true }), - expect.objectContaining({ id: 2, isNews: false }), - ]), - ); + if (db.$schema.provider.type !== 'mysql') { + const posts = await db.post.createManyAndReturn({ + data: [ + { id: 1, title: 'latest news', body: 'some news content' }, + { id: 2, title: 'random post', body: 'some other content' }, + ], + }); + expect(posts).toEqual( + expect.arrayContaining([ + expect.objectContaining({ id: 1, isNews: true }), + expect.objectContaining({ id: 2, isNews: false }), + ]), + ); + } }); }); diff --git a/tests/e2e/orm/client-api/create-many-and-return.test.ts b/tests/e2e/orm/client-api/create-many-and-return.test.ts index 8a8697b1..663c8fa7 100644 --- a/tests/e2e/orm/client-api/create-many-and-return.test.ts +++ b/tests/e2e/orm/client-api/create-many-and-return.test.ts @@ -15,6 +15,11 @@ describe('Client createManyAndReturn tests', () => { }); it('works with toplevel createManyAndReturn', async () => { + if (client.$schema.provider.type === ('mysql' as any)) { + // mysql doesn't support createManyAndReturn + return; + } + // empty await expect(client.user.createManyAndReturn()).toResolveWithLength(0); @@ -59,6 +64,11 @@ describe('Client createManyAndReturn tests', () => { }); it('works with select and omit', async () => { + if (client.$schema.provider.type === ('mysql' as any)) { + // mysql doesn't support createManyAndReturn + return; + } + const r = await client.user.createManyAndReturn({ data: [{ email: 'u1@test.com', name: 'name' }], select: { email: true }, diff --git a/tests/e2e/orm/client-api/delegate.test.ts b/tests/e2e/orm/client-api/delegate.test.ts index c5f59c70..86010b4e 100644 --- a/tests/e2e/orm/client-api/delegate.test.ts +++ b/tests/e2e/orm/client-api/delegate.test.ts @@ -135,6 +135,10 @@ describe('Delegate model tests ', () => { }); it('works with createManyAndReturn', async () => { + if (client.$schema.provider.type === ('mysql' as any)) { + return; + } + await expect( client.ratedVideo.createManyAndReturn({ data: [ @@ -179,7 +183,7 @@ describe('Delegate model tests ', () => { rating: 3, }, }), - ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().match(/(constraint)|(duplicate)/i)); await expect(client.ratedVideo.findMany()).toResolveWithLength(1); await expect(client.video.findMany()).toResolveWithLength(1); @@ -790,6 +794,10 @@ describe('Delegate model tests ', () => { }); it('works with updateManyAndReturn', async () => { + if (client.$schema.provider.type === ('mysql' as any)) { + return; + } + await client.ratedVideo.create({ data: { id: 2, viewCount: 1, duration: 200, url: 'abc', rating: 5 }, }); diff --git a/tests/e2e/orm/client-api/error-handling.test.ts b/tests/e2e/orm/client-api/error-handling.test.ts index 1227e525..7be2002a 100644 --- a/tests/e2e/orm/client-api/error-handling.test.ts +++ b/tests/e2e/orm/client-api/error-handling.test.ts @@ -1,5 +1,6 @@ import { ORMError, ORMErrorReason, RejectedByPolicyReason } from '@zenstackhq/orm'; import { createPolicyTestClient, createTestClient } from '@zenstackhq/testtools'; +import { match } from 'ts-pattern'; import { describe, expect, it } from 'vitest'; describe('Error handling tests', () => { @@ -39,14 +40,20 @@ model User { await db.user.create({ data: { email: 'user1@example.com' } }); const provider = db.$schema.provider.type; - const expectedCode = provider === 'sqlite' ? 'SQLITE_CONSTRAINT_UNIQUE' : '23505'; + const expectedCode = match(provider) + .with('sqlite', () => 'SQLITE_CONSTRAINT_UNIQUE') + .with('postgresql', () => '23505') + .with('mysql', () => 'ER_DUP_ENTRY') + .otherwise(() => { + throw new Error(`Unsupported provider: ${provider}`); + }); await expect(db.user.create({ data: { email: 'user1@example.com' } })).rejects.toSatisfy( (e) => e instanceof ORMError && e.reason === ORMErrorReason.DB_QUERY_ERROR && e.dbErrorCode === expectedCode && - !!e.dbErrorMessage?.includes('constraint'), + !!e.dbErrorMessage?.match(/(constraint)|(duplicate)/i), ); }); }); diff --git a/tests/e2e/orm/client-api/find.test.ts b/tests/e2e/orm/client-api/find.test.ts index 2eddd214..5b6ab563 100644 --- a/tests/e2e/orm/client-api/find.test.ts +++ b/tests/e2e/orm/client-api/find.test.ts @@ -262,6 +262,11 @@ describe('Client find tests ', () => { }); it('works with distinct', async () => { + if (['sqlite', 'mysql'].includes(client.$schema.provider.type)) { + await expect(client.user.findMany({ distinct: ['role'] } as any)).rejects.toThrow('not supported'); + return; + } + const user1 = await createUser(client, 'u1@test.com', { name: 'Admin1', role: 'ADMIN', @@ -282,11 +287,6 @@ describe('Client find tests ', () => { role: 'USER', }); - if (client.$schema.provider.type === 'sqlite') { - await expect(client.user.findMany({ distinct: ['role'] } as any)).rejects.toThrow('not supported'); - return; - } - // single field distinct let r: any = await client.user.findMany({ distinct: ['role'] } as any); expect(r).toHaveLength(2); diff --git a/tests/e2e/orm/client-api/json-filter.test.ts b/tests/e2e/orm/client-api/json-filter.test.ts index 4260d439..6b127a81 100644 --- a/tests/e2e/orm/client-api/json-filter.test.ts +++ b/tests/e2e/orm/client-api/json-filter.test.ts @@ -7,7 +7,9 @@ import { schema as typedJsonSchema } from '../schemas/typed-json/schema'; describe('Json filter tests', () => { it('works with simple equality filter', async () => { const db = await createTestClient(schema); - await db.foo.create({ data: { data: { hello: 'world' } } }); + await expect(db.foo.create({ data: { data: { hello: 'world' } } })).resolves.toMatchObject({ + data: { hello: 'world' }, + }); await expect(db.foo.findFirst({ where: { data: { equals: { hello: 'world' } } } })).resolves.toMatchObject({ data: { hello: 'world' }, diff --git a/tests/e2e/orm/client-api/mixin.test.ts b/tests/e2e/orm/client-api/mixin.test.ts index e373b8fe..4dc1001b 100644 --- a/tests/e2e/orm/client-api/mixin.test.ts +++ b/tests/e2e/orm/client-api/mixin.test.ts @@ -75,7 +75,7 @@ model Bar with CommonFields { description: 'Bar', }, }), - ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().match(/(constraint)|(duplicate)/i)); }); it('supports multiple-level mixins', async () => { diff --git a/tests/e2e/orm/client-api/name-mapping.test.ts b/tests/e2e/orm/client-api/name-mapping.test.ts index 2b70cd46..f45bddee 100644 --- a/tests/e2e/orm/client-api/name-mapping.test.ts +++ b/tests/e2e/orm/client-api/name-mapping.test.ts @@ -85,20 +85,34 @@ describe('Name mapping tests', () => { user_role: 'MODERATOR', }); - await expect( - db.$qb + const mysql = db.$schema.provider.type === ('mysql' as any); + + if (!mysql) { + await expect( + db.$qb + .insertInto('User') + .values({ + email: 'u2@test.com', + role: 'ADMIN', + }) + .returning(['id', 'email', 'role']) + .executeTakeFirst(), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u2@test.com', + role: 'ADMIN', + }); + } else { + // mysql doesn't support returning, simply insert + await db.$qb .insertInto('User') .values({ email: 'u2@test.com', role: 'ADMIN', }) - .returning(['id', 'email', 'role']) - .executeTakeFirst(), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u2@test.com', - role: 'ADMIN', - }); + .executeTakeFirst(); + } + rawRead = await db.$qbRaw .selectFrom('users') .where('user_email', '=', 'u2@test.com') @@ -108,32 +122,34 @@ describe('Name mapping tests', () => { user_role: 'role_admin', }); - await expect( - db.$qb - .insertInto('User') - .values({ - email: 'u3@test.com', - }) - .returning(['User.id', 'User.email']) - .executeTakeFirst(), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u3@test.com', - }); - - await expect( - db.$qb - .insertInto('User') - .values({ - email: 'u4@test.com', - }) - .returningAll() - .executeTakeFirst(), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u4@test.com', - role: 'USER', - }); + if (!mysql) { + await expect( + db.$qb + .insertInto('User') + .values({ + email: 'u3@test.com', + }) + .returning(['User.id', 'User.email']) + .executeTakeFirst(), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u3@test.com', + }); + + await expect( + db.$qb + .insertInto('User') + .values({ + email: 'u4@test.com', + }) + .returningAll() + .executeTakeFirst(), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u4@test.com', + role: 'USER', + }); + } }); it('works with find', async () => { @@ -379,18 +395,20 @@ describe('Name mapping tests', () => { posts: [expect.objectContaining({ title: 'Post2' })], }); - await expect( - db.$qb - .updateTable('User') - .set({ email: (eb) => eb.fn('upper', [eb.ref('email')]), role: 'USER' }) - .where('email', '=', 'u2@test.com') - .returning(['email', 'role']) - .executeTakeFirst(), - ).resolves.toMatchObject({ email: 'U2@TEST.COM', role: 'USER' }); - - await expect( - db.$qb.updateTable('User as u').set({ email: 'u3@test.com' }).returningAll().executeTakeFirst(), - ).resolves.toMatchObject({ id: expect.any(Number), email: 'u3@test.com', role: 'USER' }); + if (db.$schema.provider.type !== ('mysql' as any)) { + await expect( + db.$qb + .updateTable('User') + .set({ email: (eb) => eb.fn('upper', [eb.ref('email')]), role: 'USER' }) + .where('email', '=', 'u2@test.com') + .returning(['email', 'role']) + .executeTakeFirst(), + ).resolves.toMatchObject({ email: 'U2@TEST.COM', role: 'USER' }); + + await expect( + db.$qb.updateTable('User as u').set({ email: 'u3@test.com' }).returningAll().executeTakeFirst(), + ).resolves.toMatchObject({ id: expect.any(Number), email: 'u3@test.com', role: 'USER' }); + } }); it('works with delete', async () => { @@ -406,12 +424,17 @@ describe('Name mapping tests', () => { }, }); - await expect( - db.$qb.deleteFrom('Post').where('title', '=', 'Post1').returning(['id', 'title']).executeTakeFirst(), - ).resolves.toMatchObject({ - id: user.id, - title: 'Post1', - }); + if (db.$schema.provider.type !== ('mysql' as any)) { + await expect( + db.$qb.deleteFrom('Post').where('title', '=', 'Post1').returning(['id', 'title']).executeTakeFirst(), + ).resolves.toMatchObject({ + id: user.id, + title: 'Post1', + }); + } else { + // mysql doesn't support returning, simply delete + await db.$qb.deleteFrom('Post').where('title', '=', 'Post1').executeTakeFirst(); + } await expect( db.user.delete({ diff --git a/tests/e2e/orm/client-api/raw-query.test.ts b/tests/e2e/orm/client-api/raw-query.test.ts index 7fe0ecbc..1e6e889e 100644 --- a/tests/e2e/orm/client-api/raw-query.test.ts +++ b/tests/e2e/orm/client-api/raw-query.test.ts @@ -2,6 +2,9 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '@zenstackhq/orm'; import { schema } from '../schemas/basic'; import { createTestClient } from '@zenstackhq/testtools'; +import { sql } from 'kysely'; +import { match } from 'ts-pattern'; +import type { DataSourceProviderType } from '@zenstackhq/schema'; describe('Client raw query tests', () => { let client: ClientContract; @@ -14,6 +17,10 @@ describe('Client raw query tests', () => { await client?.$disconnect(); }); + function ref(col: string) { + return client.$schema.provider.type === ('mysql' as any) ? sql.raw(`\`${col}\``) : sql.raw(`"${col}"`); + } + it('works with executeRaw', async () => { await client.user.create({ data: { @@ -23,7 +30,7 @@ describe('Client raw query tests', () => { }); await expect( - client.$executeRaw`UPDATE "User" SET "email" = ${'u2@test.com'} WHERE "id" = ${'1'}`, + client.$executeRaw`UPDATE ${ref('User')} SET ${ref('email')} = ${'u2@test.com'} WHERE ${ref('id')} = ${'1'}`, ).resolves.toBe(1); await expect(client.user.findFirst()).resolves.toMatchObject({ email: 'u2@test.com' }); }); @@ -36,11 +43,11 @@ describe('Client raw query tests', () => { }, }); - const sql = - // @ts-ignore - client.$schema.provider.type === 'postgresql' - ? `UPDATE "User" SET "email" = $1 WHERE "id" = $2` - : `UPDATE "User" SET "email" = ? WHERE "id" = ?`; + const sql = match(client.$schema.provider.type as DataSourceProviderType) + .with('postgresql', () => `UPDATE "User" SET "email" = $1 WHERE "id" = $2`) + .with('mysql', () => 'UPDATE `User` SET `email` = ? WHERE `id` = ?') + .with('sqlite', () => 'UPDATE "User" SET "email" = ? WHERE "id" = ?') + .exhaustive(); await expect(client.$executeRawUnsafe(sql, 'u2@test.com', '1')).resolves.toBe(1); await expect(client.user.findFirst()).resolves.toMatchObject({ email: 'u2@test.com' }); }); @@ -56,7 +63,7 @@ describe('Client raw query tests', () => { const uid = '1'; const users = await client.$queryRaw< { id: string; email: string }[] - >`SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ${uid}`; + >`SELECT ${ref('User')}.${ref('id')}, ${ref('User')}.${ref('email')} FROM ${ref('User')} WHERE ${ref('User')}.${ref('id')} = ${uid}`; expect(users).toEqual([{ id: '1', email: 'u1@test.com' }]); }); @@ -68,11 +75,12 @@ describe('Client raw query tests', () => { }, }); - const sql = - // @ts-ignore - client.$schema.provider.type === 'postgresql' - ? `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = $1` - : `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ?`; + const sql = match(client.$schema.provider.type as DataSourceProviderType) + .with('postgresql', () => `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = $1`) + .with('mysql', () => 'SELECT `User`.`id`, `User`.`email` FROM `User` WHERE `User`.`id` = ?') + .with('sqlite', () => 'SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ?') + .exhaustive(); + const users = await client.$queryRawUnsafe<{ id: string; email: string }[]>(sql, '1'); expect(users).toEqual([{ id: '1', email: 'u1@test.com' }]); }); diff --git a/tests/e2e/orm/client-api/relation/self-relation.test.ts b/tests/e2e/orm/client-api/relation/self-relation.test.ts index 9e70ca00..8dc7e5b5 100644 --- a/tests/e2e/orm/client-api/relation/self-relation.test.ts +++ b/tests/e2e/orm/client-api/relation/self-relation.test.ts @@ -1,5 +1,5 @@ -import { afterEach, describe, expect, it } from 'vitest'; import { createTestClient } from '@zenstackhq/testtools'; +import { afterEach, describe, expect, it } from 'vitest'; describe('Self relation tests', () => { let client: any; diff --git a/tests/e2e/orm/client-api/type-coverage.test.ts b/tests/e2e/orm/client-api/type-coverage.test.ts index a0c24880..bbe87277 100644 --- a/tests/e2e/orm/client-api/type-coverage.test.ts +++ b/tests/e2e/orm/client-api/type-coverage.test.ts @@ -35,6 +35,7 @@ describe('Zmodel type coverage tests', () => { Json Json } `, + { usePrismaPush: true }, ); await db.foo.create({ data }); @@ -80,7 +81,7 @@ describe('Zmodel type coverage tests', () => { }); it('supports all types - array', async () => { - if (getTestDbProvider() === 'sqlite') { + if (getTestDbProvider() !== 'postgresql') { return; } diff --git a/tests/e2e/orm/client-api/update-many.test.ts b/tests/e2e/orm/client-api/update-many.test.ts index 61776e3e..dfb598d9 100644 --- a/tests/e2e/orm/client-api/update-many.test.ts +++ b/tests/e2e/orm/client-api/update-many.test.ts @@ -80,6 +80,11 @@ describe('Client updateMany tests', () => { }); it('works with updateManyAndReturn', async () => { + if (client.$schema.provider.type === ('mysql' as any)) { + // skip for mysql as it does not support returning + return; + } + await client.user.create({ data: { id: '1', email: 'u1@test.com', name: 'User1' }, }); diff --git a/tests/e2e/orm/client-api/update.test.ts b/tests/e2e/orm/client-api/update.test.ts index c79396d7..08377e51 100644 --- a/tests/e2e/orm/client-api/update.test.ts +++ b/tests/e2e/orm/client-api/update.test.ts @@ -115,6 +115,30 @@ describe('Client update tests', () => { ).resolves.toMatchObject({ id: 'user2' }); }); + it('works with update with unchanged data or no data', async () => { + const user = await createUser(client, 'u1@test.com'); + await expect( + client.user.update({ + where: { id: user.id }, + data: { + email: user.email, + // force a no-op update + updatedAt: user.updatedAt, + }, + }), + ).resolves.toEqual(user); + + await expect( + client.user.update({ + where: { id: user.id }, + data: {}, + }), + ).resolves.toEqual(user); + + const plain = await client.plain.create({ data: { value: 42 } }); + await expect(client.plain.update({ where: { id: plain.id }, data: { value: 42 } })).resolves.toEqual(plain); + }); + it('does not update updatedAt if no other scalar fields are updated', async () => { const user = await createUser(client, 'u1@test.com'); const originalUpdatedAt = user.updatedAt; @@ -1050,7 +1074,7 @@ describe('Client update tests', () => { }, }, }), - ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().match(/(constraint)|(duplicate)/i)); // transaction fails as a whole await expect(client.comment.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ content: 'Comment3', diff --git a/tests/e2e/orm/plugin-infra/entity-mutation-hooks.test.ts b/tests/e2e/orm/plugin-infra/entity-mutation-hooks.test.ts index 4c97d1a9..eb8a25f1 100644 --- a/tests/e2e/orm/plugin-infra/entity-mutation-hooks.test.ts +++ b/tests/e2e/orm/plugin-infra/entity-mutation-hooks.test.ts @@ -8,7 +8,7 @@ describe('Entity mutation hooks tests', () => { let _client: ClientContract; beforeEach(async () => { - _client = (await createTestClient(schema, {})) as any; + _client = await createTestClient(schema); }); afterEach(async () => { @@ -67,6 +67,8 @@ describe('Entity mutation hooks tests', () => { delete: { before: '', after: '' }, }; + let beforeMutationEntitiesInAfterHooks: Record[] | undefined; + const client = _client.$use({ id: 'test', onEntityMutation: { @@ -84,6 +86,10 @@ describe('Entity mutation hooks tests', () => { if (args.action === 'update' || args.action === 'delete') { queryIds[args.action].after = args.queryId.queryId; } + + if (args.action === 'update') { + beforeMutationEntitiesInAfterHooks = args.beforeMutationEntities; + } }, }, }); @@ -94,10 +100,14 @@ describe('Entity mutation hooks tests', () => { await client.user.create({ data: { email: 'u2@test.com' }, }); + await client.user.update({ where: { id: user.id }, data: { email: 'u3@test.com' }, }); + // beforeMutationEntities in after hooks is available because we called loadBeforeMutationEntities in before hook + expect(beforeMutationEntitiesInAfterHooks).toEqual([expect.objectContaining({ email: 'u1@test.com' })]); + await client.user.delete({ where: { id: user.id } }); expect(queryIds.update.before).toBeTruthy(); diff --git a/tests/e2e/orm/plugin-infra/ext-query-args.test.ts b/tests/e2e/orm/plugin-infra/ext-query-args.test.ts index 95794626..3a15321e 100644 --- a/tests/e2e/orm/plugin-infra/ext-query-args.test.ts +++ b/tests/e2e/orm/plugin-infra/ext-query-args.test.ts @@ -122,9 +122,14 @@ describe('Plugin extended query args', () => { await expect( extDb.user.createMany({ data: [{ name: 'Charlie' }], ...cacheBustOption }), ).resolves.toHaveProperty('count'); - await expect( - extDb.user.createManyAndReturn({ data: [{ name: 'David' }], ...cacheBustOption }), - ).toResolveWithLength(1); + + const isMySql = db.$schema.provider.type === ('mysql' as any); + + if (!isMySql) { + await expect( + extDb.user.createManyAndReturn({ data: [{ name: 'David' }], ...cacheBustOption }), + ).toResolveWithLength(1); + } // update operations await expect( @@ -133,13 +138,17 @@ describe('Plugin extended query args', () => { await expect( extDb.user.updateMany({ where: { name: 'Bob' }, data: { name: 'Bob Updated' }, ...cacheBustOption }), ).resolves.toHaveProperty('count'); - await expect( - extDb.user.updateManyAndReturn({ - where: { name: 'Charlie' }, - data: { name: 'Charlie Updated' }, - ...cacheBustOption, - }), - ).toResolveTruthy(); + + if (!isMySql) { + await expect( + extDb.user.updateManyAndReturn({ + where: { name: 'Charlie' }, + data: { name: 'Charlie Updated' }, + ...cacheBustOption, + }), + ).toResolveTruthy(); + } + await expect( extDb.user.upsert({ where: { id: 999 }, diff --git a/tests/e2e/orm/plugin-infra/on-kysely-query.test.ts b/tests/e2e/orm/plugin-infra/on-kysely-query.test.ts index 68602613..d6e17c04 100644 --- a/tests/e2e/orm/plugin-infra/on-kysely-query.test.ts +++ b/tests/e2e/orm/plugin-infra/on-kysely-query.test.ts @@ -1,6 +1,6 @@ import { type ClientContract } from '@zenstackhq/orm'; import { createTestClient } from '@zenstackhq/testtools'; -import { InsertQueryNode, Kysely, PrimitiveValueListNode, ValuesNode, type QueryResult } from 'kysely'; +import { InsertQueryNode, PrimitiveValueListNode, ValuesNode } from 'kysely'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { schema } from '../schemas/basic'; @@ -93,7 +93,14 @@ describe('On kysely query tests', () => { const result = await proceed(query); // create a post for the user - await proceed(createPost(client.$qb, result)); + const now = new Date().toISOString().replace('Z', '+00:00'); // for mysql compatibility + const createPost = client.$qb.insertInto('Post').values({ + id: '1', + title: 'Post1', + authorId: '1', + updatedAt: now, + }); + await proceed(createPost.toOperationNode()); return result; }, @@ -218,14 +225,3 @@ describe('On kysely query tests', () => { await expect(client.user.findFirst()).toResolveNull(); }); }); - -function createPost(kysely: Kysely, userRows: QueryResult) { - const now = new Date().toISOString(); - const createPost = kysely.insertInto('Post').values({ - id: '1', - title: 'Post1', - authorId: (userRows.rows[0] as any).id, - updatedAt: now, - }); - return createPost.toOperationNode(); -} diff --git a/tests/e2e/orm/policy/crud/create.test.ts b/tests/e2e/orm/policy/crud/create.test.ts index 6aecba29..124a23eb 100644 --- a/tests/e2e/orm/policy/crud/create.test.ts +++ b/tests/e2e/orm/policy/crud/create.test.ts @@ -16,12 +16,11 @@ model Foo { await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); - await expect( - db.$qb.insertInto('Foo').values({ x: 0 }).returningAll().executeTakeFirst(), - ).toBeRejectedByPolicy(); - await expect( - db.$qb.insertInto('Foo').values({ x: 1 }).returningAll().executeTakeFirst(), - ).resolves.toMatchObject({ x: 1 }); + await expect(db.$qb.insertInto('Foo').values({ x: 0 }).executeTakeFirst()).toBeRejectedByPolicy(); + + await expect(db.$qb.insertInto('Foo').values({ x: 1 }).executeTakeFirst()).toResolveTruthy(); + + await expect(db.foo.findMany({ where: { x: 1 } })).resolves.toHaveLength(2); }); it('works with this scalar member check', async () => { diff --git a/tests/e2e/orm/policy/crud/read.test.ts b/tests/e2e/orm/policy/crud/read.test.ts index fd767b51..febe3503 100644 --- a/tests/e2e/orm/policy/crud/read.test.ts +++ b/tests/e2e/orm/policy/crud/read.test.ts @@ -632,8 +632,14 @@ model Bar { @@allow('read', y > 0) } `, + { provider: 'postgresql' }, ); + if (db.$schema.provider.type !== 'postgresql') { + // skip for non-postgresql as from is only supported there + return; + } + await db.$unuseAll().foo.create({ data: { id: 1, x: 1 } }); await db.$unuseAll().bar.create({ data: { id: 1, y: 0 } }); diff --git a/tests/e2e/orm/policy/crud/update.test.ts b/tests/e2e/orm/policy/crud/update.test.ts index 95b17306..d7f3a8a2 100644 --- a/tests/e2e/orm/policy/crud/update.test.ts +++ b/tests/e2e/orm/policy/crud/update.test.ts @@ -23,9 +23,12 @@ model Foo { await expect( db.$qb.updateTable('Foo').set({ x: 1 }).where('id', '=', 1).executeTakeFirst(), ).resolves.toMatchObject({ numUpdatedRows: 0n }); - await expect( - db.$qb.updateTable('Foo').set({ x: 3 }).where('id', '=', 2).returningAll().execute(), - ).resolves.toMatchObject([{ id: 2, x: 3 }]); + + if (db.$schema.provider.type !== 'mysql') { + await expect( + db.$qb.updateTable('Foo').set({ x: 3 }).where('id', '=', 2).returningAll().execute(), + ).resolves.toMatchObject([{ id: 2, x: 3 }]); + } }); it('works with this scalar member check', async () => { @@ -758,7 +761,7 @@ model Post { }, }, }), - ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().match(/(constraint)|(duplicate)/)); await db.$unuseAll().post.update({ where: { id: 1 }, data: { title: 'Bar Post' } }); // can update await expect( @@ -1124,7 +1127,7 @@ model Foo { // can't update, but create violates unique constraint await expect( db.foo.upsert({ where: { id: 1 }, create: { id: 1, x: 1 }, update: { x: 1 } }), - ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().match(/(constraint)|(duplicate)/)); await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 2 } }); // can update now await expect( @@ -1226,38 +1229,51 @@ model Foo { ], }); + const mysql = db.$schema.provider.type === 'mysql'; + // #1 not updatable - await expect( - db.$qb - .insertInto('Foo') - .values({ id: 1, x: 5 }) - .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 })) - .executeTakeFirst(), - ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n }); - await expect(db.foo.count()).resolves.toBe(3); + const r = await db.$qb + .insertInto('Foo') + .values({ id: 1, x: 5 }) + .$if(mysql, (qb) => qb.onDuplicateKeyUpdate({ x: 5 })) + .$if(!mysql, (qb) => qb.onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }))) + .executeTakeFirst(); + + if (!mysql) { + expect(r).toMatchObject({ numInsertedOrUpdatedRows: 0n }); + } else { + // mysql's on duplicate key update returns rows affected even if no values are changed + expect(r).toMatchObject({ numInsertedOrUpdatedRows: 1n }); + } + // verify not updated await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); - // with where, #1 not updatable - await expect( - db.$qb - .insertInto('Foo') - .values({ id: 1, x: 5 }) - .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }).where('Foo.id', '=', 1)) - .executeTakeFirst(), - ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n }); await expect(db.foo.count()).resolves.toBe(3); await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); - // with where, #2 updatable - await expect( - db.$qb - .insertInto('Foo') - .values({ id: 2, x: 5 }) - .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 6 }).where('Foo.id', '=', 2)) - .executeTakeFirst(), - ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 1n }); - await expect(db.foo.count()).resolves.toBe(3); - await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 6 }); + if (!mysql) { + // with where, #1 not updatable + await expect( + db.$qb + .insertInto('Foo') + .values({ id: 1, x: 5 }) + .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }).where('Foo.id', '=', 1)) + .executeTakeFirst(), + ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n }); + await expect(db.foo.count()).resolves.toBe(3); + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // with where, #2 updatable + await expect( + db.$qb + .insertInto('Foo') + .values({ id: 2, x: 5 }) + .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 6 }).where('Foo.id', '=', 2)) + .executeTakeFirst(), + ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 1n }); + await expect(db.foo.count()).resolves.toBe(3); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 6 }); + } }); }); }); diff --git a/tests/e2e/orm/policy/migrated/auth.test.ts b/tests/e2e/orm/policy/migrated/auth.test.ts index b3e49980..c612e120 100644 --- a/tests/e2e/orm/policy/migrated/auth.test.ts +++ b/tests/e2e/orm/policy/migrated/auth.test.ts @@ -325,15 +325,18 @@ model Post { await expect(userDb.post.count({ where: { authorName: 'user1', score: 10 } })).resolves.toBe(1); await expect(userDb.post.createMany({ data: [{ title: 'def' }] })).resolves.toMatchObject({ count: 1 }); - const r = await userDb.post.createManyAndReturn({ - data: [{ title: 'xxx' }, { title: 'yyy' }], - }); - expect(r).toEqual( - expect.arrayContaining([ - expect.objectContaining({ title: 'xxx', score: 10 }), - expect.objectContaining({ title: 'yyy', score: 10 }), - ]), - ); + + if (userDb.$schema.provider.type !== 'mysql') { + const r = await userDb.post.createManyAndReturn({ + data: [{ title: 'xxx' }, { title: 'yyy' }], + }); + expect(r).toEqual( + expect.arrayContaining([ + expect.objectContaining({ title: 'xxx', score: 10 }), + expect.objectContaining({ title: 'yyy', score: 10 }), + ]), + ); + } }); it('respects explicitly passed field values even when default is set', async () => { @@ -363,10 +366,12 @@ model Post { await expect(userDb.post.createMany({ data: [{ authorName: overrideName }] })).toResolveTruthy(); await expect(userDb.post.count({ where: { authorName: overrideName } })).resolves.toBe(2); - const r = await userDb.post.createManyAndReturn({ - data: [{ authorName: overrideName }], - }); - expect(r[0]).toMatchObject({ authorName: overrideName }); + if (userDb.$schema.provider.type !== 'mysql') { + const r = await userDb.post.createManyAndReturn({ + data: [{ authorName: overrideName }], + }); + expect(r[0]).toMatchObject({ authorName: overrideName }); + } }); it('works with using auth value as default for foreign key', async () => { @@ -439,11 +444,13 @@ model Post { const r = await db.post.findFirst({ where: { title: 'post5' } }); expect(r).toMatchObject({ authorId: 'userId-1' }); - // default auth effective for createManyAndReturn - const r1 = await db.post.createManyAndReturn({ - data: { title: 'post6' }, - }); - expect(r1[0]).toMatchObject({ authorId: 'userId-1' }); + if (db.$schema.provider.type !== 'mysql') { + // default auth effective for createManyAndReturn + const r1 = await db.post.createManyAndReturn({ + data: { title: 'post6' }, + }); + expect(r1[0]).toMatchObject({ authorId: 'userId-1' }); + } }); it('works with using nested auth value as default', async () => { @@ -537,7 +544,7 @@ model Post { await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); await expect(db.post.create({ data: { title: 'title' } })).rejects.toSatisfy((e) => - e.cause.message.toLowerCase().includes('constraint'), + e.cause.message.toLowerCase().match(/(constraint)|(cannot be null)/), ); await expect(db.post.findMany({})).toResolveTruthy(); }); @@ -591,10 +598,13 @@ model Post { }); await db.stats.create({ data: { id: 'stats-3', viewCount: 10 } }); - const r = await db.post.createManyAndReturn({ - data: [{ title: 'title3', statsId: 'stats-3' }], - }); - expect(r[0]).toMatchObject({ statsId: 'stats-3' }); + + if (db.$schema.provider.type !== 'mysql') { + const r = await db.post.createManyAndReturn({ + data: [{ title: 'title3', statsId: 'stats-3' }], + }); + expect(r[0]).toMatchObject({ statsId: 'stats-3' }); + } // checked context await db.stats.create({ data: { id: 'stats-4', viewCount: 10 } }); diff --git a/tests/e2e/orm/policy/migrated/create-many-and-return.test.ts b/tests/e2e/orm/policy/migrated/create-many-and-return.test.ts index 5191853c..d11723d7 100644 --- a/tests/e2e/orm/policy/migrated/create-many-and-return.test.ts +++ b/tests/e2e/orm/policy/migrated/create-many-and-return.test.ts @@ -25,6 +25,12 @@ describe('createManyAndReturn tests', () => { } `, ); + + if (db.$schema.provider.type === 'mysql') { + // MySQL does not support createManyAndReturn + return; + } + const rawDb = db.$unuseAll(); await rawDb.user.createMany({ @@ -60,8 +66,7 @@ describe('createManyAndReturn tests', () => { await expect(rawDb.post.findMany()).resolves.toHaveLength(3); }); - // TODO: field-level policies support - it.skip('field-level policies', async () => { + it('field-level policies', async () => { const db = await createPolicyTestClient( ` model Post { @@ -73,6 +78,12 @@ describe('createManyAndReturn tests', () => { } `, ); + + if (db.$schema.provider.type === 'mysql') { + // MySQL does not support createManyAndReturn + return; + } + const rawDb = db.$unuseAll(); // create should succeed but one result's title field can't be read back const r = await db.post.createManyAndReturn({ @@ -84,7 +95,7 @@ describe('createManyAndReturn tests', () => { expect(r.length).toBe(2); expect(r[0].title).toBeTruthy(); - expect(r[1].title).toBeUndefined(); + expect(r[1].title).toBeNull(); // check posts are created await expect(rawDb.post.findMany()).resolves.toHaveLength(2); diff --git a/tests/e2e/orm/policy/migrated/current-model.test.ts b/tests/e2e/orm/policy/migrated/current-model.test.ts index 0ceb96c3..7272e74a 100644 --- a/tests/e2e/orm/policy/migrated/current-model.test.ts +++ b/tests/e2e/orm/policy/migrated/current-model.test.ts @@ -41,7 +41,13 @@ describe('currentModel tests', () => { ); await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); - await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + + if (db.$schema.provider.type !== 'mysql') { + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + } else { + // mysql string comparison is case insensitive by default + await expect(db.post.create({ data: { id: 1 } })).toResolveTruthy(); + } }); it('works with lower case', async () => { @@ -62,7 +68,13 @@ describe('currentModel tests', () => { ); await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); - await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + + if (db.$schema.provider.type !== 'mysql') { + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + } else { + // mysql string comparison is case insensitive by default + await expect(db.post.create({ data: { id: 1 } })).toResolveTruthy(); + } }); it('works with capitalization', async () => { @@ -83,7 +95,13 @@ describe('currentModel tests', () => { ); await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); - await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + + if (db.$schema.provider.type !== 'mysql') { + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + } else { + // mysql string comparison is case insensitive by default + await expect(db.post.create({ data: { id: 1 } })).toResolveTruthy(); + } }); it('works with uncapitalization', async () => { @@ -104,7 +122,13 @@ describe('currentModel tests', () => { ); await expect(db.USER.create({ data: { id: 1 } })).toResolveTruthy(); - await expect(db.POST.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + + if (db.$schema.provider.type !== 'mysql') { + await expect(db.POST.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + } else { + // mysql string comparison is case insensitive by default + await expect(db.POST.create({ data: { id: 1 } })).toResolveTruthy(); + } }); it('works when inherited from abstract base', async () => { diff --git a/tests/e2e/orm/policy/migrated/current-operation.test.ts b/tests/e2e/orm/policy/migrated/current-operation.test.ts index 3cbae4ca..374c7cf9 100644 --- a/tests/e2e/orm/policy/migrated/current-operation.test.ts +++ b/tests/e2e/orm/policy/migrated/current-operation.test.ts @@ -99,7 +99,13 @@ describe('currentOperation tests', () => { ); await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); - await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + + if (db.$schema.provider.type !== 'mysql') { + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + } else { + // mysql string comparison is case insensitive by default + await expect(db.post.create({ data: { id: 1 } })).toResolveTruthy(); + } }); it('works with uncapitalization', async () => { diff --git a/tests/e2e/orm/policy/migrated/deep-nested.test.ts b/tests/e2e/orm/policy/migrated/deep-nested.test.ts index f8bcea93..a118757c 100644 --- a/tests/e2e/orm/policy/migrated/deep-nested.test.ts +++ b/tests/e2e/orm/policy/migrated/deep-nested.test.ts @@ -482,7 +482,7 @@ describe('deep nested operations tests', () => { }, }, }), - ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().match(/(constraint)|(duplicate)/)); // createMany skip duplicate await db.m1.update({ diff --git a/tests/e2e/orm/policy/migrated/multi-id-fields.test.ts b/tests/e2e/orm/policy/migrated/multi-id-fields.test.ts index 0c3bdf2a..e056a315 100644 --- a/tests/e2e/orm/policy/migrated/multi-id-fields.test.ts +++ b/tests/e2e/orm/policy/migrated/multi-id-fields.test.ts @@ -96,13 +96,25 @@ describe('Policy tests multiple id fields', () => { db.a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 0 } }), ).toBeRejectedByPolicy(); - await expect( - db.a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 2 } }), - ).resolves.toMatchObject({ - x: '2', - y: 3, - value: 2, - }); + const mysql = db.$schema.provider.type === 'mysql'; + + if (!mysql) { + await expect( + db.a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 2 } }), + ).resolves.toMatchObject({ + x: '2', + y: 3, + value: 2, + }); + } else { + // mysql doesn't support post-update policies with id updates + await expect( + db.a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 2 } }), + ).toBeRejectedByPolicy(); + + // force update + await db.$unuseAll().a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 2 } }); + } await expect( db.a.upsert({ @@ -112,17 +124,28 @@ describe('Policy tests multiple id fields', () => { }), ).toBeRejectedByPolicy(); - await expect( - db.a.upsert({ - where: { x_y: { x: '2', y: 3 } }, - update: { x: '3', y: 4, value: 3 }, - create: { x: '4', y: 5, value: 5 }, - }), - ).resolves.toMatchObject({ - x: '3', - y: 4, - value: 3, - }); + if (!mysql) { + await expect( + db.a.upsert({ + where: { x_y: { x: '2', y: 3 } }, + update: { x: '3', y: 4, value: 3 }, + create: { x: '4', y: 5, value: 5 }, + }), + ).resolves.toMatchObject({ + x: '3', + y: 4, + value: 3, + }); + } else { + // mysql doesn't support post-update policies with id updates + await expect( + db.a.upsert({ + where: { x_y: { x: '2', y: 3 } }, + update: { x: '3', y: 4, value: 3 }, + create: { x: '4', y: 5, value: 5 }, + }), + ).toBeRejectedByPolicy(); + } }); it('multi-id auth', async () => { @@ -353,13 +376,34 @@ describe('Policy tests multiple id fields', () => { }), ).toBeRejectedByPolicy(); - await expect( - db.b.update({ + const mysql = db.$schema.provider.type === 'mysql'; + + if (!mysql) { + await expect( + db.b.update({ + where: { id: 1 }, + data: { a: { update: { where: { x_y: { x: '1', y: 1 } }, data: { x: '2', y: 2, value: 2 } } } }, + include: { a: true }, + }), + ).resolves.toMatchObject({ + a: expect.arrayContaining([expect.objectContaining({ x: '2', y: 2, value: 2 })]), + }); + } else { + // mysql doesn't support post-update policies with id updates + await expect( + db.b.update({ + where: { id: 1 }, + data: { a: { update: { where: { x_y: { x: '1', y: 1 } }, data: { x: '2', y: 2, value: 2 } } } }, + include: { a: true }, + }), + ).toBeRejectedByPolicy(); + + // force update + await db.$unuseAll().b.update({ where: { id: 1 }, data: { a: { update: { where: { x_y: { x: '1', y: 1 } }, data: { x: '2', y: 2, value: 2 } } } }, - include: { a: true }, - }), - ).resolves.toMatchObject({ a: expect.arrayContaining([expect.objectContaining({ x: '2', y: 2, value: 2 })]) }); + }); + } await expect( db.b.update({ @@ -376,20 +420,24 @@ describe('Policy tests multiple id fields', () => { }), ).toBeRejectedByPolicy(); - await expect( - db.b.update({ - where: { id: 1 }, - data: { - a: { - upsert: { - where: { x_y: { x: '2', y: 2 } }, - update: { x: '3', y: 3, value: 3 }, - create: { x: '4', y: 4, value: 4 }, + if (!mysql) { + await expect( + db.b.update({ + where: { id: 1 }, + data: { + a: { + upsert: { + where: { x_y: { x: '2', y: 2 } }, + update: { x: '3', y: 3, value: 3 }, + create: { x: '4', y: 4, value: 4 }, + }, }, }, - }, - include: { a: true }, - }), - ).resolves.toMatchObject({ a: expect.arrayContaining([expect.objectContaining({ x: '3', y: 3, value: 3 })]) }); + include: { a: true }, + }), + ).resolves.toMatchObject({ + a: expect.arrayContaining([expect.objectContaining({ x: '3', y: 3, value: 3 })]), + }); + } }); }); diff --git a/tests/e2e/orm/policy/migrated/nested-to-one.test.ts b/tests/e2e/orm/policy/migrated/nested-to-one.test.ts index f839336c..12481ae0 100644 --- a/tests/e2e/orm/policy/migrated/nested-to-one.test.ts +++ b/tests/e2e/orm/policy/migrated/nested-to-one.test.ts @@ -243,17 +243,32 @@ describe('With Policy:nested to-one', () => { }), ).toBeRejectedByPolicy(); - await expect( - db.m1.update({ - where: { id: '1' }, - data: { - m2: { - update: { id: '2', value: 3 }, + if (db.$schema.provider.type !== 'mysql') { + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { id: '2', value: 3 }, + }, }, - }, - include: { m2: true }, - }), - ).resolves.toMatchObject({ m2: expect.objectContaining({ id: '2', value: 3 }) }); + include: { m2: true }, + }), + ).resolves.toMatchObject({ m2: expect.objectContaining({ id: '2', value: 3 }) }); + } else { + // mysql does not support post-update with id updates + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { id: '2', value: 3 }, + }, + }, + include: { m2: true }, + }), + ).toBeRejectedByPolicy(); + } }); it('nested create', async () => { diff --git a/tests/e2e/orm/policy/migrated/toplevel-operations.test.ts b/tests/e2e/orm/policy/migrated/toplevel-operations.test.ts index 086efaf0..564b89b1 100644 --- a/tests/e2e/orm/policy/migrated/toplevel-operations.test.ts +++ b/tests/e2e/orm/policy/migrated/toplevel-operations.test.ts @@ -166,16 +166,38 @@ describe('Policy toplevel operations tests', () => { }), ).toBeRejectedByPolicy(); - // update success - await expect( - db.model.update({ + if (db.$schema.provider.type !== 'mysql') { + // update success + await expect( + db.model.update({ + where: { id: '1' }, + data: { + id: '2', + value: 3, + }, + }), + ).resolves.toMatchObject({ id: '2', value: 3 }); + } else { + // mysql doesn't support post-update with id updates + await expect( + db.model.update({ + where: { id: '1' }, + data: { + id: '2', + value: 3, + }, + }), + ).toBeRejectedByPolicy(); + + // force update + await db.$unuseAll().model.update({ where: { id: '1' }, data: { id: '2', value: 3, }, - }), - ).resolves.toMatchObject({ id: '2', value: 3 }); + }); + } // upsert denied await expect( @@ -192,20 +214,22 @@ describe('Policy toplevel operations tests', () => { }), ).toBeRejectedByPolicy(); - // upsert success - await expect( - db.model.upsert({ - where: { id: '2' }, - update: { - id: '3', - value: 4, - }, - create: { - id: '4', - value: 5, - }, - }), - ).resolves.toMatchObject({ id: '3', value: 4 }); + if (db.$schema.provider.type !== 'mysql') { + // upsert success + await expect( + db.model.upsert({ + where: { id: '2' }, + update: { + id: '3', + value: 4, + }, + create: { + id: '4', + value: 5, + }, + }), + ).resolves.toMatchObject({ id: '3', value: 4 }); + } }); it('delete tests', async () => { diff --git a/tests/e2e/orm/policy/migrated/update-many-and-return.test.ts b/tests/e2e/orm/policy/migrated/update-many-and-return.test.ts index 32367c35..38f33690 100644 --- a/tests/e2e/orm/policy/migrated/update-many-and-return.test.ts +++ b/tests/e2e/orm/policy/migrated/update-many-and-return.test.ts @@ -26,6 +26,11 @@ describe('Policy updateManyAndReturn tests', () => { `, ); + if (db.$schema.provider.type === 'mysql') { + // skip mysql as it doesn't support updateManyAndReturn + return; + } + const rawDb = db.$unuseAll(); await rawDb.user.createMany({ @@ -82,8 +87,7 @@ describe('Policy updateManyAndReturn tests', () => { await expect(db.$unuseAll().post.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ published: false }); }); - // TODO: field-level policy support - it.skip('field-level policies', async () => { + it('field-level policies', async () => { const db = await createPolicyTestClient( ` model Post { @@ -96,6 +100,11 @@ describe('Policy updateManyAndReturn tests', () => { `, ); + if (db.$schema.provider.type === 'mysql') { + // skip mysql as it doesn't support updateManyAndReturn + return; + } + const rawDb = db.$unuseAll(); // update should succeed but one result's title field can't be read back @@ -112,7 +121,7 @@ describe('Policy updateManyAndReturn tests', () => { expect(r.length).toBe(2); expect(r[0].title).toBeTruthy(); - expect(r[1].title).toBeUndefined(); + expect(r[1].title).toBeNull(); // check posts are updated await expect(rawDb.post.findMany({ where: { title: 'foo' } })).resolves.toHaveLength(2); diff --git a/tests/e2e/orm/policy/migrated/view.test.ts b/tests/e2e/orm/policy/migrated/view.test.ts index 7a8afe28..6251b37f 100644 --- a/tests/e2e/orm/policy/migrated/view.test.ts +++ b/tests/e2e/orm/policy/migrated/view.test.ts @@ -36,7 +36,13 @@ describe('View Policy Test', () => { const rawDb = db.$unuseAll(); - await rawDb.$executeRaw`CREATE VIEW "UserInfo" as select "User"."id", "User"."name", "User"."email", "User"."id" as "userId", count("Post"."id") as "postCount" from "User" left join "Post" on "User"."id" = "Post"."authorId" group by "User"."id";`; + if (['postgresql', 'sqlite'].includes(rawDb.$schema.provider.type)) { + await rawDb.$executeRaw`CREATE VIEW "UserInfo" as select "User"."id", "User"."name", "User"."email", "User"."id" as "userId", count("Post"."id") as "postCount" from "User" left join "Post" on "User"."id" = "Post"."authorId" group by "User"."id";`; + } else if (rawDb.$schema.provider.type === 'mysql') { + await rawDb.$executeRaw`CREATE VIEW UserInfo as select User.id, User.name, User.email, User.id as userId, count(Post.id) as postCount from User left join Post on User.id = Post.authorId group by User.id;`; + } else { + throw new Error(`Unsupported provider: ${rawDb.$schema.provider.type}`); + } await rawDb.user.create({ data: { diff --git a/tests/e2e/orm/policy/nonexistent-models.test.ts b/tests/e2e/orm/policy/nonexistent-models.test.ts index 70fd0ecb..6cde1054 100644 --- a/tests/e2e/orm/policy/nonexistent-models.test.ts +++ b/tests/e2e/orm/policy/nonexistent-models.test.ts @@ -15,9 +15,15 @@ describe('Policy tests for nonexistent models and fields', () => { const dbRaw = db.$unuseAll(); // create a Bar table - await dbRaw.$executeRawUnsafe( - `CREATE TABLE "Bar" ("id" TEXT PRIMARY KEY, "string" TEXT, "fooId" TEXT, FOREIGN KEY ("fooId") REFERENCES "Foo" ("id"));`, - ); + if (['postgresql', 'sqlite'].includes(dbRaw.$schema.provider.type)) { + await dbRaw.$executeRawUnsafe( + `CREATE TABLE "Bar" ("id" TEXT PRIMARY KEY, "string" TEXT, "fooId" TEXT, FOREIGN KEY ("fooId") REFERENCES "Foo" ("id"));`, + ); + } else { + await dbRaw.$executeRawUnsafe( + `CREATE TABLE Bar (id VARCHAR(191) PRIMARY KEY, string VARCHAR(191), fooId VARCHAR(191), FOREIGN KEY (fooId) REFERENCES Foo (id));`, + ); + } await dbRaw.$qb.insertInto('Foo').values({ id: '1', string: 'test' }).execute(); await dbRaw.$qb.insertInto('Bar').values({ id: '1', string: 'test', fooId: '1' }).execute(); diff --git a/tests/e2e/orm/policy/policy-functions.test.ts b/tests/e2e/orm/policy/policy-functions.test.ts index b30ea5ce..9d12defb 100644 --- a/tests/e2e/orm/policy/policy-functions.test.ts +++ b/tests/e2e/orm/policy/policy-functions.test.ts @@ -14,8 +14,8 @@ describe('policy functions tests', () => { ); await expect(db.foo.create({ data: { string: 'bcd' } })).toBeRejectedByPolicy(); - if (db.$schema.provider.type === 'sqlite') { - // sqlite is always case-insensitive + if (['sqlite', 'mysql'].includes(db.$schema.provider.type)) { + // sqlite and mysql are always case-insensitive await expect(db.foo.create({ data: { string: 'Acd' } })).toResolveTruthy(); } else { await expect(db.foo.create({ data: { string: 'Acd' } })).toBeRejectedByPolicy(); @@ -51,8 +51,8 @@ describe('policy functions tests', () => { ); await expect(db.foo.create({ data: { string: 'bcd' } })).toBeRejectedByPolicy(); - if (db.$schema.provider.type === 'sqlite') { - // sqlite is always case-insensitive + if (['sqlite', 'mysql'].includes(db.$schema.provider.type)) { + // sqlite and mysql are always case-insensitive await expect(db.foo.create({ data: { string: 'Acd' } })).toResolveTruthy(); } else { await expect(db.foo.create({ data: { string: 'Acd' } })).toBeRejectedByPolicy(); @@ -93,8 +93,8 @@ describe('policy functions tests', () => { await expect(db.foo.create({ data: {} })).toBeRejectedByPolicy(); await expect(db.$setAuth({ id: 'user1', name: 'bcd' }).foo.create({ data: {} })).toBeRejectedByPolicy(); await expect(db.$setAuth({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toResolveTruthy(); - if (db.$schema.provider.type === 'sqlite') { - // sqlite is always case-insensitive + if (['sqlite', 'mysql'].includes(db.$schema.provider.type)) { + // sqlite and mysql are always case-insensitive await expect(db.$setAuth({ id: 'user1', name: 'Abc' }).foo.create({ data: {} })).toResolveTruthy(); } else { await expect(db.$setAuth({ id: 'user1', name: 'Abc' }).foo.create({ data: {} })).toBeRejectedByPolicy(); diff --git a/tests/e2e/orm/query-builder/query-builder.test.ts b/tests/e2e/orm/query-builder/query-builder.test.ts index 563118a4..ea4aded2 100644 --- a/tests/e2e/orm/query-builder/query-builder.test.ts +++ b/tests/e2e/orm/query-builder/query-builder.test.ts @@ -17,7 +17,7 @@ describe('Client API tests', () => { .values({ id: uid, email: 'a@b.com', - updatedAt: new Date().toISOString(), + updatedAt: new Date().toISOString().replace('Z', '+00:00'), }) .execute(); @@ -31,7 +31,7 @@ describe('Client API tests', () => { authorId: uid, title: 'Post1', content: 'My post', - updatedAt: new Date().toISOString(), + updatedAt: new Date().toISOString().replace('Z', '+00:00'), }) .execute(); diff --git a/tests/e2e/orm/schemas/basic/input.ts b/tests/e2e/orm/schemas/basic/input.ts index 4bbec22c..90babcce 100644 --- a/tests/e2e/orm/schemas/basic/input.ts +++ b/tests/e2e/orm/schemas/basic/input.ts @@ -92,3 +92,24 @@ export type ProfileSelect = $SelectInput<$Schema, "Profile">; export type ProfileInclude = $IncludeInput<$Schema, "Profile">; export type ProfileOmit = $OmitInput<$Schema, "Profile">; export type ProfileGetPayload, Options extends $QueryOptions<$Schema> = $QueryOptions<$Schema>> = $Result<$Schema, "Profile", Args, Options>; +export type PlainFindManyArgs = $FindManyArgs<$Schema, "Plain">; +export type PlainFindUniqueArgs = $FindUniqueArgs<$Schema, "Plain">; +export type PlainFindFirstArgs = $FindFirstArgs<$Schema, "Plain">; +export type PlainExistsArgs = $ExistsArgs<$Schema, "Plain">; +export type PlainCreateArgs = $CreateArgs<$Schema, "Plain">; +export type PlainCreateManyArgs = $CreateManyArgs<$Schema, "Plain">; +export type PlainCreateManyAndReturnArgs = $CreateManyAndReturnArgs<$Schema, "Plain">; +export type PlainUpdateArgs = $UpdateArgs<$Schema, "Plain">; +export type PlainUpdateManyArgs = $UpdateManyArgs<$Schema, "Plain">; +export type PlainUpdateManyAndReturnArgs = $UpdateManyAndReturnArgs<$Schema, "Plain">; +export type PlainUpsertArgs = $UpsertArgs<$Schema, "Plain">; +export type PlainDeleteArgs = $DeleteArgs<$Schema, "Plain">; +export type PlainDeleteManyArgs = $DeleteManyArgs<$Schema, "Plain">; +export type PlainCountArgs = $CountArgs<$Schema, "Plain">; +export type PlainAggregateArgs = $AggregateArgs<$Schema, "Plain">; +export type PlainGroupByArgs = $GroupByArgs<$Schema, "Plain">; +export type PlainWhereInput = $WhereInput<$Schema, "Plain">; +export type PlainSelect = $SelectInput<$Schema, "Plain">; +export type PlainInclude = $IncludeInput<$Schema, "Plain">; +export type PlainOmit = $OmitInput<$Schema, "Plain">; +export type PlainGetPayload, Options extends $QueryOptions<$Schema> = $QueryOptions<$Schema>> = $Result<$Schema, "Plain", Args, Options>; diff --git a/tests/e2e/orm/schemas/basic/models.ts b/tests/e2e/orm/schemas/basic/models.ts index be197879..733e7df6 100644 --- a/tests/e2e/orm/schemas/basic/models.ts +++ b/tests/e2e/orm/schemas/basic/models.ts @@ -11,6 +11,7 @@ export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; export type Comment = $ModelResult<$Schema, "Comment">; export type Profile = $ModelResult<$Schema, "Profile">; +export type Plain = $ModelResult<$Schema, "Plain">; export type CommonFields = $TypeDefResult<$Schema, "CommonFields">; export const Role = $schema.enums.Role.values; export type Role = (typeof Role)[keyof typeof Role]; diff --git a/tests/e2e/orm/schemas/basic/schema.ts b/tests/e2e/orm/schemas/basic/schema.ts index 5f067685..1e559cc7 100644 --- a/tests/e2e/orm/schemas/basic/schema.ts +++ b/tests/e2e/orm/schemas/basic/schema.ts @@ -246,6 +246,26 @@ export class SchemaType implements SchemaDef { id: { type: "String" }, userId: { type: "String" } } + }, + Plain: { + name: "Plain", + fields: { + id: { + name: "id", + type: "Int", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("autoincrement") }] }], + default: ExpressionUtils.call("autoincrement") + }, + value: { + name: "value", + type: "Int" + } + }, + idFields: ["id"], + uniqueFields: { + id: { type: "Int" } + } } } as const; typeDefs = { diff --git a/tests/e2e/orm/schemas/basic/schema.zmodel b/tests/e2e/orm/schemas/basic/schema.zmodel index a831b827..9d3a7d91 100644 --- a/tests/e2e/orm/schemas/basic/schema.zmodel +++ b/tests/e2e/orm/schemas/basic/schema.zmodel @@ -64,3 +64,7 @@ model Foo { @@ignore } +model Plain { + id Int @id @default(autoincrement()) + value Int +} diff --git a/tests/e2e/package.json b/tests/e2e/package.json index 5022cf2e..f37d8b29 100644 --- a/tests/e2e/package.json +++ b/tests/e2e/package.json @@ -9,7 +9,8 @@ "test:typecheck": "tsc --noEmit", "test": "vitest run", "test:sqlite": "TEST_DB_PROVIDER=sqlite vitest run", - "test:postgresql": "TEST_DB_PROVIDER=postgresql vitest run" + "test:postgresql": "TEST_DB_PROVIDER=postgresql vitest run", + "test:mysql": "TEST_DB_PROVIDER=mysql vitest run" }, "dependencies": { "@paralleldrive/cuid2": "^2.2.2", @@ -26,7 +27,8 @@ "ulid": "^3.0.0", "uuid": "^11.0.5", "cuid": "^3.0.0", - "zod": "catalog:" + "zod": "catalog:", + "ts-pattern": "catalog:" }, "devDependencies": { "@zenstackhq/cli": "workspace:*", diff --git a/tests/regression/package.json b/tests/regression/package.json index 1e66dafb..d6b14dac 100644 --- a/tests/regression/package.json +++ b/tests/regression/package.json @@ -6,7 +6,10 @@ "scripts": { "build": "pnpm run test:generate", "test:generate": "tsx ../../scripts/test-generate.ts ./test", - "test": "pnpm test:generate && tsc && vitest run" + "test": "pnpm test:generate && tsc && vitest run", + "test:sqlite": "TEST_DB_PROVIDER=sqlite vitest run", + "test:postgresql": "TEST_DB_PROVIDER=postgresql vitest run", + "test:mysql": "TEST_DB_PROVIDER=mysql vitest run" }, "dependencies": { "@zenstackhq/testtools": "workspace:*", diff --git a/tests/regression/test/issue-493.test.ts b/tests/regression/test/issue-493.test.ts index 269a68f3..fb30ef3d 100644 --- a/tests/regression/test/issue-493.test.ts +++ b/tests/regression/test/issue-493.test.ts @@ -40,7 +40,7 @@ model Foo { } `; - const db = await createTestClient(schema, { provider: 'postgresql', debug: true }); + const db = await createTestClient(schema, { provider: 'postgresql' }); // plain JSON non-array await expect( diff --git a/tests/regression/test/v2-migrated/issue-1576.test.ts b/tests/regression/test/v2-migrated/issue-1576.test.ts index 91870b3e..3c62af01 100644 --- a/tests/regression/test/v2-migrated/issue-1576.test.ts +++ b/tests/regression/test/v2-migrated/issue-1576.test.ts @@ -40,24 +40,48 @@ describe('Regression for issue #1576', () => { }, }); - await expect( - db.goldItem.createManyAndReturn({ - data: [ - { - profileId: profile.id, - inventory: true, - }, - { - profileId: profile.id, - inventory: true, - }, - ], - }), - ).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ profileId: profile.id, type: 'GoldItem', inventory: true }), - expect.objectContaining({ profileId: profile.id, type: 'GoldItem', inventory: true }), - ]), - ); + if (db.$schema.provider.type !== 'mysql') { + await expect( + db.goldItem.createManyAndReturn({ + data: [ + { + profileId: profile.id, + inventory: true, + }, + { + profileId: profile.id, + inventory: true, + }, + ], + }), + ).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ profileId: profile.id, type: 'GoldItem', inventory: true }), + expect.objectContaining({ profileId: profile.id, type: 'GoldItem', inventory: true }), + ]), + ); + } else { + // mysql doesn't support createManyAndReturn + await expect( + db.goldItem.createMany({ + data: [ + { + profileId: profile.id, + inventory: true, + }, + { + profileId: profile.id, + inventory: true, + }, + ], + }), + ).toResolveTruthy(); + await expect(db.goldItem.findMany()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ profileId: profile.id, type: 'GoldItem', inventory: true }), + expect.objectContaining({ profileId: profile.id, type: 'GoldItem', inventory: true }), + ]), + ); + } }); }); diff --git a/tests/regression/test/v2-migrated/issue-1681.test.ts b/tests/regression/test/v2-migrated/issue-1681.test.ts index 0483a1c3..ba867dbf 100644 --- a/tests/regression/test/v2-migrated/issue-1681.test.ts +++ b/tests/regression/test/v2-migrated/issue-1681.test.ts @@ -25,7 +25,9 @@ describe('Regression for issue #1681', () => { const user = await db.user.create({ data: {} }); await expect(authDb.post.createMany({ data: [{ title: 'Post1' }] })).resolves.toMatchObject({ count: 1 }); - const r = await authDb.post.createManyAndReturn({ data: [{ title: 'Post2' }] }); - expect(r[0].authorId).toBe(user.id); + if (db.$schema.provider.type !== 'mysql') { + const r = await authDb.post.createManyAndReturn({ data: [{ title: 'Post2' }] }); + expect(r[0].authorId).toBe(user.id); + } }); }); diff --git a/tests/regression/test/v2-migrated/issue-1894.test.ts b/tests/regression/test/v2-migrated/issue-1894.test.ts index 8d745851..76d3d61f 100644 --- a/tests/regression/test/v2-migrated/issue-1894.test.ts +++ b/tests/regression/test/v2-migrated/issue-1894.test.ts @@ -42,7 +42,7 @@ describe('Regression for issue #1894', () => { }, ); - await db.a.create({ data: { id: 0 } }); - await expect(db.c.create({ data: { a: { connect: { id: 0 } } } })).toResolveTruthy(); + const r = await db.a.create({ data: { id: 0 } }); + await expect(db.c.create({ data: { a: { connect: { id: r.id } } } })).toResolveTruthy(); }); });