Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BREAKINGCHANGES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1. `auth()` cannot be directly compared with a relation anymore
2.
9 changes: 5 additions & 4 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
- [x] Relation connection
- [x] Create many
- [x] ID generation
- [x] CreateManyAndReturn
- [ ] Find
- [x] Input validation
- [ ] Field selection
- [x] Field selection
- [x] Omit
- [x] Counting relation
- [x] Pagination
Expand All @@ -42,6 +43,7 @@
- [x] Nested to-one
- [ ] Delta update for numeric fields
- [ ] Array update
- [ ] Upsert
- [x] Delete
- [ ] Aggregation
- [x] Count
Expand All @@ -52,15 +54,14 @@
- [x] Computed fields
- [?] Prisma client extension
- [ ] Misc
- [ ] Rename AST Model to Schema
- [ ] Compound ID
- [ ] Cross field comparison
- [ ] Many-to-many relation
- [ ] Cache validation schemas
- [?] Logging
- [ ] Error system
- [?] Custom table name
- [ ] Custom field name
- [x] Custom table name
- [x] Custom field name
- [ ] Access Policy
- [ ] Polymorphism
- [x] Migration
Expand Down
16 changes: 14 additions & 2 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
} from 'kysely';
import { match } from 'ts-pattern';
import type { GetModels, ProcedureDef, SchemaDef } from '../schema';
import type { AuthType } from '../schema/schema';
import type { ClientConstructor, ClientContract } from './contract';
import type { ModelOperations } from './crud-types';
import { AggregateOperationHandler } from './crud/operations/aggregate';
Expand All @@ -30,7 +31,6 @@ import type { RuntimePlugin } from './plugin';
import { createDeferredPromise } from './promise';
import type { ToKysely } from './query-builder';
import { ResultProcessor } from './result-processor';
import type { AuthType } from '../schema/schema';

/**
* Creates a new ZenStack client instance.
Expand Down Expand Up @@ -201,7 +201,10 @@ export class ClientImpl<Schema extends SchemaDef> {
return new ClientImpl<Schema>(this.schema, newOptions, this);
}

$setAuth(auth: AuthType<Schema>) {
$setAuth(auth: AuthType<Schema> | undefined) {
if (auth !== undefined && typeof auth !== 'object') {
throw new Error('Invalid auth object');
}
const newClient = new ClientImpl<Schema>(
this.schema,
this.$options,
Expand Down Expand Up @@ -364,6 +367,15 @@ function createModelCrudHandler<
);
},

createManyAndReturn: (args: unknown) => {
return createPromise(
'createManyAndReturn',
args,
new CreateOperationHandler(client, model, inputValidator),
true
);
},

update: (args: unknown) => {
return createPromise(
'update',
Expand Down
20 changes: 17 additions & 3 deletions packages/runtime/src/client/crud-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -299,16 +299,18 @@ export type WhereUnique<
Extract<keyof GetModel<Schema, Model>['uniqueFields'], string>
>;

type OmitFields<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
[Key in ScalarFields<Schema, Model>]?: true;
};

export type SelectInclude<
Schema extends SchemaDef,
Model extends GetModels<Schema>,
AllowCount extends boolean
> = {
select?: Select<Schema, Model, AllowCount>;
include?: Include<Schema, Model>;
omit?: {
[Key in ScalarFields<Schema, Model>]?: true;
};
omit?: OmitFields<Schema, Model>;
};

type Select<
Expand Down Expand Up @@ -545,6 +547,12 @@ export type CreateManyArgs<
Model extends GetModels<Schema>
> = CreateManyPayload<Schema, Model>;

export type CreateManyAndReturnArgs<
Schema extends SchemaDef,
Model extends GetModels<Schema>
> = CreateManyPayload<Schema, Model> &
Omit<SelectInclude<Schema, Model, false>, 'include'>;

type OptionalWrap<
Schema extends SchemaDef,
Model extends GetModels<Schema>,
Expand Down Expand Up @@ -1074,6 +1082,12 @@ export type ModelOperations<

createMany(args?: CreateManyPayload<Schema, Model>): Promise<BatchResult>;

createManyAndReturn(
args?: CreateManyAndReturnArgs<Schema, Model>
): Promise<
ModelResult<Schema, Model, CreateManyAndReturnArgs<Schema, Model>>[]
>;

update<T extends UpdateArgs<Schema, Model>>(
args: SelectSubset<T, UpdateArgs<Schema, Model>>
): Promise<ModelResult<Schema, Model, T>>;
Expand Down
123 changes: 100 additions & 23 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ import { match } from 'ts-pattern';
import { ulid } from 'ulid';
import * as uuid from 'uuid';
import type { ClientContract } from '../..';
import type { GetModels, ModelDef, SchemaDef } from '../../../schema';
import type {
BuiltinType,
FieldDef,
FieldDefaultProvider,
} from '../../../schema/schema';
import {
Expression,
type GetModels,
type ModelDef,
type SchemaDef,
} from '../../../schema';
import type { BuiltinType, FieldDef } from '../../../schema/schema';
import { clone } from '../../../utils/clone';
import { enumerate } from '../../../utils/enumerate';
import {
extractFields,
fieldsToSelectObject,
} from '../../../utils/object-utils';
import type { FindArgs, SelectInclude, Where } from '../../crud-types';
import { InternalError, NotFoundError, QueryError } from '../../errors';
import type { ToKysely } from '../../query-builder';
Expand Down Expand Up @@ -49,6 +54,7 @@ export type CrudOperation =
| 'findFirst'
| 'create'
| 'createMany'
| 'createManyAndReturn'
| 'update'
| 'updateMany'
| 'delete'
Expand Down Expand Up @@ -547,8 +553,42 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}

case 'connect': {
// directly return the payload as foreign key values
result = subPayload;
const referencedPkFields =
relationField.relation!.references!;
invariant(
referencedPkFields,
'relation must have fields info'
);
const extractedFks = extractFields(
subPayload,
referencedPkFields
);
if (
Object.keys(extractedFks).length ===
referencedPkFields.length
) {
// payload contains all referenced pk fields, we can
// directly use it to connect the relation
result = extractedFks;
} else {
// read the relation entity and fetch the referenced pk fields
const relationEntity = await this.readUnique(
kysely,
relationModel,
{
where: subPayload,
select: fieldsToSelectObject(
referencedPkFields
) as any,
}
);
if (!relationEntity) {
throw new NotFoundError(
`Could not find the entity for connect action`
);
}
result = relationEntity;
}
break;
}

Expand Down Expand Up @@ -674,12 +714,16 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return Promise.all(tasks);
}

protected async createMany(
protected async createMany<
ReturnData extends boolean,
Result = ReturnData extends true ? unknown[] : { count: number }
>(
kysely: ToKysely<Schema>,
model: GetModels<Schema>,
input: { data: any; skipDuplicates?: boolean },
returnData: ReturnData,
fromRelation?: FromRelationContext<Schema>
) {
): Promise<Result> {
const modelDef = this.requireModel(model);

let relationKeyPairs: { fk: string; pk: string }[] = [];
Expand Down Expand Up @@ -713,8 +757,15 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
.$if(!!input.skipDuplicates, (qb) =>
qb.onConflict((oc) => oc.doNothing())
);
const result = await query.executeTakeFirstOrThrow();
return { count: Number(result.numInsertedOrUpdatedRows) };

if (!returnData) {
const result = await query.executeTakeFirstOrThrow();
return { count: Number(result.numInsertedOrUpdatedRows) } as Result;
} else {
const idFields = getIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
return result as Result;
}
}

private fillGeneratedValues(modelDef: ModelDef, data: object) {
Expand All @@ -724,10 +775,10 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
if (!(field in data)) {
if (
typeof fields[field]?.default === 'object' &&
'call' in fields[field].default
'kind' in fields[field].default
) {
const generated = this.evalGenerator(fields[field].default);
if (generated) {
if (generated !== undefined) {
values[field] = generated;
}
} else if (fields[field]?.updatedAt) {
Expand All @@ -738,15 +789,40 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return values;
}

private evalGenerator(defaultProvider: FieldDefaultProvider) {
return match(defaultProvider.call)
.with('cuid', () => createId())
.with('uuid', () =>
defaultProvider.args?.[0] === 7 ? uuid.v7() : uuid.v4()
)
.with('nanoid', () => nanoid(defaultProvider.args?.[0]))
.with('ulid', () => ulid())
.otherwise(() => undefined);
private evalGenerator(defaultValue: Expression) {
if (Expression.isCall(defaultValue)) {
return match(defaultValue.function)
.with('cuid', () => createId())
.with('uuid', () =>
defaultValue.args?.[0] &&
Expression.isLiteral(defaultValue.args?.[0]) &&
defaultValue.args[0].value === 7
? uuid.v7()
: uuid.v4()
)
.with('nanoid', () =>
defaultValue.args?.[0] &&
Expression.isLiteral(defaultValue.args[0]) &&
typeof defaultValue.args[0].value === 'number'
? nanoid(defaultValue.args[0].value)
: nanoid()
)
.with('ulid', () => ulid())
.otherwise(() => undefined);
} else if (
Expression.isMember(defaultValue) &&
Expression.isCall(defaultValue.receiver) &&
defaultValue.receiver.function === 'auth'
) {
// `auth()` member access
let val: any = this.client.$auth;
for (const member of defaultValue.members) {
val = val?.[member];
}
return val ?? null;
} else {
return undefined;
}
}

protected async update(
Expand Down Expand Up @@ -1023,6 +1099,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
kysely,
fieldModel,
value as { data: any; skipDuplicates: boolean },
false,
fromRelationContext
)
);
Expand Down
46 changes: 43 additions & 3 deletions packages/runtime/src/client/crud/operations/create.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import { match } from 'ts-pattern';
import { RejectedByPolicyError } from '../../../plugins/policy/errors';
import type { GetModels, SchemaDef } from '../../../schema';
import type { CreateArgs, CreateManyArgs } from '../../crud-types';
import type {
CreateArgs,
CreateManyAndReturnArgs,
CreateManyArgs,
} from '../../crud-types';
import { getIdValues } from '../../query-utils';
import { BaseOperationHandler } from './base';

export class CreateOperationHandler<
Schema extends SchemaDef
> extends BaseOperationHandler<Schema> {
async handle(
operation: 'create' | 'createMany',
operation: 'create' | 'createMany' | 'createManyAndReturn',
args: unknown | undefined
) {
return match(operation)
Expand All @@ -23,6 +27,14 @@ export class CreateOperationHandler<
this.inputValidator.validateCreateManyArgs(this.model, args)
);
})
.with('createManyAndReturn', () => {
return this.runCreateManyAndReturn(
this.inputValidator.validateCreateManyAndReturnArgs(
this.model,
args
)
);
})
.exhaustive();
}

Expand Down Expand Up @@ -50,6 +62,34 @@ export class CreateOperationHandler<
if (args === undefined) {
return { count: 0 };
}
return this.createMany(this.kysely, this.model, args);
return this.createMany(this.kysely, this.model, args, false);
}

private async runCreateManyAndReturn(
args?: CreateManyAndReturnArgs<Schema, GetModels<Schema>>
) {
if (args === undefined) {
return [];
}

// TODO: avoid using transaction for simple create
return this.safeTransaction(async (tx) => {
const createResult = await this.createMany(
tx,
this.model,
args,
true
);
return this.read(tx, this.model, {
select: args.select,
omit: args.omit,
where: {
OR: createResult.map(
(item) =>
getIdValues(this.schema, this.model, item) as any
),
} as any, // TODO: fix type
});
});
}
}
Loading