diff --git a/TODO.md b/TODO.md index 84a13f6a7..7a277af8d 100644 --- a/TODO.md +++ b/TODO.md @@ -23,7 +23,7 @@ - [x] Counting relation - [x] Pagination - [x] Skip and limit - - [ ] Cursor + - [x] Cursor - [x] Filtering - [x] Unique fields - [x] Scalar fields diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index 8de6ff217..8509487df 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -281,7 +281,7 @@ function createModelCrudHandler< } let result: unknown; if (r && postProcess) { - result = resultProcessor.processResult(r, model); + result = resultProcessor.processResult(r, model, args); } else { result = r ?? null; } diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index 4a6a48a88..d1a33c8f6 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -362,6 +362,10 @@ type Distinct> = { distinct?: OrArray>; }; +type Cursor> = { + cursor?: WhereUnique; +}; + type Select< Schema extends SchemaDef, Model extends GetModels, @@ -570,7 +574,8 @@ export type FindArgs< } : {}) & SelectIncludeOmit & - Distinct; + Distinct & + Cursor; export type FindUniqueArgs< Schema extends SchemaDef, diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 8523c86d1..28114760d 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -53,7 +53,7 @@ export class PostgresCrudDialect< ); return joinedQuery.select( - `${parentAlias}$${relationField}.data as ${relationField}` + `${parentAlias}$${relationField}.$j as ${relationField}` ); } @@ -82,12 +82,12 @@ export class PostgresCrudDialect< // however if there're filter/orderBy/take/skip, // we need to build a subquery to handle them before aggregation - if (payload && typeof payload === 'object') { - result = eb.selectFrom(() => { - let subQuery = eb - .selectFrom(`${relationModel}`) - .selectAll(); + result = eb.selectFrom(() => { + let subQuery = eb + .selectFrom(`${relationModel}`) + .selectAll(); + if (payload && typeof payload === 'object') { if (payload.where) { subQuery = subQuery.where((eb) => this.buildFilter( @@ -118,9 +118,27 @@ export class PostgresCrudDialect< skip !== undefined || take !== undefined, negateOrderBy ); - return subQuery.as(joinTableName); - }); - } + } + + // add join conditions + const joinPairs = buildJoinPairs( + this.schema, + model, + parentName, + relationField, + relationModel + ); + subQuery = subQuery.where((eb) => + this.and( + eb, + ...joinPairs.map(([left, right]) => + eb(sql.ref(left), '=', sql.ref(right)) + ) + ) + ); + + return subQuery.as(joinTableName); + }); result = this.buildRelationObjectSelect( relationModel, @@ -131,23 +149,6 @@ export class PostgresCrudDialect< parentName ); - // add join conditions - const joinPairs = buildJoinPairs( - this.schema, - model, - parentName, - relationField, - joinTableName - ); - result = result.where((eb) => - this.and( - eb, - ...joinPairs.map(([left, right]) => - eb(sql.ref(left), '=', sql.ref(right)) - ) - ) - ); - // add nested joins for each relation result = this.buildRelationJoins( relationModel, @@ -189,9 +190,9 @@ export class PostgresCrudDialect< )}))`, sql`'[]'::jsonb` ) - .as('data'); + .as('$j'); } else { - return sql`jsonb_build_object(${sql.join(objArgs)})`.as('data'); + return sql`jsonb_build_object(${sql.join(objArgs)})`.as('$j'); } }); @@ -267,7 +268,7 @@ export class PostgresCrudDialect< .filter(([, value]) => value) .map(([field]) => [ sql.lit(field), - eb.ref(`${parentName}$${relationField}$${field}.data`), + eb.ref(`${parentName}$${relationField}$${field}.$j`), ]) .flatMap((v) => v) ); diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index ff50d9cf0..5b836e2bb 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -74,25 +74,12 @@ export class SqliteCrudDialect< const relationModel = relationFieldDef.type as GetModels; const relationModelDef = requireModel(this.schema, relationModel); - const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs( - this.schema, - model, - relationField - ); - const subQueryName = `${parentName}$${relationField}`; - // simple select by default - let tbl: SelectQueryBuilder = eb.selectFrom( - `${relationModel} as ${subQueryName}` - ); - - // however if there're filter/orderBy/take/skip, - // we need to build a subquery to handle them before aggregation - if (payload && typeof payload === 'object') { - tbl = eb.selectFrom(() => { - let subQuery = eb.selectFrom(relationModel).selectAll(); + let tbl = eb.selectFrom(() => { + let subQuery = eb.selectFrom(relationModel).selectAll(); + if (payload && typeof payload === 'object') { if (payload.where) { subQuery = subQuery.where((eb) => this.buildFilter( @@ -123,10 +110,33 @@ export class SqliteCrudDialect< skip !== undefined || take !== undefined, negateOrderBy ); + } - return subQuery.as(subQueryName); + // join conditions + const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs( + this.schema, + model, + relationField + ); + keyPairs.forEach(({ fk, pk }) => { + if (ownedByModel) { + // the parent model owns the fk + subQuery = subQuery.whereRef( + `${relationModel}.${pk}`, + '=', + `${parentName}.${fk}` + ); + } else { + // the relation side owns the fk + subQuery = subQuery.whereRef( + `${relationModel}.${fk}`, + '=', + `${parentName}.${pk}` + ); + } }); - } + return subQuery.as(subQueryName); + }); tbl = tbl.select(() => { type ArgsType = @@ -227,30 +237,12 @@ export class SqliteCrudDialect< )}))`, sql`json_array()` ) - .as('data'); + .as('$j'); } else { return sql`json_object(${sql.join(objArgs)})`.as('data'); } }); - // join conditions - keyPairs.forEach(({ fk, pk }) => { - if (ownedByModel) { - // the parent model owns the fk - tbl = tbl.whereRef( - `${parentName}$${relationField}.${pk}`, - '=', - `${parentName}.${fk}` - ); - } else { - // the relation side owns the fk - tbl = tbl.whereRef( - `${parentName}$${relationField}.${fk}`, - '=', - `${parentName}.${pk}` - ); - } - }); return tbl; } diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 6ac1ac778..f9d7a11eb 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -2,6 +2,7 @@ import { createId } from '@paralleldrive/cuid2'; import { DeleteResult, expressionBuilder, + ExpressionWrapper, sql, UpdateResult, type ExpressionBuilder, @@ -29,7 +30,12 @@ import { } from '../../../utils/object-utils'; import { CONTEXT_COMMENT_PREFIX } from '../../constants'; import type { CRUD } from '../../contract'; -import type { FindArgs, SelectIncludeOmit, WhereInput } from '../../crud-types'; +import type { + FindArgs, + SelectIncludeOmit, + SortOrder, + WhereInput, +} from '../../crud-types'; import { InternalError, NotFoundError, QueryError } from '../../errors'; import type { ToKysely } from '../../query-builder'; import { @@ -44,8 +50,10 @@ import { isForeignKeyField, isRelationField, isScalarField, + makeDefaultOrderBy, requireField, requireModel, + safeJSONStringify, } from '../../query-utils'; import { getCrudDialect } from '../dialects'; import type { BaseCrudDialect } from '../dialects/base'; @@ -205,33 +213,46 @@ export abstract class BaseOperationHandler { ); } + if (args?.cursor) { + query = this.buildCursorFilter( + model, + query, + args.cursor, + args.orderBy, + negateOrderBy + ); + } + query = query.modifyEnd( this.makeContextComment({ model, operation: 'read' }) ); + let result: any[] = []; try { - let result = await query.execute(); - if (inMemoryDistinct) { - const distinctResult: Record[] = []; - const seen = new Set(); - for (const r of result as any[]) { - const key = JSON.stringify( - inMemoryDistinct.map((f) => r[f]) - )!; - if (!seen.has(key)) { - distinctResult.push(r); - seen.add(key); - } - } - result = distinctResult; - } - return result; + result = await query.execute(); } catch (err) { const { sql, parameters } = query.compile(); throw new QueryError( `Failed to execute query: ${err}, sql: ${sql}, parameters: ${parameters}` ); } + + if (inMemoryDistinct) { + const distinctResult: Record[] = []; + const seen = new Set(); + for (const r of result as any[]) { + const key = safeJSONStringify( + inMemoryDistinct.map((f) => r[f]) + )!; + if (!seen.has(key)) { + distinctResult.push(r); + seen.add(key); + } + } + result = distinctResult; + } + + return result; } protected async readUnique( @@ -408,6 +429,58 @@ export abstract class BaseOperationHandler { } } + private buildCursorFilter( + model: string, + query: SelectQueryBuilder, + cursor: FindArgs, true>['cursor'], + orderBy: FindArgs, true>['orderBy'], + negateOrderBy: boolean + ) { + if (!orderBy) { + orderBy = makeDefaultOrderBy(this.schema, model); + } + + const orderByItems = ensureArray(orderBy).flatMap((obj) => + Object.entries(obj) + ); + + const eb = expressionBuilder(); + const cursorFilter = this.dialect.buildFilter(eb, model, model, cursor); + + let result = query; + let filters: ExpressionWrapper[] = []; + + for (let i = orderByItems.length - 1; i >= 0; i--) { + const andFilters: ExpressionWrapper[] = []; + + for (let j = 0; j <= i; j++) { + const [field, order] = orderByItems[j]!; + const _order = negateOrderBy + ? order === 'asc' + ? 'desc' + : 'asc' + : order; + const op = j === i ? (_order === 'asc' ? '>=' : '<=') : '='; + andFilters.push( + eb( + eb.ref(`${model}.${field}`), + op, + eb + .selectFrom(model) + .select(`${model}.${field}`) + .where(cursorFilter) + ) + ); + } + + filters.push(eb.and(andFilters)); + } + + result = result.where((eb) => eb.or(filters)); + + return result; + } + protected async create( kysely: ToKysely, model: GetModels, diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 5c6ca91ce..f463666d1 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -158,6 +158,7 @@ export class InputValidator { fields['include'] = this.makeIncludeSchema(model).optional(); fields['omit'] = this.makeOmitSchema(model).optional(); fields['distinct'] = this.makeDistinctSchema(model).optional(); + fields['cursor'] = this.makeCursorSchema(model).optional(); if (collection) { fields['skip'] = z.number().int().nonnegative().optional(); @@ -626,6 +627,10 @@ export class InputValidator { return this.orArray(z.enum(nonRelationFields as any), true); } + private makeCursorSchema(model: string) { + return this.makeWhereSchema(model, true, true).optional(); + } + // #endregion // #region Create diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 2fd37755a..15ad1676d 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -269,3 +269,13 @@ export function ensureArray(value: T | T[]): T[] { return [value]; } } + +export function safeJSONStringify(value: unknown) { + return JSON.stringify(value, (_, v) => { + if (typeof v === 'bigint') { + return v.toString(); + } else { + return v; + } + }); +} diff --git a/packages/runtime/src/client/result-processor.ts b/packages/runtime/src/client/result-processor.ts index e737e818e..88a7ba9f1 100644 --- a/packages/runtime/src/client/result-processor.ts +++ b/packages/runtime/src/client/result-processor.ts @@ -3,12 +3,19 @@ import invariant from 'tiny-invariant'; import { match } from 'ts-pattern'; import type { FieldDef, GetModels, SchemaDef } from '../schema'; import type { BuiltinType } from '../schema/schema'; -import { getField } from './query-utils'; +import { ensureArray, getField } from './query-utils'; export class ResultProcessor { constructor(private readonly schema: Schema) {} - processResult(data: any, model: GetModels) { + processResult(data: any, model: GetModels, args?: any) { + const result = this.doProcessResult(data, model); + // deal with correcting the reversed order due to negative take + this.fixReversedResult(result, model, args); + return result; + } + + private doProcessResult(data: any, model: GetModels) { if (Array.isArray(data)) { data.forEach((row, i) => (data[i] = this.processRow(row, model))); return data; @@ -65,7 +72,7 @@ export class ResultProcessor { return value; } } - return this.processResult( + return this.doProcessResult( relationData, fieldDef.type as GetModels ); @@ -122,4 +129,38 @@ export class ResultProcessor { private transformBytes(value: unknown) { return Buffer.isBuffer(value) ? Uint8Array.from(value) : value; } + + private fixReversedResult(data: any, model: GetModels, args: any) { + if ( + Array.isArray(data) && + typeof args === 'object' && + args && + args.take !== undefined && + args.take < 0 + ) { + data.reverse(); + } + + const selectInclude = args?.include ?? args?.select; + if (!selectInclude) { + return; + } + + for (const row of ensureArray(data)) { + for (const [field, value] of Object.entries(selectInclude)) { + if (typeof value !== 'object' || !value) { + continue; + } + const fieldDef = getField(this.schema, model, field); + if (!fieldDef?.relation) { + continue; + } + this.fixReversedResult( + row[field], + fieldDef.type as GetModels, + value + ); + } + } + } } diff --git a/packages/runtime/test/client-api/find.test.ts b/packages/runtime/test/client-api/find.test.ts index f841b1b2a..2c5a89a12 100644 --- a/packages/runtime/test/client-api/find.test.ts +++ b/packages/runtime/test/client-api/find.test.ts @@ -102,29 +102,29 @@ describe.each(createClientSpecs(PG_DB_NAME))( // negative take, default sort is negated await expect( - client.user.findMany({ take: -1 }) - ).toResolveWithLength(1); - await expect(client.user.findMany({ take: -1 })).resolves.toEqual( - expect.arrayContaining([expect.objectContaining({ id: '3' })]) + client.user.findMany({ take: -2 }) + ).toResolveWithLength(2); + await expect(client.user.findMany({ take: -2 })).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ id: '3' }), + expect.objectContaining({ id: '2' }), + ]) ); await expect( client.user.findMany({ skip: 1, take: -1 }) - ).resolves.toEqual( - expect.arrayContaining([expect.objectContaining({ id: '2' })]) - ); + ).resolves.toEqual([expect.objectContaining({ id: '2' })]); // negative take, explicit sort is negated await expect( client.user.findMany({ - skip: 2, - take: -1, + skip: 1, + take: -2, orderBy: { email: 'asc' }, }) - ).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ email: 'u02@test.com' }), - ]) - ); + ).resolves.toEqual([ + expect.objectContaining({ email: 'u02@test.com' }), + expect.objectContaining({ email: 'u1@test.com' }), + ]); }); it('works with orderBy', async () => { @@ -201,6 +201,79 @@ describe.each(createClientSpecs(PG_DB_NAME))( ).resolves.toMatchObject(user2); }); + it('works with cursor', async () => { + const user1 = await createUser(client, 'u1@test.com', { + id: '1', + role: 'ADMIN', + }); + const user2 = await createUser(client, 'u2@test.com', { + id: '2', + role: 'USER', + }); + const user3 = await createUser(client, 'u3@test.com', { + id: '3', + role: 'ADMIN', + }); + + // cursor is inclusive + await expect( + client.user.findMany({ + cursor: { id: user2.id }, + }) + ).resolves.toEqual([user2, user3]); + + // skip cursor + await expect( + client.user.findMany({ + skip: 1, + cursor: { id: user1.id }, + }) + ).resolves.toEqual([user2, user3]); + + // custom orderBy + await expect( + client.user.findMany({ + skip: 1, + cursor: { id: user2.id }, + orderBy: { email: 'desc' }, + }) + ).resolves.toEqual([user1]); + + // multiple orderBy + await expect( + client.user.findMany({ + skip: 1, + cursor: { id: user1.id }, + orderBy: [{ role: 'desc' }, { id: 'asc' }], + }) + ).resolves.toEqual([user3]); + + // multiple cursor + await expect( + client.user.findMany({ + skip: 1, + cursor: { id: user1.id, role: 'ADMIN' }, + }) + ).resolves.toEqual([user2, user3]); + + // non-existing cursor + await expect( + client.user.findMany({ + skip: 1, + cursor: { id: 'none' }, + }) + ).resolves.toEqual([]); + + // backward from cursor + await expect( + client.user.findMany({ + skip: 1, + take: -2, + cursor: { id: user3.id }, + }) + ).resolves.toEqual([user1, user2]); + }); + it('works with distinct', async () => { await createUser(client, 'u1@test.com', { name: 'Admin1', @@ -243,6 +316,48 @@ describe.each(createClientSpecs(PG_DB_NAME))( ); }); + it('works with nested skip, take, orderBy', async () => { + await createUser(client, 'u1@test.com', { + posts: { + create: [ + { id: '1', title: 'Post1' }, + { id: '2', title: 'Post2' }, + { id: '3', title: 'Post3' }, + ], + }, + }); + + await expect( + client.user.findFirst({ + include: { + posts: { orderBy: { title: 'desc' }, skip: 2, take: 1 }, + }, + }) + ).resolves.toEqual( + expect.objectContaining({ + posts: [expect.objectContaining({ id: '1' })], + }) + ); + + await expect( + client.user.findFirst({ + include: { + posts: { + skip: 1, + take: -2, + }, + }, + }) + ).resolves.toEqual( + expect.objectContaining({ + posts: [ + expect.objectContaining({ id: '1' }), + expect.objectContaining({ id: '2' }), + ], + }) + ); + }); + it('works with unique finds', async () => { let r = await client.user.findUnique({ where: { id: 'none' } }); expect(r).toBeNull(); diff --git a/packages/runtime/test/client-api/group-by.test.ts b/packages/runtime/test/client-api/group-by.test.ts index 64db3dcb8..24e7a40dd 100644 --- a/packages/runtime/test/client-api/group-by.test.ts +++ b/packages/runtime/test/client-api/group-by.test.ts @@ -60,6 +60,18 @@ describe.each(createClientSpecs(__filename))( }) ).resolves.toEqual([{ email: 'u1@test.com' }]); + await expect( + client.user.groupBy({ + by: ['email'], + skip: 1, + take: -2, + orderBy: { email: 'desc' }, + }) + ).resolves.toEqual([ + { email: 'u2@test.com' }, + { email: 'u1@test.com' }, + ]); + await expect( client.user.groupBy({ by: ['name'],