Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/orm/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"json-stable-stringify": "^1.3.0",
"kysely": "catalog:",
"nanoid": "^5.0.9",
"postgres-array": "^3.0.4",
"toposort": "^2.0.2",
"ts-pattern": "catalog:",
"ulid": "^3.0.0",
Expand Down
2 changes: 1 addition & 1 deletion packages/orm/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return value;
}

transformOutput(value: unknown, _type: BuiltinType) {
transformOutput(value: unknown, _type: BuiltinType, _array: boolean) {
return value;
}

Expand Down
61 changes: 59 additions & 2 deletions packages/orm/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
type SelectQueryBuilder,
type SqlBool,
} from 'kysely';
import { parse as parsePostgresArray } from 'postgres-array';
import { match } from 'ts-pattern';
import z from 'zod';
import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types';
Expand All @@ -20,14 +21,17 @@ import type { ClientOptions } from '../../options';
import {
buildJoinPairs,
getDelegateDescendantModels,
getEnum,
getManyToManyRelation,
isEnum,
isRelationField,
isTypeDef,
requireField,
requireIdFields,
requireModel,
} from '../../query-utils';
import { BaseCrudDialect } from './base-dialect';

export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect<Schema> {
private isoDateSchema = z.iso.datetime({ local: true, offset: true });

Expand Down Expand Up @@ -70,6 +74,16 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
if (type === 'Json' && !forArrayField) {
// scalar `Json` fields need their input stringified
return JSON.stringify(value);
}
if (isEnum(this.schema, type)) {
// cast to enum array `CAST(ARRAY[...] AS "enum_type"[])`
return this.eb.cast(
sql`ARRAY[${sql.join(
value.map((v) => this.transformPrimitive(v, type, false)),
sql.raw(','),
)}]`,
this.createSchemaQualifiedEnumType(type, true),
);
} else {
// `Json[]` fields need their input as array (not stringified)
return value.map((v) => this.transformPrimitive(v, type, false));
Expand All @@ -96,7 +110,33 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
}
}

override transformOutput(value: unknown, type: BuiltinType) {
private createSchemaQualifiedEnumType(type: string, array: boolean) {
// determines the postgres schema name for the enum type, and returns the
// qualified name

let qualified = type;

const enumDef = getEnum(this.schema, type);
if (enumDef) {
// check if the enum has a custom "@@schema" attribute
const schemaAttr = enumDef.attributes?.find((attr) => attr.name === '@@schema');
if (schemaAttr) {
const mapArg = schemaAttr.args?.find((arg) => arg.name === 'map');
if (mapArg && mapArg.value.kind === 'literal') {
const schemaName = mapArg.value.value as string;
qualified = `"${schemaName}"."${type}"`;
}
} else {
// no custom schema, use default from datasource or 'public'
const defaultSchema = this.schema.provider.defaultSchema ?? 'public';
qualified = `"${defaultSchema}"."${type}"`;
}
}

return array ? sql.raw(`${qualified}[]`) : sql.raw(qualified);
}

override transformOutput(value: unknown, type: BuiltinType, array: boolean) {
if (value === null || value === undefined) {
return value;
}
Expand All @@ -105,7 +145,11 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
.with('Bytes', () => this.transformOutputBytes(value))
.with('BigInt', () => this.transformOutputBigInt(value))
.with('Decimal', () => this.transformDecimal(value))
.otherwise(() => super.transformOutput(value, type));
.when(
(type) => isEnum(this.schema, type),
() => this.transformOutputEnum(value, array),
)
.otherwise(() => super.transformOutput(value, type, array));
}

private transformOutputBigInt(value: unknown) {
Expand Down Expand Up @@ -162,6 +206,19 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
: value;
}

private transformOutputEnum(value: unknown, array: boolean) {
if (array && typeof value === 'string') {
try {
// postgres returns enum arrays as `{"val 1",val2}` strings, parse them back
// to string arrays here
return parsePostgresArray(value);
} catch {
// fall through - return as-is if parsing fails
}
}
return value;
}

override buildRelationSelection(
query: SelectQueryBuilder<any, any, any>,
model: string,
Expand Down
4 changes: 2 additions & 2 deletions packages/orm/src/client/crud/dialects/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
}
}

override transformOutput(value: unknown, type: BuiltinType) {
override transformOutput(value: unknown, type: BuiltinType, array: boolean) {
if (value === null || value === undefined) {
return value;
} else if (this.schema.typeDefs && type in this.schema.typeDefs) {
Expand All @@ -81,7 +81,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
.with('Decimal', () => this.transformOutputDecimal(value))
.with('BigInt', () => this.transformOutputBigInt(value))
.with('Json', () => this.transformOutputJson(value))
.otherwise(() => super.transformOutput(value, type));
.otherwise(() => super.transformOutput(value, type, array));
}
}

Expand Down
79 changes: 47 additions & 32 deletions packages/orm/src/client/crud/validator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import {
type BuiltinType,
type EnumDef,
type FieldDef,
type ProcedureDef,
type GetModels,
type ModelDef,
type ProcedureDef,
type SchemaDef,
} from '../../../schema';
import { extractFields } from '../../../utils/object-utils';
Expand Down Expand Up @@ -199,10 +199,7 @@ export class InputValidator<Schema extends SchemaDef> {
>(model, 'find', options, (model, options) => this.makeFindSchema(model, options), args);
}

validateExistsArgs(
model: GetModels<Schema>,
args: unknown,
): ExistsArgs<Schema, GetModels<Schema>> | undefined {
validateExistsArgs(model: GetModels<Schema>, args: unknown): ExistsArgs<Schema, GetModels<Schema>> | undefined {
return this.validate<ExistsArgs<Schema, GetModels<Schema>>>(
model,
'exists',
Expand Down Expand Up @@ -429,9 +426,11 @@ export class InputValidator<Schema extends SchemaDef> {
}

private makeExistsSchema(model: string) {
return z.strictObject({
where: this.makeWhereSchema(model, false).optional(),
}).optional();
return z
.strictObject({
where: this.makeWhereSchema(model, false).optional(),
})
.optional();
}

private makeScalarSchema(type: string, attributes?: readonly AttributeApplication[]) {
Expand Down Expand Up @@ -577,7 +576,12 @@ export class InputValidator<Schema extends SchemaDef> {
if (enumDef) {
// enum
if (Object.keys(enumDef.values).length > 0) {
fieldSchema = this.makeEnumFilterSchema(enumDef, !!fieldDef.optional, withAggregations);
fieldSchema = this.makeEnumFilterSchema(
enumDef,
!!fieldDef.optional,
withAggregations,
!!fieldDef.array,
);
}
} else if (fieldDef.array) {
// array field
Expand Down Expand Up @@ -614,7 +618,12 @@ export class InputValidator<Schema extends SchemaDef> {
if (enumDef) {
// enum
if (Object.keys(enumDef.values).length > 0) {
fieldSchema = this.makeEnumFilterSchema(enumDef, !!def.optional, false);
fieldSchema = this.makeEnumFilterSchema(
enumDef,
!!def.optional,
false,
false,
);
} else {
fieldSchema = z.never();
}
Expand Down Expand Up @@ -696,24 +705,23 @@ export class InputValidator<Schema extends SchemaDef> {
!!fieldDef.array,
).optional();
} else {
// array, enum, primitives
if (fieldDef.array) {
// enum, array, primitives
const enumDef = getEnum(this.schema, fieldDef.type);
if (enumDef) {
fieldSchemas[fieldName] = this.makeEnumFilterSchema(
enumDef,
!!fieldDef.optional,
false,
!!fieldDef.array,
).optional();
} else if (fieldDef.array) {
fieldSchemas[fieldName] = this.makeArrayFilterSchema(fieldDef.type as BuiltinType).optional();
} else {
const enumDef = getEnum(this.schema, fieldDef.type);
if (enumDef) {
fieldSchemas[fieldName] = this.makeEnumFilterSchema(
enumDef,
!!fieldDef.optional,
false,
).optional();
} else {
fieldSchemas[fieldName] = this.makePrimitiveFilterSchema(
fieldDef.type as BuiltinType,
!!fieldDef.optional,
false,
).optional();
}
fieldSchemas[fieldName] = this.makePrimitiveFilterSchema(
fieldDef.type as BuiltinType,
!!fieldDef.optional,
false,
).optional();
}
}
}
Expand Down Expand Up @@ -757,24 +765,31 @@ export class InputValidator<Schema extends SchemaDef> {
return this.schema.typeDefs && type in this.schema.typeDefs;
}

private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean) {
private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean, array: boolean) {
const baseSchema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]);
if (array) {
return this.internalMakeArrayFilterSchema(baseSchema);
}
const components = this.makeCommonPrimitiveFilterComponents(
baseSchema,
optional,
() => z.lazy(() => this.makeEnumFilterSchema(enumDef, optional, withAggregations)),
() => z.lazy(() => this.makeEnumFilterSchema(enumDef, optional, withAggregations, array)),
['equals', 'in', 'notIn', 'not'],
withAggregations ? ['_count', '_min', '_max'] : undefined,
);
return z.union([this.nullableIf(baseSchema, optional), z.strictObject(components)]);
}

private makeArrayFilterSchema(type: BuiltinType) {
return this.internalMakeArrayFilterSchema(this.makeScalarSchema(type));
}

private internalMakeArrayFilterSchema(elementSchema: ZodType) {
return z.strictObject({
equals: this.makeScalarSchema(type).array().optional(),
has: this.makeScalarSchema(type).optional(),
hasEvery: this.makeScalarSchema(type).array().optional(),
hasSome: this.makeScalarSchema(type).array().optional(),
equals: elementSchema.array().optional(),
has: elementSchema.optional(),
hasEvery: elementSchema.array().optional(),
hasSome: elementSchema.array().optional(),
isEmpty: z.boolean().optional(),
});
}
Expand Down
6 changes: 3 additions & 3 deletions packages/orm/src/client/executor/name-mapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
let schema = this.schema.provider.defaultSchema ?? 'public';
const schemaAttr = this.schema.models[model]?.attributes?.find((attr) => attr.name === '@@schema');
if (schemaAttr) {
const nameArg = schemaAttr.args?.find((arg) => arg.name === 'map');
if (nameArg && nameArg.value.kind === 'literal') {
schema = nameArg.value.value as string;
const mapArg = schemaAttr.args?.find((arg) => arg.name === 'map');
if (mapArg && mapArg.value.kind === 'literal') {
schema = mapArg.value.value as string;
}
}
return schema;
Expand Down
6 changes: 3 additions & 3 deletions packages/orm/src/client/result-processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ export class ResultProcessor<Schema extends SchemaDef> {
// merge delegate descendant fields
if (value) {
// descendant fields are packed as JSON
const subRow = this.dialect.transformOutput(value, 'Json');
const subRow = this.dialect.transformOutput(value, 'Json', false);

// process the sub-row
const subModel = key.slice(DELEGATE_JOINED_FIELD_PREFIX.length) as GetModels<Schema>;
Expand Down Expand Up @@ -93,10 +93,10 @@ export class ResultProcessor<Schema extends SchemaDef> {
private processFieldValue(value: unknown, fieldDef: FieldDef) {
const type = fieldDef.type as BuiltinType;
if (Array.isArray(value)) {
value.forEach((v, i) => (value[i] = this.dialect.transformOutput(v, type)));
value.forEach((v, i) => (value[i] = this.dialect.transformOutput(v, type, false)));
return value;
} else {
return this.dialect.transformOutput(value, type);
return this.dialect.transformOutput(value, type, !!fieldDef.array);
}
}

Expand Down
2 changes: 1 addition & 1 deletion packages/testtools/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ export async function createTestClient(
fs.writeFileSync(path.resolve(workDir!, 'schema.prisma'), prismaSchemaText);
execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', {
cwd: workDir,
stdio: 'ignore',
stdio: options.debug ? 'inherit' : 'ignore',
});
} else {
if (provider === 'postgresql') {
Expand Down
Loading