diff --git a/packages/orm/package.json b/packages/orm/package.json index 9fd31800f..173c16fd1 100644 --- a/packages/orm/package.json +++ b/packages/orm/package.json @@ -90,6 +90,7 @@ "json-stable-stringify": "^1.3.0", "kysely": "catalog:", "nanoid": "^5.0.9", + "postgres-array": "^3.0.4", "toposort": "^2.0.2", "ts-pattern": "catalog:", "ulid": "^3.0.0", diff --git a/packages/orm/src/client/crud/dialects/base-dialect.ts b/packages/orm/src/client/crud/dialects/base-dialect.ts index 30e330a40..26da88769 100644 --- a/packages/orm/src/client/crud/dialects/base-dialect.ts +++ b/packages/orm/src/client/crud/dialects/base-dialect.ts @@ -48,7 +48,7 @@ export abstract class BaseCrudDialect { return value; } - transformOutput(value: unknown, _type: BuiltinType) { + transformOutput(value: unknown, _type: BuiltinType, _array: boolean) { return value; } diff --git a/packages/orm/src/client/crud/dialects/postgresql.ts b/packages/orm/src/client/crud/dialects/postgresql.ts index 92e570fe8..bae049077 100644 --- a/packages/orm/src/client/crud/dialects/postgresql.ts +++ b/packages/orm/src/client/crud/dialects/postgresql.ts @@ -9,6 +9,7 @@ import { type SelectQueryBuilder, type SqlBool, } from 'kysely'; +import { parse as parsePostgresArray } from 'postgres-array'; import { match } from 'ts-pattern'; import z from 'zod'; import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types'; @@ -20,7 +21,9 @@ import type { ClientOptions } from '../../options'; import { buildJoinPairs, getDelegateDescendantModels, + getEnum, getManyToManyRelation, + isEnum, isRelationField, isTypeDef, requireField, @@ -28,6 +31,7 @@ import { requireModel, } from '../../query-utils'; import { BaseCrudDialect } from './base-dialect'; + export class PostgresCrudDialect extends BaseCrudDialect { private isoDateSchema = z.iso.datetime({ local: true, offset: true }); @@ -70,6 +74,16 @@ export class PostgresCrudDialect extends BaseCrudDiale if (type === 'Json' && !forArrayField) { // scalar `Json` fields need their input stringified return JSON.stringify(value); + } + if (isEnum(this.schema, type)) { + // cast to enum array `CAST(ARRAY[...] AS "enum_type"[])` + return this.eb.cast( + sql`ARRAY[${sql.join( + value.map((v) => this.transformPrimitive(v, type, false)), + sql.raw(','), + )}]`, + this.createSchemaQualifiedEnumType(type, true), + ); } else { // `Json[]` fields need their input as array (not stringified) return value.map((v) => this.transformPrimitive(v, type, false)); @@ -96,7 +110,33 @@ export class PostgresCrudDialect extends BaseCrudDiale } } - override transformOutput(value: unknown, type: BuiltinType) { + private createSchemaQualifiedEnumType(type: string, array: boolean) { + // determines the postgres schema name for the enum type, and returns the + // qualified name + + let qualified = type; + + const enumDef = getEnum(this.schema, type); + if (enumDef) { + // check if the enum has a custom "@@schema" attribute + const schemaAttr = enumDef.attributes?.find((attr) => attr.name === '@@schema'); + if (schemaAttr) { + const mapArg = schemaAttr.args?.find((arg) => arg.name === 'map'); + if (mapArg && mapArg.value.kind === 'literal') { + const schemaName = mapArg.value.value as string; + qualified = `"${schemaName}"."${type}"`; + } + } else { + // no custom schema, use default from datasource or 'public' + const defaultSchema = this.schema.provider.defaultSchema ?? 'public'; + qualified = `"${defaultSchema}"."${type}"`; + } + } + + return array ? sql.raw(`${qualified}[]`) : sql.raw(qualified); + } + + override transformOutput(value: unknown, type: BuiltinType, array: boolean) { if (value === null || value === undefined) { return value; } @@ -105,7 +145,11 @@ export class PostgresCrudDialect extends BaseCrudDiale .with('Bytes', () => this.transformOutputBytes(value)) .with('BigInt', () => this.transformOutputBigInt(value)) .with('Decimal', () => this.transformDecimal(value)) - .otherwise(() => super.transformOutput(value, type)); + .when( + (type) => isEnum(this.schema, type), + () => this.transformOutputEnum(value, array), + ) + .otherwise(() => super.transformOutput(value, type, array)); } private transformOutputBigInt(value: unknown) { @@ -162,6 +206,19 @@ export class PostgresCrudDialect extends BaseCrudDiale : value; } + private transformOutputEnum(value: unknown, array: boolean) { + if (array && typeof value === 'string') { + try { + // postgres returns enum arrays as `{"val 1",val2}` strings, parse them back + // to string arrays here + return parsePostgresArray(value); + } catch { + // fall through - return as-is if parsing fails + } + } + return value; + } + override buildRelationSelection( query: SelectQueryBuilder, model: string, diff --git a/packages/orm/src/client/crud/dialects/sqlite.ts b/packages/orm/src/client/crud/dialects/sqlite.ts index 3f0ae1dc6..32bb4e4ed 100644 --- a/packages/orm/src/client/crud/dialects/sqlite.ts +++ b/packages/orm/src/client/crud/dialects/sqlite.ts @@ -67,7 +67,7 @@ export class SqliteCrudDialect extends BaseCrudDialect } } - override transformOutput(value: unknown, type: BuiltinType) { + override transformOutput(value: unknown, type: BuiltinType, array: boolean) { if (value === null || value === undefined) { return value; } else if (this.schema.typeDefs && type in this.schema.typeDefs) { @@ -81,7 +81,7 @@ export class SqliteCrudDialect extends BaseCrudDialect .with('Decimal', () => this.transformOutputDecimal(value)) .with('BigInt', () => this.transformOutputBigInt(value)) .with('Json', () => this.transformOutputJson(value)) - .otherwise(() => super.transformOutput(value, type)); + .otherwise(() => super.transformOutput(value, type, array)); } } diff --git a/packages/orm/src/client/crud/validator/index.ts b/packages/orm/src/client/crud/validator/index.ts index 76214af91..50245c605 100644 --- a/packages/orm/src/client/crud/validator/index.ts +++ b/packages/orm/src/client/crud/validator/index.ts @@ -9,9 +9,9 @@ import { type BuiltinType, type EnumDef, type FieldDef, - type ProcedureDef, type GetModels, type ModelDef, + type ProcedureDef, type SchemaDef, } from '../../../schema'; import { extractFields } from '../../../utils/object-utils'; @@ -199,10 +199,7 @@ export class InputValidator { >(model, 'find', options, (model, options) => this.makeFindSchema(model, options), args); } - validateExistsArgs( - model: GetModels, - args: unknown, - ): ExistsArgs> | undefined { + validateExistsArgs(model: GetModels, args: unknown): ExistsArgs> | undefined { return this.validate>>( model, 'exists', @@ -429,9 +426,11 @@ export class InputValidator { } private makeExistsSchema(model: string) { - return z.strictObject({ - where: this.makeWhereSchema(model, false).optional(), - }).optional(); + return z + .strictObject({ + where: this.makeWhereSchema(model, false).optional(), + }) + .optional(); } private makeScalarSchema(type: string, attributes?: readonly AttributeApplication[]) { @@ -577,7 +576,12 @@ export class InputValidator { if (enumDef) { // enum if (Object.keys(enumDef.values).length > 0) { - fieldSchema = this.makeEnumFilterSchema(enumDef, !!fieldDef.optional, withAggregations); + fieldSchema = this.makeEnumFilterSchema( + enumDef, + !!fieldDef.optional, + withAggregations, + !!fieldDef.array, + ); } } else if (fieldDef.array) { // array field @@ -614,7 +618,12 @@ export class InputValidator { if (enumDef) { // enum if (Object.keys(enumDef.values).length > 0) { - fieldSchema = this.makeEnumFilterSchema(enumDef, !!def.optional, false); + fieldSchema = this.makeEnumFilterSchema( + enumDef, + !!def.optional, + false, + false, + ); } else { fieldSchema = z.never(); } @@ -696,24 +705,23 @@ export class InputValidator { !!fieldDef.array, ).optional(); } else { - // array, enum, primitives - if (fieldDef.array) { + // enum, array, primitives + const enumDef = getEnum(this.schema, fieldDef.type); + if (enumDef) { + fieldSchemas[fieldName] = this.makeEnumFilterSchema( + enumDef, + !!fieldDef.optional, + false, + !!fieldDef.array, + ).optional(); + } else if (fieldDef.array) { fieldSchemas[fieldName] = this.makeArrayFilterSchema(fieldDef.type as BuiltinType).optional(); } else { - const enumDef = getEnum(this.schema, fieldDef.type); - if (enumDef) { - fieldSchemas[fieldName] = this.makeEnumFilterSchema( - enumDef, - !!fieldDef.optional, - false, - ).optional(); - } else { - fieldSchemas[fieldName] = this.makePrimitiveFilterSchema( - fieldDef.type as BuiltinType, - !!fieldDef.optional, - false, - ).optional(); - } + fieldSchemas[fieldName] = this.makePrimitiveFilterSchema( + fieldDef.type as BuiltinType, + !!fieldDef.optional, + false, + ).optional(); } } } @@ -757,12 +765,15 @@ export class InputValidator { return this.schema.typeDefs && type in this.schema.typeDefs; } - private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean) { + private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean, array: boolean) { const baseSchema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]); + if (array) { + return this.internalMakeArrayFilterSchema(baseSchema); + } const components = this.makeCommonPrimitiveFilterComponents( baseSchema, optional, - () => z.lazy(() => this.makeEnumFilterSchema(enumDef, optional, withAggregations)), + () => z.lazy(() => this.makeEnumFilterSchema(enumDef, optional, withAggregations, array)), ['equals', 'in', 'notIn', 'not'], withAggregations ? ['_count', '_min', '_max'] : undefined, ); @@ -770,11 +781,15 @@ export class InputValidator { } private makeArrayFilterSchema(type: BuiltinType) { + return this.internalMakeArrayFilterSchema(this.makeScalarSchema(type)); + } + + private internalMakeArrayFilterSchema(elementSchema: ZodType) { return z.strictObject({ - equals: this.makeScalarSchema(type).array().optional(), - has: this.makeScalarSchema(type).optional(), - hasEvery: this.makeScalarSchema(type).array().optional(), - hasSome: this.makeScalarSchema(type).array().optional(), + equals: elementSchema.array().optional(), + has: elementSchema.optional(), + hasEvery: elementSchema.array().optional(), + hasSome: elementSchema.array().optional(), isEmpty: z.boolean().optional(), }); } diff --git a/packages/orm/src/client/executor/name-mapper.ts b/packages/orm/src/client/executor/name-mapper.ts index c93ec5936..379188756 100644 --- a/packages/orm/src/client/executor/name-mapper.ts +++ b/packages/orm/src/client/executor/name-mapper.ts @@ -530,9 +530,9 @@ export class QueryNameMapper extends OperationNodeTransformer { let schema = this.schema.provider.defaultSchema ?? 'public'; const schemaAttr = this.schema.models[model]?.attributes?.find((attr) => attr.name === '@@schema'); if (schemaAttr) { - const nameArg = schemaAttr.args?.find((arg) => arg.name === 'map'); - if (nameArg && nameArg.value.kind === 'literal') { - schema = nameArg.value.value as string; + const mapArg = schemaAttr.args?.find((arg) => arg.name === 'map'); + if (mapArg && mapArg.value.kind === 'literal') { + schema = mapArg.value.value as string; } } return schema; diff --git a/packages/orm/src/client/result-processor.ts b/packages/orm/src/client/result-processor.ts index a7870babc..fc8ae1938 100644 --- a/packages/orm/src/client/result-processor.ts +++ b/packages/orm/src/client/result-processor.ts @@ -49,7 +49,7 @@ export class ResultProcessor { // merge delegate descendant fields if (value) { // descendant fields are packed as JSON - const subRow = this.dialect.transformOutput(value, 'Json'); + const subRow = this.dialect.transformOutput(value, 'Json', false); // process the sub-row const subModel = key.slice(DELEGATE_JOINED_FIELD_PREFIX.length) as GetModels; @@ -93,10 +93,10 @@ export class ResultProcessor { private processFieldValue(value: unknown, fieldDef: FieldDef) { const type = fieldDef.type as BuiltinType; if (Array.isArray(value)) { - value.forEach((v, i) => (value[i] = this.dialect.transformOutput(v, type))); + value.forEach((v, i) => (value[i] = this.dialect.transformOutput(v, type, false))); return value; } else { - return this.dialect.transformOutput(value, type); + return this.dialect.transformOutput(value, type, !!fieldDef.array); } } diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index cd6be7472..696596618 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -205,7 +205,7 @@ export async function createTestClient( fs.writeFileSync(path.resolve(workDir!, 'schema.prisma'), prismaSchemaText); execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', { cwd: workDir, - stdio: 'ignore', + stdio: options.debug ? 'inherit' : 'ignore', }); } else { if (provider === 'postgresql') { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index df889c6bd..8b686eb02 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -69,9 +69,6 @@ catalogs: react-dom: specifier: 19.2.0 version: 19.2.0 - sql.js: - specifier: ^1.13.0 - version: 1.13.0 svelte: specifier: 5.45.6 version: 5.45.6 @@ -498,6 +495,9 @@ importers: pg: specifier: 'catalog:' version: 8.16.3 + postgres-array: + specifier: ^3.0.4 + version: 3.0.4 sql.js: specifier: 'catalog:' version: 1.13.0 @@ -6889,8 +6889,8 @@ packages: resolution: {integrity: sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==} engines: {node: '>=4'} - postgres-array@3.0.2: - resolution: {integrity: sha512-6faShkdFugNQCLwucjPcY5ARoW1SlbnrZjmGl0IrrqewpvxvhSLHimCVzqeuULCbG0fQv7Dtk1yDbG3xv7Veog==} + postgres-array@3.0.4: + resolution: {integrity: sha512-nAUSGfSDGOaOAEGwqsRY27GPOea7CNipJPOA7lPbdEpx5Kg3qzdP0AaWC5MlhTWV9s4hFX39nomVZ+C4tnGOJQ==} engines: {node: '>=12'} postgres-bytea@1.0.0: @@ -12438,7 +12438,7 @@ snapshots: eslint: 9.29.0(jiti@2.6.1) eslint-import-resolver-node: 0.3.9 eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-jsx-a11y: 6.10.2(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-react: 7.37.5(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-react-hooks: 7.0.1(eslint@9.29.0(jiti@2.6.1)) @@ -12471,7 +12471,7 @@ snapshots: tinyglobby: 0.2.15 unrs-resolver: 1.11.1 optionalDependencies: - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)) transitivePeerDependencies: - supports-color @@ -12486,7 +12486,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)): + eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)): dependencies: '@rtsao/scc': 1.1.0 array-includes: 3.1.9 @@ -14510,7 +14510,7 @@ snapshots: dependencies: pg-int8: 1.0.1 pg-numeric: 1.0.2 - postgres-array: 3.0.2 + postgres-array: 3.0.4 postgres-bytea: 3.0.0 postgres-date: 2.1.0 postgres-interval: 3.0.0 @@ -14759,7 +14759,7 @@ snapshots: postgres-array@2.0.0: {} - postgres-array@3.0.2: {} + postgres-array@3.0.4: {} postgres-bytea@1.0.0: {} diff --git a/tests/e2e/orm/client-api/pg-custom-schema.test.ts b/tests/e2e/orm/client-api/pg-custom-schema.test.ts index dfe6fe895..6fae284ba 100644 --- a/tests/e2e/orm/client-api/pg-custom-schema.test.ts +++ b/tests/e2e/orm/client-api/pg-custom-schema.test.ts @@ -159,7 +159,7 @@ model Foo { enum BarRole { ADMIN USER - @@schema('public') + @@schema('mySchema') } model Bar { diff --git a/tests/regression/test/issue-576.test.ts b/tests/regression/test/issue-576.test.ts new file mode 100644 index 000000000..3997cf007 --- /dev/null +++ b/tests/regression/test/issue-576.test.ts @@ -0,0 +1,200 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('regression test for issue 576', async () => { + it('should support enum array fields', async () => { + const db = await createTestClient( + ` +enum Tag { + TAG1 + TAG2 + TAG3 +} + +model Foo { + id Int @id + tags Tag[] + bar Bar? +} + +model Bar { + id Int @id + fooId Int @unique + foo Foo @relation(fields: [fooId], references: [id]) +} +`, + { provider: 'postgresql', usePrismaPush: true }, + ); + + await expect( + db.foo.create({ + data: { + id: 1, + tags: ['TAG1', 'TAG2'], + }, + }), + ).resolves.toMatchObject({ id: 1, tags: ['TAG1', 'TAG2'] }); + await expect( + db.foo.update({ + where: { id: 1 }, + data: { + tags: { set: ['TAG2', 'TAG3'] }, + }, + }), + ).resolves.toMatchObject({ id: 1, tags: ['TAG2', 'TAG3'] }); + + await expect(db.foo.findFirst()).resolves.toMatchObject({ tags: ['TAG2', 'TAG3'] }); + await expect(db.foo.findFirst({ where: { tags: { equals: ['TAG2', 'TAG3'] } } })).resolves.toMatchObject({ + tags: ['TAG2', 'TAG3'], + }); + await expect(db.foo.findFirst({ where: { tags: { has: 'TAG1' } } })).toResolveNull(); + + // nested create + await expect( + db.bar.create({ + data: { id: 1, foo: { create: { id: 2, tags: ['TAG1'] } } }, + include: { foo: true }, + }), + ).resolves.toMatchObject({ foo: expect.objectContaining({ tags: ['TAG1'] }) }); + + // nested find + await expect( + db.bar.findFirst({ + where: { foo: { tags: { has: 'TAG1' } } }, + include: { foo: true }, + }), + ).resolves.toMatchObject({ foo: expect.objectContaining({ tags: ['TAG1'] }) }); + + await expect( + db.bar.findFirst({ + where: { foo: { tags: { equals: ['TAG2'] } } }, + }), + ).toResolveNull(); + }); + + it('should support enum array stored in JSON field', async () => { + const db = await createTestClient( + ` +enum Tag { + TAG1 + TAG2 + TAG3 +} + +model Foo { + id Int @id + tags Json +} +`, + { provider: 'postgresql', usePrismaPush: true }, + ); + + await expect( + db.foo.create({ + data: { + id: 1, + tags: ['TAG1', 'TAG2'], + }, + }), + ).resolves.toMatchObject({ id: 1, tags: ['TAG1', 'TAG2'] }); + await expect(db.foo.findFirst()).resolves.toMatchObject({ tags: ['TAG1', 'TAG2'] }); + }); + + it('should support enum array stored in JSON array field', async () => { + const db = await createTestClient( + ` +enum Tag { + TAG1 + TAG2 + TAG3 +} + +model Foo { + id Int @id + tags Json[] +} +`, + { provider: 'postgresql', usePrismaPush: true }, + ); + + await expect( + db.foo.create({ + data: { + id: 1, + tags: ['TAG1', 'TAG2'], + }, + }), + ).resolves.toMatchObject({ id: 1, tags: ['TAG1', 'TAG2'] }); + await expect(db.foo.findFirst()).resolves.toMatchObject({ tags: ['TAG1', 'TAG2'] }); + }); + + it('should support enum with datasource defined default pg schema', async () => { + const db = await createTestClient( + ` +datasource db { + provider = 'postgresql' + schemas = ['public', 'mySchema'] + url = '$DB_URL' + defaultSchema = 'mySchema' +} + +enum Tag { + TAG1 + TAG2 + TAG3 +} + +model Foo { + id Int @id + tags Tag[] +} +`, + { provider: 'postgresql', usePrismaPush: true }, + ); + + await expect( + db.foo.create({ + data: { + id: 1, + tags: ['TAG1', 'TAG2'], + }, + }), + ).resolves.toMatchObject({ id: 1, tags: ['TAG1', 'TAG2'] }); + await expect(db.foo.findFirst()).resolves.toMatchObject({ tags: ['TAG1', 'TAG2'] }); + }); + + it('should support enum with custom pg schema', async () => { + const db = await createTestClient( + ` +datasource db { + provider = 'postgresql' + schemas = ['public', 'mySchema'] + url = '$DB_URL' +} + +enum Tag { + TAG1 + TAG2 + TAG3 + @@schema('mySchema') +} + +model Foo { + id Int @id + tags Tag[] +} +`, + { provider: 'postgresql', usePrismaPush: true }, + ); + + await expect( + db.foo.create({ + data: { + id: 1, + tags: ['TAG1', 'TAG2'], + }, + }), + ).resolves.toMatchObject({ id: 1, tags: ['TAG1', 'TAG2'] }); + await expect(db.foo.findFirst()).resolves.toMatchObject({ tags: ['TAG1', 'TAG2'] }); + }); +});