diff --git a/BREAKINGCHANGES.md b/BREAKINGCHANGES.md new file mode 100644 index 000000000..afae50ec0 --- /dev/null +++ b/BREAKINGCHANGES.md @@ -0,0 +1,2 @@ +1. `auth()` cannot be directly compared with a relation anymore +2. diff --git a/TODO.md b/TODO.md index 79a220045..116d7dbd0 100644 --- a/TODO.md +++ b/TODO.md @@ -15,9 +15,10 @@ - [x] Relation connection - [x] Create many - [x] ID generation + - [x] CreateManyAndReturn - [ ] Find - [x] Input validation - - [ ] Field selection + - [x] Field selection - [x] Omit - [x] Counting relation - [x] Pagination @@ -42,6 +43,7 @@ - [x] Nested to-one - [ ] Delta update for numeric fields - [ ] Array update + - [ ] Upsert - [x] Delete - [ ] Aggregation - [x] Count @@ -52,15 +54,14 @@ - [x] Computed fields - [?] Prisma client extension - [ ] Misc - - [ ] Rename AST Model to Schema - [ ] Compound ID - [ ] Cross field comparison - [ ] Many-to-many relation - [ ] Cache validation schemas - [?] Logging - [ ] Error system - - [?] Custom table name - - [ ] Custom field name + - [x] Custom table name + - [x] Custom field name - [ ] Access Policy - [ ] Polymorphism - [x] Migration diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index b94703fca..3bf20f074 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -10,6 +10,7 @@ import { } from 'kysely'; import { match } from 'ts-pattern'; import type { GetModels, ProcedureDef, SchemaDef } from '../schema'; +import type { AuthType } from '../schema/schema'; import type { ClientConstructor, ClientContract } from './contract'; import type { ModelOperations } from './crud-types'; import { AggregateOperationHandler } from './crud/operations/aggregate'; @@ -30,7 +31,6 @@ import type { RuntimePlugin } from './plugin'; import { createDeferredPromise } from './promise'; import type { ToKysely } from './query-builder'; import { ResultProcessor } from './result-processor'; -import type { AuthType } from '../schema/schema'; /** * Creates a new ZenStack client instance. @@ -201,7 +201,10 @@ export class ClientImpl { return new ClientImpl(this.schema, newOptions, this); } - $setAuth(auth: AuthType) { + $setAuth(auth: AuthType | undefined) { + if (auth !== undefined && typeof auth !== 'object') { + throw new Error('Invalid auth object'); + } const newClient = new ClientImpl( this.schema, this.$options, @@ -364,6 +367,15 @@ function createModelCrudHandler< ); }, + createManyAndReturn: (args: unknown) => { + return createPromise( + 'createManyAndReturn', + args, + new CreateOperationHandler(client, model, inputValidator), + true + ); + }, + update: (args: unknown) => { return createPromise( 'update', diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index e5a4b25ae..57489d85e 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -299,6 +299,10 @@ export type WhereUnique< Extract['uniqueFields'], string> >; +type OmitFields> = { + [Key in ScalarFields]?: true; +}; + export type SelectInclude< Schema extends SchemaDef, Model extends GetModels, @@ -306,9 +310,7 @@ export type SelectInclude< > = { select?: Select; include?: Include; - omit?: { - [Key in ScalarFields]?: true; - }; + omit?: OmitFields; }; type Select< @@ -545,6 +547,12 @@ export type CreateManyArgs< Model extends GetModels > = CreateManyPayload; +export type CreateManyAndReturnArgs< + Schema extends SchemaDef, + Model extends GetModels +> = CreateManyPayload & + Omit, 'include'>; + type OptionalWrap< Schema extends SchemaDef, Model extends GetModels, @@ -1074,6 +1082,12 @@ export type ModelOperations< createMany(args?: CreateManyPayload): Promise; + createManyAndReturn( + args?: CreateManyAndReturnArgs + ): Promise< + ModelResult>[] + >; + update>( args: SelectSubset> ): Promise>; diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index dd081e093..418b5cbaa 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -14,14 +14,19 @@ import { match } from 'ts-pattern'; import { ulid } from 'ulid'; import * as uuid from 'uuid'; import type { ClientContract } from '../..'; -import type { GetModels, ModelDef, SchemaDef } from '../../../schema'; -import type { - BuiltinType, - FieldDef, - FieldDefaultProvider, -} from '../../../schema/schema'; +import { + Expression, + type GetModels, + type ModelDef, + type SchemaDef, +} from '../../../schema'; +import type { BuiltinType, FieldDef } from '../../../schema/schema'; import { clone } from '../../../utils/clone'; import { enumerate } from '../../../utils/enumerate'; +import { + extractFields, + fieldsToSelectObject, +} from '../../../utils/object-utils'; import type { FindArgs, SelectInclude, Where } from '../../crud-types'; import { InternalError, NotFoundError, QueryError } from '../../errors'; import type { ToKysely } from '../../query-builder'; @@ -49,6 +54,7 @@ export type CrudOperation = | 'findFirst' | 'create' | 'createMany' + | 'createManyAndReturn' | 'update' | 'updateMany' | 'delete' @@ -547,8 +553,42 @@ export abstract class BaseOperationHandler { } case 'connect': { - // directly return the payload as foreign key values - result = subPayload; + const referencedPkFields = + relationField.relation!.references!; + invariant( + referencedPkFields, + 'relation must have fields info' + ); + const extractedFks = extractFields( + subPayload, + referencedPkFields + ); + if ( + Object.keys(extractedFks).length === + referencedPkFields.length + ) { + // payload contains all referenced pk fields, we can + // directly use it to connect the relation + result = extractedFks; + } else { + // read the relation entity and fetch the referenced pk fields + const relationEntity = await this.readUnique( + kysely, + relationModel, + { + where: subPayload, + select: fieldsToSelectObject( + referencedPkFields + ) as any, + } + ); + if (!relationEntity) { + throw new NotFoundError( + `Could not find the entity for connect action` + ); + } + result = relationEntity; + } break; } @@ -674,12 +714,16 @@ export abstract class BaseOperationHandler { return Promise.all(tasks); } - protected async createMany( + protected async createMany< + ReturnData extends boolean, + Result = ReturnData extends true ? unknown[] : { count: number } + >( kysely: ToKysely, model: GetModels, input: { data: any; skipDuplicates?: boolean }, + returnData: ReturnData, fromRelation?: FromRelationContext - ) { + ): Promise { const modelDef = this.requireModel(model); let relationKeyPairs: { fk: string; pk: string }[] = []; @@ -713,8 +757,15 @@ export abstract class BaseOperationHandler { .$if(!!input.skipDuplicates, (qb) => qb.onConflict((oc) => oc.doNothing()) ); - const result = await query.executeTakeFirstOrThrow(); - return { count: Number(result.numInsertedOrUpdatedRows) }; + + if (!returnData) { + const result = await query.executeTakeFirstOrThrow(); + return { count: Number(result.numInsertedOrUpdatedRows) } as Result; + } else { + const idFields = getIdFields(this.schema, model); + const result = await query.returning(idFields as any).execute(); + return result as Result; + } } private fillGeneratedValues(modelDef: ModelDef, data: object) { @@ -724,10 +775,10 @@ export abstract class BaseOperationHandler { if (!(field in data)) { if ( typeof fields[field]?.default === 'object' && - 'call' in fields[field].default + 'kind' in fields[field].default ) { const generated = this.evalGenerator(fields[field].default); - if (generated) { + if (generated !== undefined) { values[field] = generated; } } else if (fields[field]?.updatedAt) { @@ -738,15 +789,40 @@ export abstract class BaseOperationHandler { return values; } - private evalGenerator(defaultProvider: FieldDefaultProvider) { - return match(defaultProvider.call) - .with('cuid', () => createId()) - .with('uuid', () => - defaultProvider.args?.[0] === 7 ? uuid.v7() : uuid.v4() - ) - .with('nanoid', () => nanoid(defaultProvider.args?.[0])) - .with('ulid', () => ulid()) - .otherwise(() => undefined); + private evalGenerator(defaultValue: Expression) { + if (Expression.isCall(defaultValue)) { + return match(defaultValue.function) + .with('cuid', () => createId()) + .with('uuid', () => + defaultValue.args?.[0] && + Expression.isLiteral(defaultValue.args?.[0]) && + defaultValue.args[0].value === 7 + ? uuid.v7() + : uuid.v4() + ) + .with('nanoid', () => + defaultValue.args?.[0] && + Expression.isLiteral(defaultValue.args[0]) && + typeof defaultValue.args[0].value === 'number' + ? nanoid(defaultValue.args[0].value) + : nanoid() + ) + .with('ulid', () => ulid()) + .otherwise(() => undefined); + } else if ( + Expression.isMember(defaultValue) && + Expression.isCall(defaultValue.receiver) && + defaultValue.receiver.function === 'auth' + ) { + // `auth()` member access + let val: any = this.client.$auth; + for (const member of defaultValue.members) { + val = val?.[member]; + } + return val ?? null; + } else { + return undefined; + } } protected async update( @@ -1023,6 +1099,7 @@ export abstract class BaseOperationHandler { kysely, fieldModel, value as { data: any; skipDuplicates: boolean }, + false, fromRelationContext ) ); diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index 033bf529a..03dd84493 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -1,7 +1,11 @@ import { match } from 'ts-pattern'; import { RejectedByPolicyError } from '../../../plugins/policy/errors'; import type { GetModels, SchemaDef } from '../../../schema'; -import type { CreateArgs, CreateManyArgs } from '../../crud-types'; +import type { + CreateArgs, + CreateManyAndReturnArgs, + CreateManyArgs, +} from '../../crud-types'; import { getIdValues } from '../../query-utils'; import { BaseOperationHandler } from './base'; @@ -9,7 +13,7 @@ export class CreateOperationHandler< Schema extends SchemaDef > extends BaseOperationHandler { async handle( - operation: 'create' | 'createMany', + operation: 'create' | 'createMany' | 'createManyAndReturn', args: unknown | undefined ) { return match(operation) @@ -23,6 +27,14 @@ export class CreateOperationHandler< this.inputValidator.validateCreateManyArgs(this.model, args) ); }) + .with('createManyAndReturn', () => { + return this.runCreateManyAndReturn( + this.inputValidator.validateCreateManyAndReturnArgs( + this.model, + args + ) + ); + }) .exhaustive(); } @@ -50,6 +62,34 @@ export class CreateOperationHandler< if (args === undefined) { return { count: 0 }; } - return this.createMany(this.kysely, this.model, args); + return this.createMany(this.kysely, this.model, args, false); + } + + private async runCreateManyAndReturn( + args?: CreateManyAndReturnArgs> + ) { + if (args === undefined) { + return []; + } + + // TODO: avoid using transaction for simple create + return this.safeTransaction(async (tx) => { + const createResult = await this.createMany( + tx, + this.model, + args, + true + ); + return this.read(tx, this.model, { + select: args.select, + omit: args.omit, + where: { + OR: createResult.map( + (item) => + getIdValues(this.schema, this.model, item) as any + ), + } as any, // TODO: fix type + }); + }); } } diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 137e73d5a..20f543bdc 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -11,6 +11,7 @@ import { type AggregateArgs, type CountArgs, type CreateArgs, + type CreateManyAndReturnArgs, type CreateManyArgs, type DeleteArgs, type DeleteManyArgs, @@ -55,6 +56,16 @@ export class InputValidator { >(this.makeCreateManySchema(model), 'createMany', args); } + validateCreateManyAndReturnArgs(model: GetModels, args: unknown) { + return this.validate< + CreateManyAndReturnArgs> | undefined + >( + this.makeCreateManyAndReturnSchema(model), + 'createManyAndReturn', + args + ); + } + validateUpdateArgs(model: GetModels, args: unknown) { return this.validate>>( this.makeUpdateSchema(model), @@ -550,6 +561,18 @@ export class InputValidator { return this.makeCreateManyDataSchema(model, []).optional(); } + private makeCreateManyAndReturnSchema(model: string) { + const base = this.makeCreateManyDataSchema(model, []); + return base + .merge( + z.object({ + select: this.makeSelectSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), + }) + ) + .optional(); + } + private makeCreateDataSchema( model: string, canBeArray: boolean, @@ -598,9 +621,23 @@ export class InputValidator { ) ); - // optional or array relations are optional if (fieldDef.optional || fieldDef.array) { + // optional or array relations are optional fieldSchema = fieldSchema.optional(); + } else { + // if all fk fields are optional, the relation is optional + let allFksOptional = false; + if (fieldDef.relation.fields) { + allFksOptional = fieldDef.relation.fields.every((f) => { + const fkDef = requireField(this.schema, model, f); + return ( + fkDef.optional || fieldHasDefaultValue(fkDef) + ); + }); + } + if (allFksOptional) { + fieldSchema = fieldSchema.optional(); + } } // optional to-one relation can be null @@ -612,6 +649,7 @@ export class InputValidator { let fieldSchema: ZodSchema = this.makePrimitiveSchema( fieldDef.type ); + if (fieldDef.optional || fieldHasDefaultValue(fieldDef)) { fieldSchema = fieldSchema.optional(); } diff --git a/packages/runtime/src/client/helpers/schema-db-pusher.ts b/packages/runtime/src/client/helpers/schema-db-pusher.ts index b1240ae48..565f24a14 100644 --- a/packages/runtime/src/client/helpers/schema-db-pusher.ts +++ b/packages/runtime/src/client/helpers/schema-db-pusher.ts @@ -6,7 +6,12 @@ import { } from 'kysely'; import invariant from 'tiny-invariant'; import { match } from 'ts-pattern'; -import type { FieldDef, ModelDef, SchemaDef } from '../../schema'; +import { + Expression, + type FieldDef, + type ModelDef, + type SchemaDef, +} from '../../schema'; import type { BuiltinType, CascadeAction, @@ -137,9 +142,12 @@ export class SchemaDbPusher { if (fieldDef.default !== undefined) { if ( typeof fieldDef.default === 'object' && - 'call' in fieldDef.default + 'kind' in fieldDef.default ) { - if (fieldDef.default.call === 'now') { + if ( + Expression.isCall(fieldDef.default) && + fieldDef.default.function === 'now' + ) { col = col.defaultTo(sql`CURRENT_TIMESTAMP`); } } else { diff --git a/packages/runtime/src/plugins/policy/expression-evaluator.ts b/packages/runtime/src/plugins/policy/expression-evaluator.ts new file mode 100644 index 000000000..cc1087632 --- /dev/null +++ b/packages/runtime/src/plugins/policy/expression-evaluator.ts @@ -0,0 +1,163 @@ +import { match } from 'ts-pattern'; +import { + Expression, + type ArrayExpression, + type BinaryExpression, + type CallExpression, + type FieldExpression, + type LiteralExpression, + type MemberExpression, + type UnaryExpression, +} from '../../schema'; +import invariant from 'tiny-invariant'; + +type ExpressionEvaluatorContext = { + auth?: any; + thisValue?: any; +}; + +/** + * Evaluate a schema expression into a JavaScript value. + */ +export class ExpressionEvaluator { + evaluate(expression: Expression, context: ExpressionEvaluatorContext): any { + const result = match(expression) + .when(Expression.isArray, (expr) => + this.evaluateArray(expr, context) + ) + .when(Expression.isBinary, (expr) => + this.evaluateBinary(expr, context) + ) + .when(Expression.isField, (expr) => + this.evaluateField(expr, context) + ) + .when(Expression.isLiteral, (expr) => this.evaluateLiteral(expr)) + .when(Expression.isMember, (expr) => + this.evaluateMember(expr, context) + ) + .when(Expression.isUnary, (expr) => + this.evaluateUnary(expr, context) + ) + .when(Expression.isCall, (expr) => this.evaluateCall(expr, context)) + .when(Expression.isThis, () => context.thisValue) + .when(Expression.isNull, () => null) + .exhaustive(); + + return result ?? null; + } + + private evaluateCall( + expr: CallExpression, + context: ExpressionEvaluatorContext + ): any { + if (expr.function === 'auth') { + return context.auth; + } else { + throw new Error( + `Unsupported call expression function: ${expr.function}` + ); + } + } + + private evaluateUnary( + expr: UnaryExpression, + context: ExpressionEvaluatorContext + ) { + return match(expr.op) + .with('!', () => !this.evaluate(expr.operand, context)) + .exhaustive(); + } + + private evaluateMember( + expr: MemberExpression, + context: ExpressionEvaluatorContext + ) { + let val = this.evaluate(expr.receiver, context); + for (const member of expr.members) { + val = val?.[member]; + } + return val; + } + + private evaluateLiteral(expr: LiteralExpression): any { + return expr.value; + } + + private evaluateField( + expr: FieldExpression, + context: ExpressionEvaluatorContext + ): any { + return context.thisValue?.[expr.field]; + } + + private evaluateArray( + expr: ArrayExpression, + context: ExpressionEvaluatorContext + ) { + return expr.items.map((item) => this.evaluate(item, context)); + } + + private evaluateBinary( + expr: BinaryExpression, + context: ExpressionEvaluatorContext + ) { + if (expr.op === '?' || expr.op === '!' || expr.op === '^') { + return this.evaluateCollectionPredicate(expr, context); + } + + const left = this.evaluate(expr.left, context); + const right = this.evaluate(expr.right, context); + + return match(expr.op) + .with('==', () => left === right) + .with('!=', () => left !== right) + .with('>', () => left > right) + .with('>=', () => left >= right) + .with('<', () => left < right) + .with('<=', () => left <= right) + .with('&&', () => left && right) + .with('||', () => left || right) + .exhaustive(); + } + + private evaluateCollectionPredicate( + expr: BinaryExpression, + context: ExpressionEvaluatorContext + ) { + const op = expr.op; + invariant( + op === '?' || op === '!' || op === '^', + 'expected "?" or "!" or "^" operator' + ); + + const left = this.evaluate(expr.left, context); + if (!left) { + return false; + } + + invariant(Array.isArray(left), 'expected array'); + + return match(op) + .with('?', () => + left.some((item: any) => + this.evaluate(expr.right, { ...context, thisValue: item }) + ) + ) + .with('!', () => + left.every((item: any) => + this.evaluate(expr.right, { ...context, thisValue: item }) + ) + ) + .with( + '^', + () => + !left.some((item: any) => + this.evaluate(expr.right, { + ...context, + thisValue: item, + }) + ) + ) + .exhaustive(); + } +} diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index cf36e0923..6809b2cbc 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -35,6 +35,7 @@ import { type UnaryExpression, } from '../../schema/expression'; import type { BuiltinType, GetModels } from '../../schema/schema'; +import { ExpressionEvaluator } from './expression-evaluator'; import { conjunction, disjunction, logicalNot, trueNode } from './utils'; export type ExpressionTransformerContext = { @@ -217,16 +218,22 @@ export class ExpressionTransformer { 'expected "?" or "!" or "^" operator' ); + if (this.isAuthCall(expr.left) || this.isAuthMember(expr.left)) { + const value = new ExpressionEvaluator().evaluate(expr, { + auth: this.auth, + }); + return this.transformValue(value, 'Boolean'); + } + const left = this.transform(expr.left, context); invariant( - Expression.isFieldExpr(expr.left) || - Expression.isMemberExpr(expr.left), + Expression.isField(expr.left) || Expression.isMember(expr.left), 'left operand must be field or member access' ); let newContextModel: string; - if (Expression.isFieldExpr(expr.left)) { + if (Expression.isField(expr.left)) { const fieldDef = requireField( this.schema, context.model, @@ -234,7 +241,7 @@ export class ExpressionTransformer { ); newContextModel = fieldDef.type; } else { - invariant(Expression.isFieldExpr(expr.left.receiver)); + invariant(Expression.isField(expr.left.receiver)); const fieldDef = requireField( this.schema, context.model, @@ -403,6 +410,10 @@ export class ExpressionTransformer { return Expression.isCall(value) && value.function === 'auth'; } + private isAuthMember(expr: Expression): boolean { + return Expression.isMember(expr) && this.isAuthCall(expr.receiver); + } + @expr('member') // @ts-ignore private _member( @@ -415,7 +426,7 @@ export class ExpressionTransformer { } invariant( - Expression.isFieldExpr(expr.receiver), + Expression.isField(expr.receiver), 'expect receiver to be field expression' ); diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 6dcce1b76..50cc9edca 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -397,12 +397,14 @@ export class PolicyHandler< node.from?.froms.forEach((from) => { let modelName = this.extractTableName(from); - const filter = this.buildPolicyFilter(modelName, 'read'); - whereNode = WhereNode.create( - whereNode?.where - ? conjunction(this.dialect, [whereNode.where, filter]) - : filter - ); + if (modelName) { + const filter = this.buildPolicyFilter(modelName, 'read'); + whereNode = WhereNode.create( + whereNode?.where + ? conjunction(this.dialect, [whereNode.where, filter]) + : filter + ); + } }); const baseResult = super.transformSelectQuery({ @@ -468,14 +470,18 @@ export class PolicyHandler< }; } - private extractTableName(from: OperationNode): GetModels { + private extractTableName( + from: OperationNode + ): GetModels | undefined { if (TableNode.is(from)) { return from.table.identifier.name as GetModels; } if (AliasNode.is(from)) { return this.extractTableName(from.node); } else { - throw new Error(`Unexpected "from" node kind: ${from.kind}`); + // this can happen for subqueries, which will be handled when nested + // transformation happens + return undefined; } } diff --git a/packages/runtime/src/schema/expression.ts b/packages/runtime/src/schema/expression.ts index d8236864e..ce6f91ffb 100644 --- a/packages/runtime/src/schema/expression.ts +++ b/packages/runtime/src/schema/expression.ts @@ -180,21 +180,15 @@ export const Expression = { isThis: (value: unknown): value is ThisExpression => Expression.is(value, 'this'), - isUnaryExpr: (value: unknown): value is UnaryExpression => + isUnary: (value: unknown): value is UnaryExpression => Expression.is(value, 'unary'), - isBinaryExpr: (value: unknown): value is BinaryExpression => + isBinary: (value: unknown): value is BinaryExpression => Expression.is(value, 'binary'), - isFieldExpr: (value: unknown): value is FieldExpression => + isField: (value: unknown): value is FieldExpression => Expression.is(value, 'field'), - isMemberExpr: (value: unknown): value is MemberExpression => + isMember: (value: unknown): value is MemberExpression => Expression.is(value, 'member'), - - isCallExpr: (value: unknown): value is CallExpression => - Expression.is(value, 'call'), - - isThisExpr: (value: unknown): value is ThisExpression => - Expression.is(value, 'this'), }; diff --git a/packages/runtime/src/schema/schema.ts b/packages/runtime/src/schema/schema.ts index d740563cb..3dd1d330a 100644 --- a/packages/runtime/src/schema/schema.ts +++ b/packages/runtime/src/schema/schema.ts @@ -57,8 +57,6 @@ export type RelationInfo = { onUpdate?: CascadeAction; }; -export type FieldDefaultProvider = { call: string; args?: any[] }; - export type FieldDef = { type: string; id?: boolean; @@ -67,7 +65,7 @@ export type FieldDef = { unique?: boolean; updatedAt?: boolean; attributes?: AttributeApplication[]; - default?: MappedBuiltinType | FieldDefaultProvider; + default?: MappedBuiltinType | Expression; relation?: RelationInfo; foreignKeyFor?: string[]; computed?: boolean; diff --git a/packages/runtime/src/utils/object-utils.ts b/packages/runtime/src/utils/object-utils.ts new file mode 100644 index 000000000..59a2346ae --- /dev/null +++ b/packages/runtime/src/utils/object-utils.ts @@ -0,0 +1,17 @@ +/** + * Extract fields from an object. + */ +export function extractFields(obj: any, fields: string[]) { + return Object.fromEntries( + Object.entries(obj).filter(([key]) => fields.includes(key)) + ); +} + +/** + * Create an object with fields as keys and true values. + */ +export function fieldsToSelectObject( + fields: string[] +): Record { + return Object.fromEntries(fields.map((f) => [f, true])); +} diff --git a/packages/runtime/test/client-api/create-many-and-return.test.ts b/packages/runtime/test/client-api/create-many-and-return.test.ts new file mode 100644 index 000000000..4781ef0fb --- /dev/null +++ b/packages/runtime/test/client-api/create-many-and-return.test.ts @@ -0,0 +1,94 @@ +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import type { ClientContract } from '../../src/client'; +import { schema } from '../test-schema'; +import { createClientSpecs } from './client-specs'; + +const PG_DB_NAME = 'client-api-create-many-and-return-tests'; + +describe.each(createClientSpecs(PG_DB_NAME))( + 'Client createManyAndReturn tests', + ({ createClient }) => { + let client: ClientContract; + + beforeEach(async () => { + client = await createClient(); + await client.$pushSchema(); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('works with toplevel createManyAndReturn', async () => { + // empty + await expect(client.user.createManyAndReturn()).toResolveWithLength( + 0 + ); + + // single + await expect( + client.user.createManyAndReturn({ + data: { + email: 'u1@test.com', + name: 'name', + }, + }) + ).resolves.toEqual([ + expect.objectContaining({ email: 'u1@test.com', name: 'name' }), + ]); + + // multiple + let r = await client.user.createManyAndReturn({ + data: [{ email: 'u2@test.com' }, { email: 'u3@test.com' }], + }); + expect(r).toHaveLength(2); + expect(r).toEqual( + expect.arrayContaining([ + expect.objectContaining({ email: 'u2@test.com' }), + expect.objectContaining({ email: 'u3@test.com' }), + ]) + ); + + // conflict + await expect( + client.user.createManyAndReturn({ + data: [{ email: 'u3@test.com' }, { email: 'u4@test.com' }], + }) + ).rejects.toThrow(); + await expect( + client.user.findUnique({ where: { email: 'u4@test.com' } }) + ).toResolveNull(); + + // skip duplicates + r = await client.user.createManyAndReturn({ + data: [{ email: 'u3@test.com' }, { email: 'u4@test.com' }], + skipDuplicates: true, + }); + expect(r).toHaveLength(1); + expect(r).toEqual( + expect.arrayContaining([ + expect.objectContaining({ email: 'u4@test.com' }), + ]) + ); + await expect( + client.user.findUnique({ where: { email: 'u4@test.com' } }) + ).toResolveTruthy(); + }); + + it('works with select and omit', async () => { + let r = await client.user.createManyAndReturn({ + data: [{ email: 'u1@test.com', name: 'name' }], + select: { email: true }, + }); + expect(r[0]!.email).toBe('u1@test.com'); + expect(r[0]!.name).toBeUndefined(); + + r = await client.user.createManyAndReturn({ + data: [{ email: 'u2@test.com', name: 'name' }], + omit: { name: true }, + }); + expect(r[0]!.email).toBe('u2@test.com'); + expect(r[0]!.name).toBeUndefined(); + }); + } +); diff --git a/packages/runtime/test/policy/auth.test.ts b/packages/runtime/test/policy/auth.test.ts new file mode 100644 index 000000000..28cdd63e0 --- /dev/null +++ b/packages/runtime/test/policy/auth.test.ts @@ -0,0 +1,663 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('auth() tests', () => { + it('works with string id non-null test', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id @default(uuid()) +} + +model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth() != null) +} +` + ); + + await expect( + db.post.create({ data: { title: 'abc' } }) + ).toBeRejectedByPolicy(); + + const authDb = db.$setAuth({ id: 'user1' }); + await expect( + authDb.post.create({ data: { title: 'abc' } }) + ).toResolveTruthy(); + }); + + it('works with string id id test', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id @default(uuid()) + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth().id != null) + } + ` + ); + + await expect( + db.post.create({ data: { title: 'abc' } }) + ).toBeRejectedByPolicy(); + + const authDb = db.$setAuth({ id: 'user1' }); + await expect( + authDb.post.create({ data: { title: 'abc' } }) + ).toResolveTruthy(); + }); + + it('works with int id', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth() != null) + } + ` + ); + + await expect( + db.post.create({ data: { title: 'abc' } }) + ).toBeRejectedByPolicy(); + + const authDb = db.$setAuth({ id: 'user1' }); + await expect( + authDb.post.create({ data: { title: 'abc' } }) + ).toResolveTruthy(); + }); + + it('works with field comparison', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id @default(uuid()) + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('create,read', true) + @@allow('update', auth().id == author.id) + } + ` + ); + + await expect( + db.user.create({ data: { id: 'user1' } }) + ).toResolveTruthy(); + await expect( + db.post.create({ + data: { id: '1', title: 'abc', authorId: 'user1' }, + }) + ).toResolveTruthy(); + + await expect( + db.post.update({ where: { id: '1' }, data: { title: 'bcd' } }) + ).toBeRejectedNotFound(); + + const authDb2 = db.$setAuth({ id: 'user1' }); + await expect( + authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } }) + ).toResolveTruthy(); + }); + + it('works with undefined user non-id field', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id @default(uuid()) + posts Post[] + role String + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('create,read', true) + @@allow('update', auth().role == 'ADMIN') + } + ` + ); + + await expect( + db.user.create({ data: { id: 'user1', role: 'USER' } }) + ).toResolveTruthy(); + await expect( + db.post.create({ + data: { id: '1', title: 'abc', authorId: 'user1' }, + }) + ).toResolveTruthy(); + await expect( + db.post.update({ where: { id: '1' }, data: { title: 'bcd' } }) + ).toBeRejectedNotFound(); + + const authDb = db.$setAuth({ id: 'user1', role: 'USER' }); + await expect( + authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } }) + ).toBeRejectedNotFound(); + + const authDb1 = db.$setAuth({ id: 'user2', role: 'ADMIN' }); + await expect( + authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } }) + ).toResolveTruthy(); + }); + + it('works with non User auth model', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id String @id @default(uuid()) + role String + + @@auth() + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth().role == 'ADMIN') + } + ` + ); + + const userDb = db.$setAuth({ id: 'user1', role: 'USER' }); + await expect( + userDb.post.create({ data: { title: 'abc' } }) + ).toBeRejectedByPolicy(); + + const adminDb = db.$setAuth({ id: 'user1', role: 'ADMIN' }); + await expect( + adminDb.post.create({ data: { title: 'abc' } }) + ).toResolveTruthy(); + }); + + it('works with collection predicate', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id @default(uuid()) + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + published Boolean @default(false) + author User @relation(fields: [authorId], references: [id]) + authorId String + comments Comment[] + + @@allow('read', true) + @@allow('create', auth().posts?[published && comments![published]]) + } + + model Comment { + id String @id @default(uuid()) + published Boolean @default(false) + post Post @relation(fields: [postId], references: [id]) + postId String + + @@allow('all', true) + } + `, + { log: ['query'] } + ); + + const rawDb = db.$unuseAll(); + + const user = await rawDb.user.create({ data: {} }); + + const createPayload = { + data: { title: 'Post 1', author: { connect: { id: user.id } } }, + }; + + // no post + await expect( + db.$setAuth({ id: '1' }).post.create(createPayload) + ).toBeRejectedByPolicy(); + + // post not published + await expect( + db + .$setAuth({ + id: '1', + posts: [{ id: '1', published: false }], + }) + .post.create(createPayload) + ).toBeRejectedByPolicy(); + + // no comments + await expect( + db + .$setAuth({ + id: '1', + posts: [{ id: '1', published: true }], + }) + .post.create(createPayload) + ).toBeRejectedByPolicy(); + + // not all comments published + await expect( + db + .$setAuth({ + id: '1', + posts: [ + { + id: '1', + published: true, + comments: [ + { id: '1', published: true }, + { id: '2', published: false }, + ], + }, + ], + }) + .post.create(createPayload) + ).toBeRejectedByPolicy(); + + // comments published but parent post is not + await expect( + db + .$setAuth({ + id: '1', + posts: [ + { + id: '1', + published: false, + comments: [{ id: '1', published: true }], + }, + { id: '2', published: true }, + ], + }) + .post.create(createPayload) + ).toBeRejectedByPolicy(); + + await expect( + db + .$setAuth({ + id: '1', + posts: [ + { + id: '1', + published: true, + comments: [{ id: '1', published: true }], + }, + { id: '2', published: false }, + ], + }) + .post.create(createPayload) + ).toResolveTruthy(); + + // no comments ("every" evaluates to true in this case) + await expect( + db + .$setAuth({ + id: '1', + posts: [{ id: '1', published: true, comments: [] }], + }) + .post.create(createPayload) + ).toResolveTruthy(); + }); + + it('works with using auth value as default for literal fields', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id + name String + score Int + + } + + model Post { + id String @id @default(uuid()) + title String + score Int? @default(auth().score) + authorName String? @default(auth().name) + + @@allow('all', true) + } + ` + ); + + const userDb = db.$setAuth({ id: '1', name: 'user1', score: 10 }); + await expect( + userDb.post.create({ data: { title: 'abc' } }) + ).toResolveTruthy(); + await expect(userDb.post.findMany()).resolves.toHaveLength(1); + await expect( + userDb.post.count({ where: { authorName: 'user1', score: 10 } }) + ).resolves.toBe(1); + + await expect( + userDb.post.createMany({ data: [{ title: 'def' }] }) + ).resolves.toMatchObject({ count: 1 }); + const r = await userDb.post.createManyAndReturn({ + data: [{ title: 'xxx' }, { title: 'yyy' }], + }); + expect(r[0]).toMatchObject({ title: 'xxx', score: 10 }); + expect(r[1]).toMatchObject({ title: 'yyy', score: 10 }); + }); + + it('respects explicitly passed field values even when default is set', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id + name String + + } + + model Post { + id String @id @default(uuid()) + authorName String? @default(auth().name) + + @@allow('all', true) + } + ` + ); + + const userContextName = 'user1'; + const overrideName = 'no-default-auth-name'; + const userDb = db.$setAuth({ id: '1', name: userContextName }); + await expect( + userDb.post.create({ data: { authorName: overrideName } }) + ).toResolveTruthy(); + await expect( + userDb.post.count({ where: { authorName: overrideName } }) + ).resolves.toBe(1); + + await expect( + userDb.post.createMany({ data: [{ authorName: overrideName }] }) + ).toResolveTruthy(); + await expect( + userDb.post.count({ where: { authorName: overrideName } }) + ).resolves.toBe(2); + + const r = await userDb.post.createManyAndReturn({ + data: [{ authorName: overrideName }], + }); + expect(r[0]).toMatchObject({ authorName: overrideName }); + }); + + it('works with using auth value as default for foreign key', async () => { + const anonDb = await createPolicyTestClient( + ` + model User { + id String @id + email String @unique + posts Post[] + + @@allow('all', true) + + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String @default(auth().id) + + @@allow('all', true) + } + ` + ); + + const rawDb = anonDb.$unuseAll(); + await rawDb.user.create({ + data: { id: 'userId-1', email: 'user1@abc.com' }, + }); + await rawDb.user.create({ + data: { id: 'userId-2', email: 'user2@abc.com' }, + }); + + const db = anonDb.$setAuth({ id: 'userId-1' }); + + // default auth effective + await expect( + db.post.create({ data: { title: 'post1' } }) + ).resolves.toMatchObject({ authorId: 'userId-1' }); + + // default auth ineffective due to explicit connect + await expect( + db.post.create({ + data: { + title: 'post2', + author: { connect: { email: 'user1@abc.com' } }, + }, + }) + ).resolves.toMatchObject({ authorId: 'userId-1' }); + + // default auth ineffective due to explicit connect + await expect( + db.post.create({ + data: { + title: 'post3', + author: { connect: { email: 'user2@abc.com' } }, + }, + }) + ).resolves.toMatchObject({ authorId: 'userId-2' }); + + // TODO: upsert + // await expect( + // db.post.upsert({ + // where: { id: 'post4' }, + // create: { id: 'post4', title: 'post4' }, + // update: { title: 'post4' }, + // }) + // ).resolves.toMatchObject({ authorId: 'userId-1' }); + + // default auth effective for createMany + await expect( + db.post.createMany({ data: { title: 'post5' } }) + ).resolves.toMatchObject({ count: 1 }); + const r = await db.post.findFirst({ where: { title: 'post5' } }); + expect(r).toMatchObject({ authorId: 'userId-1' }); + + // default auth effective for createManyAndReturn + const r1 = await db.post.createManyAndReturn({ + data: { title: 'post6' }, + }); + expect(r1[0]).toMatchObject({ authorId: 'userId-1' }); + }); + + it('works with using nested auth value as default', async () => { + const anonDb = await createPolicyTestClient( + ` + model User { + id String @id + profile Profile? + posts Post[] + + @@allow('all', true) + } + + model Profile { + id String @id @default(uuid()) + image Image? + user User @relation(fields: [userId], references: [id]) + userId String @unique + } + + model Image { + id String @id @default(uuid()) + url String + profile Profile @relation(fields: [profileId], references: [id]) + profileId String @unique + } + + model Post { + id String @id @default(uuid()) + title String + defaultImageUrl String @default(auth().profile.image.url) + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('all', true) + } + ` + ); + const url = 'https://zenstack.dev'; + const db = anonDb.$setAuth({ + id: 'userId-1', + profile: { image: { url } }, + }); + + // top-level create + await expect( + db.user.create({ data: { id: 'userId-1' } }) + ).toResolveTruthy(); + await expect( + db.post.create({ + data: { title: 'abc', author: { connect: { id: 'userId-1' } } }, + }) + ).resolves.toMatchObject({ defaultImageUrl: url }); + + // nested create + let result = await db.user.create({ + data: { + id: 'userId-2', + posts: { + create: [{ title: 'p1' }, { title: 'p2' }], + }, + }, + include: { posts: true }, + }); + expect(result.posts).toEqual( + expect.arrayContaining([ + expect.objectContaining({ title: 'p1', defaultImageUrl: url }), + expect.objectContaining({ title: 'p2', defaultImageUrl: url }), + ]) + ); + }); + + it('works with using auth value as default with anonymous context', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String @default(auth().id) + + @@allow('all', true) + } + ` + ); + + await expect( + db.user.create({ data: { id: 'userId-1' } }) + ).toResolveTruthy(); + await expect( + db.post.create({ data: { title: 'title' } }) + ).rejects.toThrow('constraint failed'); + await expect(db.post.findMany({})).toResolveTruthy(); + }); + + it('works with using auth value as default in mixed checked and unchecked context', async () => { + const anonDb = await createPolicyTestClient( + ` + model User { + id String @id + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String @default(auth().id) + + stats Stats @relation(fields: [statsId], references: [id]) + statsId String @unique + + @@allow('all', true) + } + + model Stats { + id String @id @default(uuid()) + viewCount Int @default(0) + post Post? + + @@allow('all', true) + } + ` + ); + + const db = anonDb.$setAuth({ id: 'userId-1' }); + await db.user.create({ data: { id: 'userId-1' } }); + + // unchecked context + await db.stats.create({ data: { id: 'stats-1', viewCount: 10 } }); + await expect( + db.post.create({ data: { title: 'title1', statsId: 'stats-1' } }) + ).toResolveTruthy(); + + await db.stats.create({ data: { id: 'stats-2', viewCount: 10 } }); + await expect( + db.post.createMany({ + data: [{ title: 'title2', statsId: 'stats-2' }], + }) + ).resolves.toMatchObject({ + count: 1, + }); + + await db.stats.create({ data: { id: 'stats-3', viewCount: 10 } }); + const r = await db.post.createManyAndReturn({ + data: [{ title: 'title3', statsId: 'stats-3' }], + }); + expect(r[0]).toMatchObject({ statsId: 'stats-3' }); + + // checked context + await db.stats.create({ data: { id: 'stats-4', viewCount: 10 } }); + await expect( + db.post.create({ + data: { + title: 'title4', + stats: { connect: { id: 'stats-4' } }, + }, + }) + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/todo-sample.test.ts b/packages/runtime/test/policy/todo-sample.test.ts index e5f7b2f40..0159ed94d 100644 --- a/packages/runtime/test/policy/todo-sample.test.ts +++ b/packages/runtime/test/policy/todo-sample.test.ts @@ -26,7 +26,7 @@ describe('Todo sample', () => { name: 'User 2', }; - const client: any = new ZenStackClient(schema, { log: ['query'] }); + const client: any = new ZenStackClient(schema); await client.$pushSchema(); const anonDb: any = client.$use(new PolicyPlugin()); diff --git a/packages/runtime/test/policy/utils.ts b/packages/runtime/test/policy/utils.ts new file mode 100644 index 000000000..c5e1607f5 --- /dev/null +++ b/packages/runtime/test/policy/utils.ts @@ -0,0 +1,14 @@ +import type { ClientOptions } from '../../src/client/options'; +import { PolicyPlugin } from '../../src/plugins/policy'; +import type { SchemaDef } from '../../src/schema'; +import { createTestClient } from '../utils'; + +export function createPolicyTestClient( + schema: string | SchemaDef, + options?: ClientOptions +) { + return createTestClient(schema as any, { + ...options, + plugins: [new PolicyPlugin()], + }); +} diff --git a/packages/runtime/test/test-schema.ts b/packages/runtime/test/test-schema.ts index 446890bf2..1b0550717 100644 --- a/packages/runtime/test/test-schema.ts +++ b/packages/runtime/test/test-schema.ts @@ -16,7 +16,7 @@ export const schema = { id: { type: 'String', id: true, - default: { call: 'cuid' }, + default: Expression.call('cuid'), attributes: [ { name: '@id' }, { @@ -47,7 +47,7 @@ export const schema = { }, createdAt: { type: 'DateTime', - default: { call: 'now' }, + default: Expression.call('now'), attributes: [ { name: '@default', @@ -141,11 +141,11 @@ export const schema = { id: { type: 'String', id: true, - default: { call: 'cuid' }, + default: Expression.call('cuid'), }, createdAt: { type: 'DateTime', - default: { call: 'now' }, + default: Expression.call('now'), }, updatedAt: { type: 'DateTime', @@ -248,11 +248,11 @@ export const schema = { id: { type: 'String', id: true, - default: { call: 'cuid' }, + default: Expression.call('cuid'), }, createdAt: { type: 'DateTime', - default: { call: 'now' }, + default: Expression.call('now'), }, updatedAt: { type: 'DateTime', @@ -288,7 +288,7 @@ export const schema = { id: { type: 'String', id: true, - default: { call: 'cuid' }, + default: Expression.call('cuid'), }, bio: { type: 'String' }, age: { type: 'Int', optional: true }, diff --git a/packages/runtime/test/utils.ts b/packages/runtime/test/utils.ts index adbf2480e..9adc82e9b 100644 --- a/packages/runtime/test/utils.ts +++ b/packages/runtime/test/utils.ts @@ -1,3 +1,4 @@ +import { generateTsSchema } from '@zenstackhq/testtools'; import Sqlite from 'better-sqlite3'; import { Client as PGClient, Pool } from 'pg'; import { ZenStackClient } from '../src/client'; @@ -44,3 +45,36 @@ export async function makePostgresClient( }, } as unknown as ClientOptions); } + +type CreateTestClientOptions = ClientOptions; + +export async function createTestClient( + schema: Schema, + options?: CreateTestClientOptions +): Promise; +export async function createTestClient( + schema: string, + options?: CreateTestClientOptions +): Promise; +export async function createTestClient( + schema: Schema | string, + options?: CreateTestClientOptions +): Promise { + let _schema = + typeof schema === 'string' + ? ((await generateTsSchema(schema)) as Schema) + : schema; + + const { plugins, ...rest } = options ?? {}; + + let client = new ZenStackClient(_schema, rest as ClientOptions); + await client.$pushSchema(); + + if (options?.plugins) { + for (const plugin of options.plugins) { + client = client.$use(plugin); + } + } + + return client; +} diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index ae601c0fa..afd166c27 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -37,6 +37,7 @@ import path from 'node:path'; import invariant from 'tiny-invariant'; import { match } from 'ts-pattern'; import * as ts from 'typescript'; +import { ModelUtils } from '.'; import { getAttribute, getAuthDecl, @@ -498,32 +499,59 @@ export class TsSchemaGenerator { const defaultValue = this.getMappedDefault(field); if (defaultValue !== undefined) { - if (typeof defaultValue === 'object' && 'call' in defaultValue) { - objectFields.push( - ts.factory.createPropertyAssignment( - 'default', - ts.factory.createObjectLiteralExpression([ - ts.factory.createPropertyAssignment( - 'call', - ts.factory.createStringLiteral( - defaultValue.call - ) - ), - ...(defaultValue.args.length > 0 - ? [ - ts.factory.createPropertyAssignment( - 'args', - ts.factory.createArrayLiteralExpression( - defaultValue.args.map((arg) => - this.createLiteralNode(arg) - ) - ) - ), - ] - : []), - ]) - ) - ); + if (typeof defaultValue === 'object') { + if ('call' in defaultValue) { + objectFields.push( + ts.factory.createPropertyAssignment( + 'default', + + ts.factory.createCallExpression( + ts.factory.createIdentifier('Expression.call'), + undefined, + [ + ts.factory.createStringLiteral( + defaultValue.call + ), + ts.factory.createArrayLiteralExpression( + defaultValue.args.map((arg) => + this.createLiteralNode(arg) + ) + ), + ] + ) + ) + ); + } else if ('authMember' in defaultValue) { + objectFields.push( + ts.factory.createPropertyAssignment( + 'default', + ts.factory.createCallExpression( + ts.factory.createIdentifier( + 'Expression.member' + ), + undefined, + [ + ts.factory.createCallExpression( + ts.factory.createIdentifier( + 'Expression.call' + ), + undefined, + [ts.factory.createStringLiteral('auth')] + ), + ts.factory.createArrayLiteralExpression( + defaultValue.authMember.map((m) => + ts.factory.createStringLiteral(m) + ) + ), + ] + ) + ) + ); + } else { + throw new Error( + `Unsupported default value type for field ${field.name}` + ); + } } else { objectFields.push( ts.factory.createPropertyAssignment( @@ -611,13 +639,23 @@ export class TsSchemaGenerator { return { type, url }; } - private getMappedDefault(field: DataModelField) { + private getMappedDefault( + field: DataModelField + ): + | string + | number + | boolean + | { call: string; args: any[] } + | { authMember: string[] } + | undefined { const defaultAttr = getAttribute(field, '@default'); if (!defaultAttr) { return undefined; } const defaultValue = defaultAttr.args[0]?.value; + invariant(defaultValue, 'Expected a default value'); + if (isLiteralExpr(defaultValue)) { const lit = (defaultValue as LiteralExpr).value; return field.type.type === 'Boolean' @@ -639,6 +677,10 @@ export class TsSchemaGenerator { this.getLiteral(arg.value) ), }; + } else if (this.isAuthMemberAccess(defaultValue)) { + return { + authMember: this.getMemberAccessChain(defaultValue), + }; } else { throw new Error( `Unsupported default value type for field ${field.name}` @@ -646,6 +688,36 @@ export class TsSchemaGenerator { } } + private getMemberAccessChain(expr: MemberAccessExpr): string[] { + if (!isMemberAccessExpr(expr.operand)) { + return [expr.member.$refText]; + } else { + return [ + ...this.getMemberAccessChain(expr.operand), + expr.member.$refText, + ]; + } + } + + private isAuthMemberAccess(expr: Expression): expr is MemberAccessExpr { + if (isMemberAccessExpr(expr)) { + return ( + this.isAuthInvocation(expr.operand) || + this.isAuthMemberAccess(expr.operand) + ); + } else { + return false; + } + } + + private isAuthInvocation(expr: Expression) { + return ( + isInvocationExpr(expr) && + expr.function.$refText === 'auth' && + ModelUtils.isFromStdlib(expr.function.ref!) + ); + } + private createRelationObject(field: DataModelField) { const relationFields: ts.PropertyAssignment[] = []; @@ -1191,23 +1263,6 @@ export class TsSchemaGenerator { ), ]; - // if (isDataModel(expr.$resolvedType?.decl)) { - // const operandModel = expr.operand.$resolvedType?.decl! as DataModel; - // const relationModel = expr.$resolvedType.decl; - // args.push( - // ts.factory.createObjectLiteralExpression([ - // ts.factory.createPropertyAssignment( - // 'fromModel', - // ts.factory.createStringLiteral(operandModel.name) - // ), - // ts.factory.createPropertyAssignment( - // 'relationModel', - // ts.factory.createStringLiteral(relationModel.name) - // ), - // ]) - // ); - // } - return ts.factory.createCallExpression( ts.factory.createIdentifier('Expression.member'), undefined,