diff --git a/packages/bson/src/bson-parser.ts b/packages/bson/src/bson-parser.ts index 108d6bbc9..8ff9026e6 100644 --- a/packages/bson/src/bson-parser.ts +++ b/packages/bson/src/bson-parser.ts @@ -15,7 +15,7 @@ import { digitByteSize, TWO_PWR_32_DBL_N, } from './utils.js'; -import { decodeUTF8 } from './strings.js'; +import { decodeUTF8 } from '@deepkit/core'; import { nodeBufferToArrayBuffer, ReflectionKind, SerializationError, Type } from '@deepkit/type'; import { hexTable } from './model.js'; @@ -316,7 +316,7 @@ export class BaseParser { } /** - * Size includes the \0. If not existend, increase by 1. + * Size includes the \0. If not existent, increase by 1. */ eatString(size: number): string { this.offset += size; diff --git a/packages/bson/src/strings.ts b/packages/bson/src/strings.ts deleted file mode 100644 index 811e4e7fb..000000000 --- a/packages/bson/src/strings.ts +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Deepkit Framework - * Copyright (C) 2021 Deepkit UG, Marc J. Schmidt - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the MIT License. - * - * You should have received a copy of the MIT License along with this program. - */ - -import { BSONError } from './model.js'; - -const decoder = new TextDecoder('utf-8'); - -export function decodeUTF8(buffer: Uint8Array, off: number = 0, end: number) { - if (end - off > 512) { - return decoder.decode(buffer.slice(off, end)); - } else { - return decodeUTF8Short(buffer, off, end); - } -} - -export function decodeUTF8Short(buffer: Uint8Array, off: number = 0, end: number) { - let s = ''; - while (off < end) { - let c = buffer[off++]; - - if (c > 127) { - if (c > 191 && c < 224) { - if (off >= end) - throw new BSONError('UTF-8 decode: incomplete 2-byte sequence'); - c = (c & 31) << 6 | buffer[off++] & 63; - } else if (c > 223 && c < 240) { - if (off + 1 >= end) - throw new BSONError('UTF-8 decode: incomplete 3-byte sequence'); - c = (c & 15) << 12 | (buffer[off++] & 63) << 6 | buffer[off++] & 63; - } else if (c > 239 && c < 248) { - if (off + 2 >= end) - throw new BSONError('UTF-8 decode: incomplete 4-byte sequence'); - c = (c & 7) << 18 | (buffer[off++] & 63) << 12 | (buffer[off++] & 63) << 6 | buffer[off++] & 63; - } else throw new BSONError('UTF-8 decode: unknown multibyte start 0x' + c.toString(16) + ' at index ' + (off - 1)); - if (c <= 0xffff) { - s += String.fromCharCode(c); - } else if (c <= 0x10ffff) { - c -= 0x10000; - s += String.fromCharCode(c >> 10 | 0xd800, c & 0x3FF | 0xdc00); - } else throw new BSONError('UTF-8 decode: code point 0x' + c.toString(16) + ' exceeds UTF-16 reach'); - } else { - if (c === 0) { - return s; - } - - s += String.fromCharCode(c); - } - } - return s; -} diff --git a/packages/core/src/string.ts b/packages/core/src/string.ts index cb6f55695..3296d2ebf 100644 --- a/packages/core/src/string.ts +++ b/packages/core/src/string.ts @@ -17,3 +17,57 @@ export function indent(indentation: number, prefix: string = '') { export function capitalize(string: string): string { return string.charAt(0).toUpperCase() + string.slice(1) } + +const decoder = new TextDecoder('utf-8'); + +const decodeUTF8Fast = 'undefined' !== typeof Buffer ? Buffer.prototype.utf8Slice : undefined; + +const decodeUTF8Big = decodeUTF8Fast ? (buffer: Uint8Array, off: number, end: number) => { + return decodeUTF8Fast.call(buffer, off, end); +} : (buffer: Uint8Array, off: number, end: number) => { + return decoder.decode(buffer.slice(off, end)); +} + +export function decodeUTF8(buffer: Uint8Array, off: number = 0, end: number) { + if (end - off > 64) { + return decodeUTF8Big(buffer, off, end); + } else { + return decodeUTF8Short(buffer, off, end); + } +} + +export function decodeUTF8Short(buffer: Uint8Array, off: number = 0, end: number) { + let s = ''; + while (off < end) { + let c = buffer[off++]; + + if (c > 127) { + if (c > 191 && c < 224) { + if (off >= end) + throw new Error('UTF-8 decode: incomplete 2-byte sequence'); + c = (c & 31) << 6 | buffer[off++] & 63; + } else if (c > 223 && c < 240) { + if (off + 1 >= end) + throw new Error('UTF-8 decode: incomplete 3-byte sequence'); + c = (c & 15) << 12 | (buffer[off++] & 63) << 6 | buffer[off++] & 63; + } else if (c > 239 && c < 248) { + if (off + 2 >= end) + throw new Error('UTF-8 decode: incomplete 4-byte sequence'); + c = (c & 7) << 18 | (buffer[off++] & 63) << 12 | (buffer[off++] & 63) << 6 | buffer[off++] & 63; + } else throw new Error('UTF-8 decode: unknown multibyte start 0x' + c.toString(16) + ' at index ' + (off - 1)); + if (c <= 0xffff) { + s += String.fromCharCode(c); + } else if (c <= 0x10ffff) { + c -= 0x10000; + s += String.fromCharCode(c >> 10 | 0xd800, c & 0x3FF | 0xdc00); + } else throw new Error('UTF-8 decode: code point 0x' + c.toString(16) + ' exceeds UTF-16 reach'); + } else { + if (c === 0) { + return s; + } + + s += String.fromCharCode(c); + } + } + return s; +} diff --git a/packages/core/tests/core.spec.ts b/packages/core/tests/core.spec.ts index 97be6018d..2d7ae5224 100644 --- a/packages/core/tests/core.spec.ts +++ b/packages/core/tests/core.spec.ts @@ -620,3 +620,20 @@ test('isGlobalClass', () => { expect(isGlobalClass(Uint8Array)).toBe(true); }); + +test('typed array offset', () => { + const a = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + expect(a.byteOffset).toBe(0); + expect(a.length).toBe(10); + expect(a[0]).toBe(1); + + const b= new Uint8Array(a.buffer, 1); + expect(b.byteOffset).toBe(1); + expect(b.length).toBe(9); + expect(b[0]).toBe(2); + + const c = new Uint8Array(b.buffer, b.byteOffset + 1); + expect(c.byteOffset).toBe(2); + expect(c.length).toBe(8); + expect(c[0]).toBe(3); +}); diff --git a/packages/mongo/src/client/client.ts b/packages/mongo/src/client/client.ts index 7320ee966..089baff84 100644 --- a/packages/mongo/src/client/client.ts +++ b/packages/mongo/src/client/client.ts @@ -8,7 +8,13 @@ * You should have received a copy of the MIT License along with this program. */ -import { ConnectionRequest, MongoConnection, MongoConnectionPool, MongoDatabaseTransaction, MongoStats } from './connection.js'; +import { + ConnectionRequest, + MongoConnection, + MongoConnectionPool, + MongoDatabaseTransaction, + MongoStats, +} from './connection.js'; import { isErrorRetryableRead, isErrorRetryableWrite, MongoError } from './error.js'; import { sleep } from '@deepkit/core'; import { Command } from './command/command.js'; @@ -27,9 +33,7 @@ export class MongoClient { protected serializer: BSONBinarySerializer = mongoBinarySerializer; - constructor( - connectionString: string - ) { + constructor(connectionString: string) { this.config = new MongoClientConfig(connectionString); this.connectionPool = new MongoConnectionPool(this.config, this.serializer, this.stats); } diff --git a/packages/mongo/src/client/connection.ts b/packages/mongo/src/client/connection.ts index 2f9bb7d5f..2e9f3379c 100644 --- a/packages/mongo/src/client/connection.ts +++ b/packages/mongo/src/client/connection.ts @@ -63,8 +63,6 @@ export class MongoConnectionPool { protected queue: { resolve: (connection: MongoConnection) => void, request: ConnectionRequest }[] = []; - protected nextConnectionClose: Promise = Promise.resolve(true); - protected lastError?: Error; constructor( @@ -172,7 +170,7 @@ export class MongoConnectionPool { waiter.resolve(connection); //we don't set reserved/set cleanupTimeout, //since the connection is already reserved and the timeout - //is only set when the connection actually starting idling. + //is only set when the connection actually is starting idling. return; } @@ -435,7 +433,7 @@ export class MongoConnection { /** * Puts a command on the queue and executes it when queue is empty. - * A promises is return that is resolved with the when executed successfully, or rejected + * A promise is returned that is resolved when executed successfully, or rejected * when timed out, parser error, or any other error. */ public async execute>(command: T): Promise> { @@ -497,9 +495,9 @@ export class MongoConnection { writer.writeInt32(messageLength); //detect backPressure + this.socket.write(buffer); this.bytesSent += buffer.byteLength; this.onSent(buffer.byteLength); - this.socket.write(buffer); } catch (error) { console.log('failed sending message', message, 'for type', stringifyType(type)); throw error; diff --git a/packages/mongo/src/client/options.ts b/packages/mongo/src/client/options.ts index b372a98f4..10533753e 100644 --- a/packages/mongo/src/client/options.ts +++ b/packages/mongo/src/client/options.ts @@ -27,7 +27,7 @@ export class ConnectionOptions { journal?: string; appName?: string; - retryWrites: boolean = true; + retryWrites: boolean = false; retryReads: boolean = true; readConcernLevel: 'local' | 'majority' | 'linearizable' | 'available' = 'majority'; diff --git a/packages/mongo/src/query.resolver.ts b/packages/mongo/src/query.resolver.ts index cd2930034..532bd7944 100644 --- a/packages/mongo/src/query.resolver.ts +++ b/packages/mongo/src/query.resolver.ts @@ -8,7 +8,17 @@ * You should have received a copy of the MIT License along with this program. */ -import { DatabaseAdapter, DatabaseDeleteError, DatabasePatchError, DatabaseSession, DeleteResult, Formatter, GenericQueryResolver, OrmEntity, PatchResult } from '@deepkit/orm'; +import { + DatabaseAdapter, + DatabaseDeleteError, + DatabasePatchError, + DatabaseSession, + DeleteResult, + Formatter, + GenericQueryResolver, + OrmEntity, + PatchResult, +} from '@deepkit/orm'; import { Changes, getPartialSerializeFunction, @@ -314,7 +324,7 @@ export class MongoQueryResolver extends GenericQueryResolve } public async find(model: MongoQueryModel): Promise { - const formatter = this.createFormatter(model.withIdentityMap); + const formatter9 = this.createFormatter(model.withIdentityMap); const connection = await this.client.getConnection(undefined, this.session.assignedTransaction); try { diff --git a/packages/orm/src/database-session.ts b/packages/orm/src/database-session.ts index 4ea3d2226..3672a7e35 100644 --- a/packages/orm/src/database-session.ts +++ b/packages/orm/src/database-session.ts @@ -292,6 +292,7 @@ export abstract class DatabaseTransaction { export class DatabaseSession { public readonly id = SESSION_IDS++; public withIdentityMap = true; + public withChangeDetection = true; /** * When this session belongs to a transaction, then this is set. diff --git a/packages/orm/src/formatter.ts b/packages/orm/src/formatter.ts index 596bb41ea..0499ffc63 100644 --- a/packages/orm/src/formatter.ts +++ b/packages/orm/src/formatter.ts @@ -63,9 +63,12 @@ export class Formatter { protected serializer: Serializer, protected hydrator?: HydratorFn, protected identityMap?: IdentityMap, + protected withChangeDetection: boolean = true, ) { this.deserialize = getSerializeFunction(rootClassSchema.type, serializer.deserializeRegistry); this.partialDeserialize = getPartialSerializeFunction(rootClassSchema.type, serializer.deserializeRegistry); + if (identityMap) throw new Error('nope'); + if (withChangeDetection) throw new Error('nope'); } protected getInstancePoolForClass(classType: ClassType): Map { @@ -274,7 +277,7 @@ export class Formatter { const converted = this.createObject(model, classState, classSchema, dbRecord); if (!partial) { - if (model.withChangeDetection !== false) getInstanceState(classState, converted).markAsPersisted(); + if (this.withChangeDetection) getInstanceState(classState, converted).markAsPersisted(); if (pool) pool.set(pkHash, converted); if (this.identityMap) this.identityMap.store(classSchema, converted); } @@ -336,7 +339,7 @@ export class Formatter { : (partial ? getPartialSerializeFunction(classSchema.type, this.serializer.deserializeRegistry)(dbRecord) : getSerializeFunction(classSchema.type, this.serializer.deserializeRegistry)(dbRecord)); if (!partial) { - if (model.withChangeDetection !== false) getInstanceState(classState, converted).markAsFromDatabase(); + if (this.withChangeDetection) getInstanceState(classState, converted).markAsFromDatabase(); } // if (!partial && model.lazyLoad.size) { diff --git a/packages/orm/src/select.ts b/packages/orm/src/select.ts index ac5a56cfe..100b32ea8 100644 --- a/packages/orm/src/select.ts +++ b/packages/orm/src/select.ts @@ -268,7 +268,12 @@ export type OpExpression = { export type Op = ((...args: any[]) => OpExpression) & { id: symbol }; export function getStateCacheId(state: SelectorState): string { - const cacheId = state.schema.type.id + '_' + state.where?.[treeTag].id + '_' + state.orderBy?.map(v => v.a[treeTag].id).join(':'); + const cacheId = state.schema.type.id + + '_' + state.where?.[treeTag].id + + '_' + state.limit + + '_' + state.offset + + '_' + state.select.map(v => v[treeTag].id).join(':') + + '_' + state.orderBy?.map(v => v.a[treeTag].id).join(':'); //todo select also // todo join also return cacheId; @@ -545,6 +550,20 @@ export class Query2 { return this; } + /** + * When receiving full objects the change-detector is enabled by default + * to be able to calculate change sets for database.persist()/session.commit(). + * + * If disabled, it is impossible to send updates via database.persist()/session.commit(), + * and patchOne/patchMany has to be used. + * + * This is disabled per default for partial results. + */ + disableChangeDetection(): this { + this.state.withChangeDetection = false; + return this; + } + protected async callOnFetchEvent(query: Query2): Promise { const hasEvents = this.session.eventDispatcher.hasListeners(onFind); if (!hasEvents) return; @@ -607,24 +626,24 @@ export class Query2 { * @throws DatabaseError */ public async find(): Promise { - const frame = this.session - .stopwatch?.start('Find:' + this.classSchema.getClassName(), FrameCategory.database); + // const frame = this.session + // .stopwatch?.start('Find:' + this.classSchema.getClassName(), FrameCategory.database); try { - frame?.data({ - collection: this.classSchema.getCollectionName(), - className: this.classSchema.getClassName(), - }); - const eventFrame = this.session.stopwatch?.start('Events'); - await this.callOnFetchEvent(this); - this.onQueryResolve(this); - eventFrame?.end(); + // frame?.data({ + // collection: this.classSchema.getCollectionName(), + // className: this.classSchema.getClassName(), + // }); + // const eventFrame = this.session.stopwatch?.start('Events'); + // await this.callOnFetchEvent(this); + // this.onQueryResolve(this); + // eventFrame?.end(); return await this.resolver.find(this.state) as T[]; } catch (error: any) { await this.session.eventDispatcher.dispatch(onDatabaseError, new DatabaseErrorEvent(error, this.session, this.state.schema, this)); throw error; } finally { - frame?.end(); + // frame?.end(); } } diff --git a/packages/postgres/src/client.ts b/packages/postgres/src/client.ts new file mode 100644 index 000000000..a49cc6a8e --- /dev/null +++ b/packages/postgres/src/client.ts @@ -0,0 +1,1388 @@ +import { connect, createConnection, Socket } from 'net'; +import { arrayRemoveItem, asyncOperation, CompilerContext, decodeUTF8, formatError } from '@deepkit/core'; +import { DatabaseError, DatabaseTransaction, SelectorState } from '@deepkit/orm'; +import { + getSerializeFunction, + getTypeJitContainer, + isPropertyType, + ReceiveType, + ReflectionKind, + resolveReceiveType, + Type, + TypeClass, + TypeObjectLiteral, +} from '@deepkit/type'; +import { connect as createTLSConnection, TLSSocket } from 'tls'; +import { Host, PostgresClientConfig } from './config.js'; +import { DefaultPlatform, PreparedEntity } from '@deepkit/sql'; + +export class PostgresError extends DatabaseError { + +} + +export class PostgresConnectionError extends PostgresError { + +} + +export interface Result { + rows: any[]; + rowCount: number; +} + +function readUint32BE(data: Uint8Array, offset: number = 0): number { + return data[offset + 3] + (data[offset + 2] * 2 ** 8) + (data[offset + 1] * 2 ** 16) + (data[offset] * 2 ** 24); +} + +function readUint16BE(data: Uint8Array, offset: number = 0): number { + return data[offset + 1] + (data[offset] * 2 ** 8); +} + +function buildDeserializerForType(type: Type): (message: Uint8Array) => any { + if (type.kind !== ReflectionKind.class && type.kind !== ReflectionKind.objectLiteral) { + throw new Error('Invalid type for deserialization'); + } + + const context = new CompilerContext(); + const lines: string[] = []; + const props: string[] = []; + context.set({ + DataView, + decodeUTF8, + readUint32BE, + parseJson: JSON.parse, + }); + + for (const property of type.types) { + const varName = context.reserveVariable(); + if (!isPropertyType(property)) continue; + const field = property.type; + + if (field.kind === ReflectionKind.number) { + lines.push(` + length = readUint32BE(data, offset); + ${varName} = length === 4 ? view.getFloat32(offset + 4) : view.getFloat64(offset + 4); + offset += 4 + length; + `); + } + if (field.kind === ReflectionKind.boolean) { + lines.push(` + ${varName} = data[offset + 4] === 1; + offset += 4 + 1; + `); + } + if (field.kind === ReflectionKind.string) { + lines.push(` + length = readUint32BE(data, offset); + ${varName} = decodeUTF8(data, offset + 4, offset + 4 + length); + offset += 4 + length; + `); + } + if (field.kind === ReflectionKind.class || field.kind === ReflectionKind.union + || field.kind === ReflectionKind.array || field.kind === ReflectionKind.objectLiteral) { + lines.push(` + length = readUint32BE(data, offset); + ${varName} = parseJson(decodeUTF8(data, offset + 4 + 1, offset + 4 + length)); + offset += 4 + length; + `); + } + props.push(`${String(property.name)}: ${varName},`); + } + + const code = ` + const view = new DataView(data.buffer, data.byteOffset, data.byteLength); + let offset = 1 + 4 + 2; // Skip type, length, and field count + let length = 0; + ${lines.join('\n')} + + result.push({ + ${props.join('\n')} + }); + `; + return context.build(code, 'data', 'result'); +} + +function buildDeserializer(selector: SelectorState): (message: Uint8Array) => any { + const context = new CompilerContext(); + const lines: string[] = []; + const props: string[] = []; + context.set({ + DataView, + decodeUTF8, + readUint32BE, + parseJson: JSON.parse, + }); + + for (const field of selector.schema.getProperties()) { + const varName = context.reserveVariable(); + + if (field.type.kind === ReflectionKind.number) { + lines.push(` + length = readUint32BE(data, offset); + ${varName} = length === 4 ? view.getFloat32(offset + 4) : view.getFloat64(offset + 4); + offset += 4 + length; + `); + } + if (field.type.kind === ReflectionKind.boolean) { + lines.push(` + ${varName} = data[offset + 4] === 1; + offset += 4 + 1; + `); + } + if (field.type.kind === ReflectionKind.string) { + lines.push(` + length = readUint32BE(data, offset); + // ${varName} = decodeUTF8(data, offset + 4, offset + 4 + length); + ${varName} = ''; + offset += 4 + length; + `); + } + if (field.type.kind === ReflectionKind.class || field.type.kind === ReflectionKind.union + || field.type.kind === ReflectionKind.array || field.type.kind === ReflectionKind.objectLiteral) { + lines.push(` + length = readUint32BE(data, offset); + ${varName} = parseJson(decodeUTF8(data, offset + 4 + 1, offset + 4 + length)); + offset += 4 + length; + `); + } + props.push(`${field.name}: ${varName},`); + } + + const code = ` + const view = new DataView(data.buffer, data.byteOffset, data.byteLength); + let offset = 1 + 4 + 2; // Skip type, length, and field count + let length = 0; + ${lines.join('\n')} + + result.push({ + ${props.join('\n')} + }); + `; + return context.build(code, 'data', 'result'); +} + +export class PostgresClientPrepared { + created = false; + + cache?: Buffer; + deserialize: (message: Uint8Array) => any; + + constructor( + public client: PostgresClientConnection, + public sql: string, + public selector: SelectorState, + private statementName: string, + ) { + const serializer = this.selector.schema && this.selector.select.length === 0; + this.deserialize = serializer + ? buildDeserializer(this.selector) + : () => { + }; + } + + execute(params: any[]): Promise { + // console.log('execute statement', this.statementName, this.sql.slice(0, 100), params.slice(0, 10)) ; + return asyncOperation((resolve, reject) => { + if (!this.created) { + const oids = params.map(param => { + if (typeof param === 'number') return 23; + if (typeof param === 'boolean') return 16; + return 0; // Text + }); + + const message = sendParse(this.statementName, this.sql, oids); + this.client.write(message); + this.created = true; + } + + if (!this.cache) { + this.cache = Buffer.concat([ + sendBind('ab', this.statementName, params), + sendExecute('ab', 0), + syncMessage, + ]); + } + + // this.client.write(this.cache); + // + // const rows: any[] = []; + // this.deserialize(empty); //reset state machine + // this.client.deserialize = (message) => { + // try { + // const row = this.deserialize(message); + // rows.push(row); + // } catch (e) { + // console.log('e', e); + // reject(e); + // } + // }; + // this.client.ready = () => { + // resolve({ rows, rowCount: rows.length }); + // }; + // this.client.error = reject; + }); + } +} + +type Client = ReturnType; + +function int32Buffer(value) { + const buffer = Buffer.alloc(4); + buffer.writeUint32BE(value); + return buffer; +} + +function int16Buffer(value) { + const buffer = Buffer.alloc(2); + buffer.writeUint16BE(value); + return buffer; +} + +function sendParse(statementName: string, query: string, paramTypeOids: any[]) { + const nameBuffer = Buffer.from(statementName + '\0', 'utf8'); + const queryBuffer = Buffer.from(query + '\0', 'utf8'); + const typeCountBuffer = int16Buffer(paramTypeOids.length); + const typesBuffer = Buffer.concat(paramTypeOids.map(oid => int32Buffer(oid))); + + const message = Buffer.concat([ + Buffer.from('P'), // Parse message + Buffer.from([0, 0, 0, 1]), + nameBuffer, + queryBuffer, + typeCountBuffer, + typesBuffer, + ]); + message.writeUint32BE(message.length - 1, 1); + + return message; +} + +function sendBind(portalName: string, statementName: string, parameters: any[]) { + const portalBuffer = Buffer.from(portalName + '\0', 'utf8'); + const statementBuffer = Buffer.from(statementName + '\0', 'utf8'); + + const paramCountBuffer = int16Buffer(parameters.length); + + const paramBuffers = Buffer.concat(parameters.map(param => { + const paramBuffer = Buffer.from(param + '', 'utf8'); + const paramLengthBuffer = int32Buffer(paramBuffer.length); + return Buffer.concat([paramLengthBuffer, paramBuffer]); + })); + + const message = Buffer.concat([ + Buffer.from('B'), // Bind message + Buffer.from([0, 0, 0, 1]), + portalBuffer, + statementBuffer, + + Buffer.from([0, 1, 0, 0]), // (1, 1) = (amount, text) + // Buffer.from([0, 1, 0, 1]), // (1, 1) = (amount, binary), Int16 + (Int16[C]) + + paramCountBuffer, //Int16 + paramBuffers, //(Int32, Bytes)[] + Buffer.from([0, 1, 0, 1]), // Result column format codes (1, 1) = (amount, binary), Int16 + (Int16[C]) + ]); + message.writeUint32BE(message.length - 1, 1); + + return message; +} + +function getSimpleQueryBuffer(query: string) { + const queryBuffer = Buffer.from(query + '\0', 'utf8'); + const message = Buffer.concat([ + Buffer.from('Q'), // Query message + int32Buffer(queryBuffer.length + 4), + queryBuffer, + ]); + return message; +} + +function sendExecute(portalName: string, maxRows: number) { + const portalBuffer = Buffer.from(portalName + '\0', 'utf8'); + const maxRowsBuffer = int32Buffer(maxRows); + + const message = Buffer.concat([ + Buffer.from('E'), // Execute message + Buffer.from([0, 0, 0, 1]), + portalBuffer, + maxRowsBuffer, + ]); + message.writeUint32BE(message.length - 1, 1); + return message; +} + +const syncMessage = Buffer.from('S\0\0\0\x04'); + +export function readMessageSize(buffer: Uint8Array | ArrayBuffer, offset: number = 0): number { + return 1 + (buffer[offset + 4] + (buffer[offset + 3] * 2 ** 8) + (buffer[offset + 2] * 2 ** 16) + (buffer[offset + 1] * 2 ** 24)); +} + +export class ResponseParser { + protected currentMessage?: Uint8Array; + protected currentMessageSize: number = 0; + + constructor( + protected readonly onMessage: (response: Uint8Array) => void, + ) { + } + + public feed(data: Uint8Array, bytes?: number) { + if (!data.byteLength) return; + if (!bytes) bytes = data.byteLength; + + // console.log('got chunk', data.length); + if (!this.currentMessage) { + if (data.byteLength < 5) { + //not enough data to read the header. Wait for next onData + return; + } + this.currentMessage = data.byteLength === bytes ? data : data.subarray(0, bytes); + this.currentMessageSize = readMessageSize(data); + } else { + this.currentMessage = Buffer.concat([this.currentMessage, data.byteLength === bytes ? data : data.subarray(0, bytes)]); + if (!this.currentMessageSize) { + if (this.currentMessage.byteLength < 5) { + //not enough data to read the header. Wait for next onData + return; + } + this.currentMessageSize = readMessageSize(this.currentMessage); + } + } + + let currentSize = this.currentMessageSize; + let currentBuffer = this.currentMessage; + + while (currentBuffer) { + if (currentSize > currentBuffer.byteLength) { + //important to copy, since the incoming might change its data + this.currentMessage = currentBuffer; + // this.currentMessage = currentBuffer; + this.currentMessageSize = currentSize; + //message not completely loaded, wait for next onData + return; + } + + if (currentSize === currentBuffer.byteLength) { + //current buffer is exactly the message length + this.currentMessageSize = 0; + this.currentMessage = undefined; + this.onMessage(currentBuffer); + return; + } + + if (currentSize < currentBuffer.byteLength) { + //we have more messages in this buffer. read what is necessary and hop to next loop iteration + // const message = currentBuffer.subarray(0, currentSize); + // console.log('onMessage', currentSize, message.length) + this.onMessage(currentBuffer); + + currentBuffer = currentBuffer.subarray(currentSize); + // currentBuffer = new Uint8Array(currentBuffer.buffer, currentBuffer.byteOffset + currentSize); + if (currentBuffer.byteLength < 5) { + //not enough data to read the header. Wait for next onData + this.currentMessage = currentBuffer; + this.currentMessageSize = 0; + return; + } + + const nextCurrentSize = readMessageSize(currentBuffer); + if (nextCurrentSize <= 0) throw new Error('message size wrong'); + currentSize = nextCurrentSize; + //buffer and size has been set. consume this message in the next loop iteration + } + } + } +} + +// /** +// * @reflection never +// */ +// export class PostgresClient { +// connectionPromise?: Promise; +// client?: ReturnType; +// +// deserialize: (message: Uint8Array) => void = () => { +// }; +// ready?: (value: any) => void; +// error?: (value: any) => void; +// +// statementId = 0; +// +// responseParser: ResponseParser; +// +// constructor(public config: PostgresConfig) { +// this.responseParser = new ResponseParser(this.onResponse.bind(this)); +// } +// +// onResponse(data: Uint8Array) { +// switch (data[0]) { +// case 0x45: // E +// const length = readUint32BE(data, 1); +// const error = Buffer.from(data).toString('utf8', 5, length); +// if (this.error) this.error(error); +// break; +// case 0x44: // D +// this.deserialize(data); +// break; +// case 0x5A: // Z +// if (this.ready) this.ready({ rows: [], rowCount: 0 }); +// break; +// } +// } +// +// // onResponse(data: Uint8Array) { +// // // const view = new DataView(data.buffer, data.byteOffset, data.byteLength); +// // // const length = view.getUint32(offset + 1); +// // // const length = readUint32BE(data, 1); +// // +// // // console.log(`response: ${String.fromCharCode(data[0])}, buffer: ${data.length}, length: ${length}`); +// // +// // const type = String.fromCharCode(data[0]); +// // if (type === 'R') { +// // // const authType = view.getUint32(offset + 5); +// // // console.log(`Authentication Type: ${authType}`); +// // } else if (type === 'S') { +// // // const content = data.subarray(offset + 5, offset + length); +// // // const [param, value] = content.toString().split('\0'); +// // // console.log(`Parameter: ${param}, Value: ${value}`); +// // } else if (type === 'K') { +// // // const pid = view.getUint32(offset + 5); +// // // const secret = view.getUint32(offset + 9); +// // // console.log(`PID: ${pid}, Secret: ${secret}`); +// // } else if (type === 'Z') { +// // // console.log('Ready for query'); +// // if (this.ready) this.ready({ rows: [], rowCount: 0 }); +// // } else if (type === 'E') { +// // const length = readUint32BE(data, 1); +// // const error = Buffer.from(data).toString('utf8', 5, length); +// // // console.log('Error:', error); +// // if (this.error) this.error(error); +// // } else if (type === 'C') { +// // // const commandTag = Buffer.from(data).toString('utf8', offset + 5, offset + length); +// // // console.log('Command Complete:', commandTag); +// // } else if (type === '2') { +// // // console.log('Bind Complete'); +// // } else if (type === '1') { +// // // console.log('Parse Complete'); +// // } else if (type === 'D') { +// // // const columnValues = view.getUint16(offset + 5); +// // // console.log('Data Row:', columnValues); +// // // for (let i = 0; i < columnValues; i++) { +// // // const length = view.getUint32(offset + 7); +// // // const value = Buffer.from(data).toString('utf8', offset + 11, offset + 11 + length); +// // // console.log(`Column ${i} (${length}): ${value}`, Buffer.from(data).toString('hex', offset + 11, offset + 11 + length)); +// // // offset += 4 + length; +// // // } +// // this.deserialize(data); +// // +// // // const endOfMessage = offset + length; +// // // let fieldOffset = offset + 5; // Skip type and length +// // +// // // while (fieldOffset < endOfMessage) { +// // // const fieldLength = data.readInt32BE(fieldOffset); +// // // const fieldValue = data.toString('utf8', fieldOffset + 4, fieldOffset + 4 + fieldLength); +// // // console.log(`Field: ${fieldValue}`); +// // // fieldOffset += 4 + fieldLength; +// // // } +// // } else if (type === 'N') { +// // // const endOfMessage = offset + length; +// // // let fieldOffset = offset + 5; // Skip type and length +// // // console.log('Notice Response:'); +// // // while (fieldOffset < endOfMessage) { +// // // const fieldType = String.fromCharCode(data[fieldOffset]); +// // // const fieldValue = Buffer.from(data).toString('utf8', fieldOffset + 1, endOfMessage).split('\0')[0]; +// // // console.log(`${fieldType}: ${fieldValue}`); +// // // fieldOffset += 1 + Buffer.byteLength(fieldValue) + 1; // Move past the field type, value, and null terminator +// // // } +// // } else { +// // console.log('Unknown message type:', type); +// // } +// // } +// +// protected async doConnect(): Promise { +// const client = this.client = connect({ +// port: this.config.port || 5432, +// host: this.config.host || 'localhost', +// }, () => { +// const parameters = 'user\0postgres\0database\0postgres\0\0'; +// const protocolVersion = Buffer.from([0x00, 0x03, 0x00, 0x00]); // Version 3.0 +// const totalLength = Buffer.alloc(4); +// totalLength.writeInt32BE(Buffer.byteLength(parameters) + protocolVersion.length + 4); +// +// const startupMessage = Buffer.concat([totalLength, protocolVersion, Buffer.from(parameters)]); +// client.write(startupMessage); +// }); +// +// client.setNoDelay(true); +// // client.setKeepAlive(true, 0); +// +// return asyncOperation((resolve, reject) => { +// this.ready = resolve; +// this.error = reject; +// client.on('data', (data) => { +// this.responseParser.feed(data); +// }); +// +// client.on('end', () => { +// console.log('Disconnected from server'); +// }); +// }); +// } +// +// connect() { +// if (this.client) return; +// if (!this.connectionPromise) { +// this.connectionPromise = this.doConnect(); +// } +// +// return this.connectionPromise; +// } +// +// async query(sql: string, params: any[]): Promise { +// await this.connect(); +// return asyncOperation((resolve, reject) => { +// // const queryString = `${sql}\0`; +// // const length = Buffer.alloc(4); +// // length.writeInt32BE(Buffer.byteLength(queryString) + 4); // Length + size of length field +// // const queryMessage = Buffer.concat([Buffer.from([0x51]), length, Buffer.from(queryString)]); +// // console.log('Sending query:', sql, params); +// // this.client!.write(queryMessage); +// +// const oids = params.map(param => { +// if (typeof param === 'number') return 23; +// if (typeof param === 'boolean') return 16; +// return 0; // Text +// }); +// +// this.deserialize = (message) => { +// }; +// this.ready = resolve; +// this.error = reject; +// +// // console.log('query', sql.slice(0, 100), params.slice(0, 10)); +// const statement = `s${this.statementId++}`; +// const message = Buffer.concat([ +// sendParse(this.client!, statement, sql, oids), +// sendBind(this.client!, 'ab', statement, params), +// sendExecute(this.client!, 'ab', 0), +// syncMessage, +// ]); +// this.client!.write(message); +// +// // resolve({ rows: [], rowCount: 0 }); +// }); +// } +// +// async prepare(sql: string, selector: SelectorState): Promise { +// await this.connect(); +// const statement = `s${this.statementId++}`; +// return new PostgresClientPrepared(this, sql, selector, statement); +// } +// } + +export enum ConnectionStatus { + pending = 'pending', + connecting = 'connecting', + connected = 'connected', + disconnected = 'disconnected', +} + +export interface ConnectionRequest { + readonly: boolean; + nearest: boolean; +} + +export const enum HostType { + primary, + secondary, +} + +export type TransactionTypes = 'REPEATABLE READ' | 'READ COMMITTED' | 'SERIALIZABLE'; + +export class PostgresDatabaseTransaction extends DatabaseTransaction { + connection?: PostgresClientConnection; + + setTransaction?: TransactionTypes; + + /** + * This is the default for mysql databases. + */ + repeatableRead(): this { + this.setTransaction = 'REPEATABLE READ'; + return this; + } + + readCommitted(): this { + this.setTransaction = 'READ COMMITTED'; + return this; + } + + serializable(): this { + this.setTransaction = 'SERIALIZABLE'; + return this; + } + + async begin() { + if (!this.connection) return; + const set = this.setTransaction ? 'SET TRANSACTION ISOLATION LEVEL ' + this.setTransaction + ';' : ''; + // await this.connection.run(set + 'START TRANSACTION'); + } + + async commit() { + if (!this.connection) return; + if (this.ended) throw new Error('Transaction ended already'); + + // await this.connection.run('COMMIT'); + this.ended = true; + this.connection.release(); + } + + async rollback() { + if (!this.connection) return; + + if (this.ended) throw new Error('Transaction ended already'); + // await this.connection.run('ROLLBACK'); + this.ended = true; + this.connection.release(); + } +} + +type Deserializer = (message: Uint8Array) => any; + +function noopDeserializer(message: Uint8Array) { +} + +export abstract class Command { + current?: { resolve: Function, reject: Function, deserializer: Deserializer }; + + public write: (buffer: Uint8Array) => void = () => void 0; + + public sendAndWait(message: Uint8Array, deserializer: Deserializer = noopDeserializer): Promise { + this.write(message); + return asyncOperation((resolve, reject) => { + this.current = { + resolve, + reject, + deserializer, + }; + }); + } + + public async query(sql: string, type?: ReceiveType) { + let deserializer: any = () => { + }; + + if (type) { + type = resolveReceiveType(type); + const jit = getTypeJitContainer(type); + deserializer = jit.sqlQueryDeserializer; + if (!deserializer) { + deserializer = jit.sqlQueryDeserializer = buildDeserializerForType(type); + } + } + const buffer = getSimpleQueryBuffer(sql); + + const rows: any[] = []; + await this.sendAndWait(buffer, (buffer) => { + deserializer(buffer, rows); + }); + return rows[0]; + } + + abstract execute(host: Host, connection: PostgresClientConnection, transaction?: PostgresDatabaseTransaction): Promise; + + needsWritableHost(): boolean { + return false; + } + + handleResponse(response: Uint8Array): void { + switch (response[0]) { + case 0x45: { // E + const length = readUint32BE(response, 1); + const error = decodeUTF8(response, 5, length); + this.current!.reject(error); + break; + } + case 0x44: { // D + this.current!.deserializer(response); + // const columns = readUint16BE(response, 5); + // let offset = 1 + 4 + 2; + // console.log('columns', columns); + // for (let i = 0; i < columns; i++) { + // const length = readUint32BE(response, offset); + // console.log('column', i, length, Buffer.from(response).toString('hex', offset, offset + length)); + // offset += 4 + length; + // } + break; + } + case 0x5A: {// Z + this.current!.resolve(); + break; + } + default: { + // console.log('unhandled', String.fromCharCode(response[0]), response); + } + } + } +} + +export class InsertCommand extends Command { + constructor( + public platform: DefaultPlatform, + public prepared: PreparedEntity, + public items: any[], + ) { + super(); + } + + async execute(host: Host, client: PostgresClientConnection, transaction?: PostgresDatabaseTransaction) { + // build SQL + // command: parse + // command: bind + // command: execute + // command: sync + + const names: string[] = []; + const placeholder = new this.platform.placeholderStrategy; + const scopeSerializer = getSerializeFunction(this.prepared.type, this.platform.serializer.serializeRegistry); + + for (const property of this.prepared.fields) { + if (property.autoIncrement) continue; + names.push(property.columnNameEscaped); + } + + const placeholders: string[] = []; + const params: any[] = []; + for (const item of this.items) { + const converted = scopeSerializer(item); + const values: string[] = []; + for (const property of this.prepared.fields) { + if (property.autoIncrement) continue; + values.push(placeholder.getPlaceholder()); + params.push(converted[property.name]); + } + placeholders.push(values.join(', ')); + } + + const sql = `INSERT INTO ${this.prepared.tableNameEscaped} (${names.join(', ')}) + VALUES (${placeholders.join('), (')})`; + + const statement = ''; + + const message = Buffer.concat([ + sendParse(statement, sql, []), + sendBind('', statement, params), + sendExecute('', 0), + syncMessage, + ]); + await this.sendAndWait(message); + + // parse data rows: auto-incremented columns + + return 0; + } +} + +export class FindCommand extends Command { + cache?: Buffer; + deserialize: (message: Uint8Array, rows: any[]) => any; + created = false; + + constructor( + public sql: string, + public prepared: PreparedEntity, + public selector: SelectorState, + private statementName: string, + ) { + super(); + const serializer = this.selector.select.length === 0; + this.deserialize = serializer + ? buildDeserializer(this.selector) + : () => { + //todo more complex deserializer + }; + console.log('new FindCommand'); + } + + setParameters(params: any[]) { + if (!params.length) return; + this.cache = Buffer.concat([ + sendBind('ab', this.statementName, params), + sendExecute('ab', 0), + syncMessage, + ]); + } + + async execute(host: Host, connection: PostgresClientConnection, transaction?: PostgresDatabaseTransaction) { + if (!this.created) { + const message = sendParse(this.statementName, this.sql, []); + this.write(message); + this.created = true; + } + + if (!this.cache) { + this.cache = Buffer.concat([ + sendBind('ab', this.statementName, this.selector.params), + sendExecute('ab', 0), + syncMessage, + ]); + } + + const rows: any[] = []; + await this.sendAndWait(this.cache, (data) => this.deserialize(data, rows)); + return rows; + } +} + +type InRecorveryResponse = { inRecovery: boolean }; + +class SelectCommand extends Command { + constructor( + private sql: string, + private type?: TypeClass | TypeObjectLiteral, + ) { + super(); + } + + async execute(host: Host, connection: PostgresClientConnection, transaction?: PostgresDatabaseTransaction) { + const res = await this.query(this.sql, this.type); + return true; + } +} + +class HandshakeCommand extends Command { + host?: Host; + + async execute(host: Host, connection: PostgresClientConnection, transaction?: PostgresDatabaseTransaction) { + this.host = host; + const parameters = 'user\0postgres\0database\0postgres\0\0'; + const protocolVersion = Buffer.from([0x00, 0x03, 0x00, 0x00]); // Version 3.0 + const totalLength = Buffer.alloc(4); + totalLength.writeInt32BE(Buffer.byteLength(parameters) + protocolVersion.length + 4); + + const startupMessage = Buffer.concat([totalLength, protocolVersion, Buffer.from(parameters)]); + await this.sendAndWait(startupMessage); + return true; + } + + handleResponse(response: Uint8Array): void { + switch (response[0]) { + // S + case 0x53: { + if (!this.host) break; + const length = readUint32BE(response, 1); + const firstNull = response.indexOf(0, 5); + const name = decodeUTF8(response, 5, firstNull); + const value = decodeUTF8(response, firstNull + 1, length); + this.host.parameters[name] = value; + if (name === 'in_hot_standby' && value === 'on') { + this.host.setType(HostType.secondary); + } + return; + } + } + + super.handleResponse(response); + } +} + +export class PostgresStats { + /** + * How many connections have been created. + */ + connectionsCreated: number = 0; + + /** + * How many connections have been reused. + */ + connectionsReused: number = 0; + + /** + * How many connection requests were queued because pool was full. + */ + connectionsQueued: number = 0; + + bytesReceived: number = 0; + bytesSent: number = 0; +} + +export class PostgresConnectionPool { + protected connectionId: number = 0; + /** + * Connections, might be in any state, not necessarily connected. + */ + public connections: PostgresClientConnection[] = []; + + protected cacheHints: { [key: string]: PostgresClientConnection } = {}; + + protected queue: { resolve: (connection: PostgresClientConnection) => void, request: ConnectionRequest }[] = []; + + protected lastError?: Error; + + constructor( + protected config: PostgresClientConfig, + protected stats: PostgresStats, + ) { + } + + protected async waitForAllConnectionsToConnect(throws: boolean = false): Promise { + const promises: Promise[] = []; + for (const connection of this.connections) { + if (connection.connectingPromise) { + promises.push(connection.connectingPromise); + } + } + + if (!promises.length) return; + // try { + if (throws) { + await Promise.all(promises); + } else { + await Promise.allSettled(promises); + } + // } catch (error: any) { + // throw new PostgresConnectionError(`Failed to connect: ${formatError(error)}`, { cause: error }); + // } + } + + public async connect() { + await this.ensureHostsConnected(true); + } + + public close() { + //import to work on the copy, since Connection.onClose modifies this.connections. + const connections = this.connections.slice(0); + for (const connection of connections) { + connection.close(); + } + } + + protected ensureHostsConnectedPromise?: Promise; + + public async ensureHostsConnected(throws: boolean = false) { + if (this.ensureHostsConnectedPromise) return this.ensureHostsConnectedPromise; + //make sure each host has at least one connection + const hosts = await this.config.getHosts(); + for (const host of hosts) { + if (host.connections.length > 0) continue; + this.newConnection(host); + } + + return this.ensureHostsConnectedPromise = asyncOperation(async (resolve) => { + await this.waitForAllConnectionsToConnect(throws); + resolve(undefined); + }).then(() => { + this.ensureHostsConnectedPromise = undefined; + }); + } + + protected findHostForRequest(hosts: Host[], request: ConnectionRequest): Host { + //todo, handle request.nearest + for (const host of hosts) { + if (!request.readonly && host.isWritable()) return host; + if (request.readonly && host.isReadable()) return host; + } + + throw new PostgresConnectionError(`Could not find host for connection request. (readonly=${request.readonly}, hosts=${hosts.length}). Last Error: ${this.lastError}`); + } + + protected createAdditionalConnectionForRequest(request: ConnectionRequest): PostgresClientConnection { + const hosts = this.config.hosts; + const host = this.findHostForRequest(hosts, request); + + return this.newConnection(host); + } + + protected newConnection(host: Host): PostgresClientConnection { + this.stats.connectionsCreated++; + const connection = new PostgresClientConnection(this.connectionId++, host, this.config, (connection) => { + arrayRemoveItem(host.connections, connection); + arrayRemoveItem(this.connections, connection); + //onClose does not automatically reconnect. Only new commands re-establish connections. + }, (connection) => { + this.release(connection); + }, (bytesSent) => { + this.stats.bytesSent += bytesSent; + }, (bytesReceived) => { + this.stats.bytesReceived += bytesReceived; + }); + host.connections.push(connection); + this.connections.push(connection); + return connection; + } + + protected release(connection: PostgresClientConnection) { + for (let i = 0; i < this.queue.length; i++) { + const waiter = this.queue[i]; + if (!this.matchRequest(connection, waiter.request)) continue; + + this.stats.connectionsReused++; + this.queue.splice(i, 1); + waiter.resolve(connection); + //we don't set reserved/set cleanupTimeout, + //since the connection is already reserved and the timeout + //is only set when the connection actually is starting idling. + return; + } + + connection.reserved = false; + connection.cleanupTimeout = setTimeout(() => { + if (this.connections.length <= this.config.options.minPoolSize) { + return; + } + + connection.close(); + }, this.config.options.maxIdleTimeMS); + } + + protected matchRequest(connection: PostgresClientConnection, request: ConnectionRequest): boolean { + if (!request.readonly && !connection.host.isWritable()) return false; + + if (!request.readonly) { + if (connection.host.isSecondary() && !this.config.options.secondaryReadAllowed) return false; + if (!connection.host.isReadable()) return false; + } + + return true; + } + + /** + * Returns an existing or new connection, that needs to be released once done using it. + */ + async getConnection( + request: Partial = {}, + cacheHint?: string, + ): Promise { + const r = Object.assign({ readonly: false, nearest: false }, request) as ConnectionRequest; + + if (cacheHint) { + const connection = this.cacheHints[cacheHint]; + if (connection && connection.isConnected() && !connection.reserved) { + this.stats.connectionsReused++; + connection.reserved = true; + if (connection.cleanupTimeout) { + clearTimeout(connection.cleanupTimeout); + connection.cleanupTimeout = undefined; + } + + return connection; + } + } + + await this.ensureHostsConnected(true); + + for (const connection of this.connections) { + if (!connection.isConnected()) continue; + if (connection.reserved) continue; + + if (request.nearest) throw new PostgresConnectionError('Nearest not implemented yet'); + + if (!this.matchRequest(connection, r)) continue; + + this.stats.connectionsReused++; + connection.reserved = true; + if (connection.cleanupTimeout) { + clearTimeout(connection.cleanupTimeout); + connection.cleanupTimeout = undefined; + } + + if (cacheHint) this.cacheHints[cacheHint] = connection; + return connection; + } + + if (this.connections.length < this.config.options.maxPoolSize) { + const connection = this.createAdditionalConnectionForRequest(r); + connection.reserved = true; + if (cacheHint) this.cacheHints[cacheHint] = connection; + return connection; + } + + return asyncOperation((resolve) => { + this.stats.connectionsQueued++; + this.queue.push({ + resolve: (connection) => { + if (cacheHint) this.cacheHints[cacheHint] = connection; + resolve(connection); + }, request: r, + }); + }); + } +} + +export class PostgresClientConnection { + protected messageId: number = 0; + status: ConnectionStatus = ConnectionStatus.pending; + public bufferSize: number = 2.5 * 1024 * 1024; + + cache: { [key: string]: Command } = {}; + + released: boolean = false; + + public connectingPromise?: Promise; + public lastCommand?: { command: Command, promise?: Promise }; + + public activeCommands: number = 0; + public executedCommands: number = 0; + public activeTransaction: boolean = false; + public reserved: boolean = false; + public cleanupTimeout: any; + + protected socket: Socket | TLSSocket; + + public transaction?: PostgresDatabaseTransaction; + + responseParser: ResponseParser; + error?: Error; + + bytesReceived: number = 0; + bytesSent: number = 0; + + statementId: number = 0; + + protected boundWrite = this.write.bind(this); + + constructor( + public id: number, + public readonly host: Host, + protected config: PostgresClientConfig, + protected onClose: (connection: PostgresClientConnection) => void, + protected onRelease: (connection: PostgresClientConnection) => void, + protected onSent: (bytes: number) => void, + protected onReceived: (bytes: number) => void, + ) { + this.responseParser = new ResponseParser(this.onResponse.bind(this)); + + if (this.config.options.ssl === true) { + const options: { [name: string]: any } = { + host: host.hostname, + port: host.port, + timeout: config.options.connectTimeoutMS, + servername: host.hostname, + }; + const optional = { + ca: config.options.tlsCAFile, + key: config.options.tlsCertificateKeyFile || config.options.tlsCertificateFile, + cert: config.options.tlsCertificateFile, + passphrase: config.options.tlsCertificateKeyFilePassword, + + rejectUnauthorized: config.options.rejectUnauthorized, + crl: config.options.tlsCRLFile, + checkServerIdentity: config.options.checkServerIdentity ? undefined : () => undefined, + }; + for (const i in optional) { + if (optional[i]) options[i] = optional[i]; + } + + this.socket = createTLSConnection(options); + this.socket.on('data', (data) => { + this.bytesReceived += data.byteLength; + this.onReceived(data.byteLength); + this.responseParser.feed(data); + }); + } else { + this.socket = createConnection({ + host: host.hostname, + port: host.port, + timeout: config.options.connectTimeoutMS, + }); + + this.socket.on('data', (data) => { + this.bytesReceived += data.byteLength; + this.onReceived(data.byteLength); + this.responseParser.feed(data); + }); + + // const socket = this.socket = turbo.connect(host.port, host.hostname); + // // this.socket.setNoDelay(true); + // const buffer = Buffer.allocUnsafe(this.bufferSize); + // + // function read() { + // socket.read(buffer, onRead); + // } + // + // function onRead(err: any, buf: Buffer, bytes: number) { + // if (!bytes) return; + // responseParser.feed(buf, bytes); + // read(); + // } + // + // read(); + } + + //important to catch it, so it doesn't bubble up + this.connect().catch((error) => { + this.error = error; + this.socket.end(); + onClose(this); + }); + } + + getCache(key: string): Command | undefined { + return this.cache[key]; + } + + setCache(key: string, command: Command) { + this.cache[key] = command; + } + + isConnected() { + return this.status === ConnectionStatus.connected; + } + + isConnecting() { + return this.status === ConnectionStatus.connecting; + } + + close() { + this.status = ConnectionStatus.disconnected; + this.socket.end(); + } + + run(sql: string): Promise { + return this.execute(new SelectCommand(sql)); + } + + query(sql: string, params): any { + //todo this is not possible since we don't know the parameters types + } + + prepare(sql: string, selector: SelectorState): PostgresClientPrepared { + return new PostgresClientPrepared(this, sql, selector, 's' + this.statementId++); + } + + public release() { + //connections attached to a transaction are not automatically released. + //only with commit/rollback actions + if (this.transaction && !this.transaction.ended) return; + + if (this.transaction) this.transaction = undefined; + this.released = true; + this.onRelease(this); + } + + /** + * When a full message from the server was received. + */ + protected onResponse(message: Uint8Array) { + if (!this.lastCommand) throw new PostgresError(`Got a server response without active command`); + this.lastCommand.command.handleResponse(message); + } + + /** + * Puts a command on the queue and executes it when queue is empty. + * A promise is return that is resolved when executed successfully, or rejected + * when timed out, parser error, or any other error. + */ + public async execute>(command: T): Promise> { + if (this.status === ConnectionStatus.pending) await this.connect(); + if (this.status === ConnectionStatus.disconnected) throw new PostgresError('Disconnected'); + + if (this.lastCommand && this.lastCommand.promise) { + await this.lastCommand.promise; + } + + this.lastCommand = { command }; + this.activeCommands++; + this.executedCommands++; + command.write = this.boundWrite; + try { + this.lastCommand.promise = command.execute(this.host, this, this.transaction); + return await this.lastCommand.promise; + } finally { + this.lastCommand = undefined; + this.activeCommands--; + } + } + + write(message: Uint8Array) { + this.socket.write(message); + this.bytesSent += message.byteLength; + this.onSent(message.byteLength); + } + + async connect(): Promise { + if (this.status === ConnectionStatus.disconnected) throw new PostgresConnectionError('Connection disconnected'); + if (this.status !== ConnectionStatus.pending) return; + + this.status = ConnectionStatus.connecting; + + this.connectingPromise = asyncOperation(async (resolve, reject) => { + this.socket.once('close', (hadErrors) => { + this.onClose(this); + if (this.status !== ConnectionStatus.connecting) return; + this.status = ConnectionStatus.disconnected; + reject(new PostgresConnectionError('Connection closed while connecting')); + }); + + this.socket.once('lookup', (error) => { + if (this.status !== ConnectionStatus.connecting) return; + + if (error) { + this.connectingPromise = undefined; + this.status = ConnectionStatus.disconnected; + reject(new PostgresConnectionError(formatError(error), { cause: error })); + } + }); + + this.socket.once('error', (error) => { + if (this.status !== ConnectionStatus.connecting) return; + this.connectingPromise = undefined; + this.status = ConnectionStatus.disconnected; + reject(new PostgresConnectionError(formatError(error), { cause: error })); + }); + + if (this.socket.destroyed) { + this.status = ConnectionStatus.disconnected; + this.connectingPromise = undefined; + resolve(); + } + + this.socket.once('ready', async () => { + if (await this.execute(new HandshakeCommand())) { + this.status = ConnectionStatus.connected; + this.socket.setTimeout(this.config.options.socketTimeoutMS); + this.connectingPromise = undefined; + resolve(); + } else { + this.status = ConnectionStatus.disconnected; + this.connectingPromise = undefined; + reject(new PostgresConnectionError('Connection error: Could not complete handshake 🤷‍️')); + } + }); + }); + + return this.connectingPromise; + } +} + +export class PostgresClient { + pool: PostgresConnectionPool; + stats: PostgresStats = new PostgresStats(); + config: PostgresClientConfig; + + constructor(options: PostgresClientConfig | string) { + this.config = 'string' === typeof options ? new PostgresClientConfig(options) : options; + this.pool = new PostgresConnectionPool(this.config, this.stats); + } + + /** + * Returns an existing or new connection, that needs to be released once done using it. + */ + async getConnection(request: Partial = {}, transaction?: PostgresDatabaseTransaction): Promise { + if (transaction && transaction.connection) return transaction.connection; + const connection = await this.pool.getConnection(request); + + //todo check if this is correct and align with mongo + if (transaction) { + transaction.connection = connection; + connection.transaction = transaction; + try { + await transaction.begin(); + } catch (error) { + transaction.ended = true; + connection.release(); + throw new Error('Could not start transaction: ' + error); + } + } + return connection; + } +} diff --git a/packages/postgres/src/config.ts b/packages/postgres/src/config.ts index 47c0534cb..4a81baaec 100644 --- a/packages/postgres/src/config.ts +++ b/packages/postgres/src/config.ts @@ -1,43 +1,213 @@ -import { ConnectionOptions } from 'tls'; -import { PoolConfig } from 'pg'; +/* + * Deepkit Framework + * Copyright (C) 2021 Deepkit UG, Marc J. Schmidt + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the MIT License. + * + * You should have received a copy of the MIT License along with this program. + */ + +import { parse as parseUrl } from 'url'; +import { parse as parseQueryString } from 'querystring'; import { cast } from '@deepkit/type'; +import { HostType } from './client.js'; + +interface ConnectionInterface { + close(): void; +} + +export class Host { + protected type: HostType = HostType.primary; + + parameters: {[name: string]: string} = {}; + + protected typeSetAt?: Date; + + /** + * Round Trip Times of the `ismaster` command, for `nearest` + */ + protected rrt?: number; + + public readonly connections: ConnectionInterface[] = []; + + constructor( + public readonly config: PostgresClientConfig, + public readonly hostname: string, + public readonly port: number = 27017, + ) { + } + + get id() { + return `${this.hostname}:${this.port}`; + } + + isWritable(): boolean { + return this.type === HostType.primary; + } + + isSecondary(): boolean { + return this.type === HostType.secondary; + } + + isReadable(): boolean { + return this.type === HostType.primary || this.type === HostType.secondary; + } + + setType(type: HostType) { + if (this.type !== type) { + //type changed. Should we do anything special? + } + this.type = type; + this.typeSetAt = new Date; + } + + getType() { + return this.type; + } +} + +export class ConnectionOptions { + // replicaSet?: string; + + connectTimeoutMS: number = 10000; + socketTimeoutMS: number = 36000; + + // w?: string; + // wtimeoutMS?: number; + // journal?: string; -interface AdapterClientConfig { - user?: string; - database?: string; - password?: string; - port?: number; - host?: string; - connectionString?: string; - keepAlive?: boolean; - statement_timeout?: false | number; - parseInputDatesAsUTC?: boolean; - ssl?: boolean | ConnectionOptions; - query_timeout?: number; - keepAliveInitialDelayMillis?: number; - idle_in_transaction_session_timeout?: number; - application_name?: string; - connectionTimeoutMillis?: number; - - max?: number; - min?: number; - idleTimeoutMillis?: number; + // appName?: string; + + retryWrites: boolean = false; + retryReads: boolean = true; + + // readConcernLevel: 'local' | 'majority' | 'linearizable' | 'available' = 'majority'; + + //unknown is there to prevent Typescript generating wrong options.d.ts + readPreference: 'primary' | 'primaryPreferred' | 'secondary' | 'secondaryPreferred' | 'nearest' | 'unknown' = 'primary'; + + // maxStalenessSeconds?: number; + // readPreferenceTags?: string; //e.g. "dc:ny,rack:1" + // hedge?: boolean; + + // compressors?: 'snappy' | 'zlib' | 'zstd'; + // zlibCompressionLevel?: number; + + authSource?: string; + // authMechanism?: 'SCRAM-SHA-1' | 'SCRAM-SHA-256' | 'MONGODB-X509' | 'GSSAPI' | 'PLAIN'; + // authMechanismProperties?: string; + // gssapiServiceName?: string; + + //todo check what postgres needs + ssl?: boolean; + tlsCertificateFile?: string; + tlsCertificateKeyFile?: string; + tlsCertificateKeyFilePassword?: string; + tlsCAFile?: string; + tlsCRLFile?: string; + tlsAllowInvalidCertificates?: boolean; + tlsAllowInvalidHostnames?: boolean; + tlsInsecure?: boolean; + + // queue stuff + maxPoolSize: number = 20; + minPoolSize: number = 1; + maxIdleTimeMS: number = 100; + waitQueueTimeoutMS: number = 0; + + get rejectUnauthorized() { + return this.tlsInsecure || this.tlsAllowInvalidCertificates; + } + + get checkServerIdentity() { + return !this.tlsAllowInvalidHostnames && !this.tlsInsecure; + } + + get secondaryReadAllowed() { + return this.readPreference === 'secondary' || this.readPreference === 'secondaryPreferred'; + } } -export function parseConnectionString(url: string): PoolConfig { - const parsed = new URL(url); +/** + * Default URL: + * mongodb://mongodb0.example.com:27017 + * + * ReplicaSet UR: + * mongodb://mongodb0.example.com:27017,mongodb1.example.com:27017,mongodb2.example.com:27017/?replicaSet=myRepl + * + * Shared URL: + * mongodb://mongos0.example.com:27017,mongos1.example.com:27017,mongos2.example.com:27017 + * + * SVR URL: + * mongodb+srv://server.example.com/ + */ +export class PostgresClientConfig { + defaultDb?: string; + authUser?: string; + authPassword?: string; + public readonly hosts: Host[] = []; + + options: ConnectionOptions = new ConnectionOptions; + + constructor(connectionString: string) { + this.parseConnectionString(connectionString); + } - const options: {[name: string]: any} = {}; - for (const [key, value] of parsed.searchParams.entries()) { - options[key] = value; + async getHosts(): Promise { + return this.hosts; } - return cast({ - host: parsed.hostname, - port: parsed.port, - database: parsed.pathname.slice(1), - user: parsed.username, - password: parsed.password, - ...options, - }); + protected parseConnectionString(url: string) { + //we replace only first `,` with `/,` so we get additional host names in parsed.path + url = url.replace(',', '/,'); + + const parsed = parseUrl(url); + //e.g. for `database://peter:asd@localhost,127.0.0.1,yetanother/asd` + //parsed.pathname contains now /,127.0.0.1,yetanother/asd + //and parsed.hostname localhost. Thus we merge those, when we detect `/,` in path + const hostnames: string[] = []; + let defaultDb = parsed.pathname ? parsed.pathname.substr(1) : ''; + + if (!parsed.hostname) throw new Error('No hostname found in connection string'); + hostnames.push(`${parsed.hostname}:${parsed.port || 5432}`); + + if (parsed.pathname && parsed.pathname.startsWith('/,')) { + //we got multiple host names + const lastSlash = parsed.pathname.lastIndexOf('/'); + if (lastSlash === 0) { + //no database name provided, so whole path contains host names + //offset `2` because `/,` + hostnames.push(...parsed.pathname.substr(2).split(',')); + defaultDb = ''; + } else { + hostnames.push(...parsed.pathname.substr(2, lastSlash).split(',')); + defaultDb = parsed.pathname.substr(lastSlash + 1); + } + } + this.hosts.splice(0, this.hosts.length); + for (const hostname of hostnames) { + const [host, port] = hostname.split(':'); + this.hosts.push(new Host(this, host, port ? parseInt(port, 10) : 27017)); + } + + this.defaultDb = defaultDb; + + if (parsed.auth) { + const firstColon = parsed.auth.indexOf(':'); + if (firstColon === -1) { + this.authUser = parsed.auth; + } else { + this.authUser = parsed.auth.substr(0, firstColon); + this.authPassword = parsed.auth.substr(firstColon + 1); + } + } + + const options = parsed.query ? parseQueryString(parsed.query) : {}; + this.options = cast(options); + } + + getAuthSource(): string { + return this.options.authSource || this.defaultDb || 'admin'; + } } diff --git a/packages/postgres/src/postgres-adapter.ts b/packages/postgres/src/postgres-adapter.ts index c7232533d..c4fc0ceff 100644 --- a/packages/postgres/src/postgres-adapter.ts +++ b/packages/postgres/src/postgres-adapter.ts @@ -9,56 +9,49 @@ */ import { - asAliasName, + createTables, DefaultPlatform, - getDeepTypeCaster, getPreparedEntity, prepareBatchUpdate, + PreparedAdapter, PreparedEntity, - splitDotPath, SqlBuilder, - SQLConnection, - SQLConnectionPool, - SQLDatabaseAdapter, - SQLDatabaseQuery, - SQLDatabaseQueryFactory, - SQLPersistence, - SQLQueryModel, - SQLQueryResolver, + SqlBuilderRegistry, SQLStatement, } from '@deepkit/sql'; import { - DatabaseDeleteError, + DatabaseAdapter, + DatabaseEntityRegistry, DatabaseError, DatabaseLogger, - DatabasePatchError, + DatabasePersistence, DatabasePersistenceChangeSet, DatabaseSession, - DatabaseTransaction, DatabaseUpdateError, DeleteResult, ensureDatabaseError, + getStateCacheId, + MigrateOptions, OrmEntity, PatchResult, - primaryKeyObjectConverter, + SelectorResolver, + SelectorState, UniqueConstraintFailure, } from '@deepkit/orm'; import { PostgresPlatform } from './postgres-platform.js'; -import type { Pool, PoolClient, PoolConfig } from 'pg'; -import pg from 'pg'; -import { AbstractClassType, asyncOperation, ClassType, empty } from '@deepkit/core'; +import { empty } from '@deepkit/core'; import { FrameCategory, Stopwatch } from '@deepkit/stopwatch'; +import { Changes, ReflectionClass } from '@deepkit/type'; +import { PostgresClientConfig } from './config.js'; import { - Changes, - getPatchSerializeFunction, - getSerializeFunction, - ReceiveType, - ReflectionClass, - ReflectionKind, - ReflectionProperty, - resolvePath, -} from '@deepkit/type'; -import { parseConnectionString } from './config.js'; + FindCommand, + InsertCommand, + PostgresClient, + PostgresClientConnection, + PostgresClientPrepared, + PostgresConnectionPool, + PostgresDatabaseTransaction, +} from './client.js'; /** * Converts a specific database error to a more specific error, if possible. @@ -80,11 +73,10 @@ function handleSpecificError(session: DatabaseSession, error: DatabaseError): Er return error; } - export class PostgresStatement extends SQLStatement { protected released = false; - constructor(protected logger: DatabaseLogger, protected sql: string, protected client: PoolClient, protected stopwatch?: Stopwatch) { + constructor(protected logger: DatabaseLogger, protected sql: string, protected prepared: PostgresClientPrepared, protected stopwatch?: Stopwatch) { super(); } @@ -95,9 +87,7 @@ export class PostgresStatement extends SQLStatement { this.logger.logQuery(this.sql, params); //postgres driver does not maintain error.stack when they throw errors, so //we have to manually convert it using asyncOperation. - const res = await asyncOperation((resolve, reject) => { - this.client.query(this.sql, params).then(resolve).catch(reject); - }); + const res = await this.prepared.execute(params); return res.rows[0]; } catch (error: any) { error = ensureDatabaseError(error); @@ -115,9 +105,7 @@ export class PostgresStatement extends SQLStatement { this.logger.logQuery(this.sql, params); //postgres driver does not maintain error.stack when they throw errors, so //we have to manually convert it using asyncOperation. - const res = await asyncOperation((resolve, reject) => { - this.client.query(this.sql, params).then(resolve).catch(reject); - }); + const res = await this.prepared.execute(params); return res.rows; } catch (error: any) { error = ensureDatabaseError(error, `Query: ${this.sql}\nParams: ${params}`); @@ -132,152 +120,113 @@ export class PostgresStatement extends SQLStatement { } } -export class PostgresConnection extends SQLConnection { - protected changes: number = 0; - public lastReturningRows: any[] = []; - - constructor( - connectionPool: PostgresConnectionPool, - public connection: PoolClient, - logger?: DatabaseLogger, - transaction?: DatabaseTransaction, - stopwatch?: Stopwatch, - ) { - super(connectionPool, logger, transaction, stopwatch); - } +// function typeSafeDefaultValue(property: ReflectionProperty): any { +// if (property.type.kind === ReflectionKind.string) return ''; +// if (property.type.kind === ReflectionKind.number) return 0; +// if (property.type.kind === ReflectionKind.boolean) return false; +// if (property.type.kind === ReflectionKind.class && property.type.classType === Date) return false; +// +// return null; +// } - async prepare(sql: string) { - return new PostgresStatement(this.logger, sql, this.connection, this.stopwatch); - } +export class PostgresPersistence extends DatabasePersistence { + protected connection?: PostgresClientConnection; - async run(sql: string, params: any[] = []) { - const frame = this.stopwatch ? this.stopwatch.start('Query', FrameCategory.databaseQuery) : undefined; - try { - if (frame) frame.data({ sql, sqlParams: params }); - //postgres driver does not maintain error.stack when they throw errors, so - //we have to manually convert it using asyncOperation. - const res = await asyncOperation((resolve, reject) => { - this.connection.query(sql, params).then(resolve).catch(reject); - }); - this.logger.logQuery(sql, params); - this.lastReturningRows = res.rows; - this.changes = res.rowCount; - } catch (error: any) { - error = ensureDatabaseError(error); - this.logger.failedQuery(error, sql, params); - throw error; - } finally { - if (frame) frame.end(); - } - } - - async getChanges(): Promise { - return this.changes; - } -} - -export type TransactionTypes = 'REPEATABLE READ' | 'READ COMMITTED' | 'SERIALIZABLE'; - -export class PostgresDatabaseTransaction extends DatabaseTransaction { - connection?: PostgresConnection; - - setTransaction?: TransactionTypes; - - /** - * This is the default for mysql databases. - */ - repeatableRead(): this { - this.setTransaction = 'REPEATABLE READ'; - return this; - } - - readCommitted(): this { - this.setTransaction = 'READ COMMITTED'; - return this; - } - - serializable(): this { - this.setTransaction = 'SERIALIZABLE'; - return this; - } - - async begin() { - if (!this.connection) return; - const set = this.setTransaction ? 'SET TRANSACTION ISOLATION LEVEL ' + this.setTransaction + ';' : ''; - await this.connection.run(set + 'START TRANSACTION'); - } - - async commit() { - if (!this.connection) return; - if (this.ended) throw new Error('Transaction ended already'); - - await this.connection.run('COMMIT'); - this.ended = true; - this.connection.release(); - } - - async rollback() { - if (!this.connection) return; - - if (this.ended) throw new Error('Transaction ended already'); - await this.connection.run('ROLLBACK'); - this.ended = true; - this.connection.release(); - } -} - -export class PostgresConnectionPool extends SQLConnectionPool { - constructor(protected pool: Pool) { + constructor(protected platform: DefaultPlatform, public pool: PostgresConnectionPool, public session: DatabaseSession) { super(); } - async getConnection(logger?: DatabaseLogger, transaction?: PostgresDatabaseTransaction, stopwatch?: Stopwatch): Promise { - //when a transaction object is given, it means we make the connection sticky exclusively to that transaction - //and only release the connection when the transaction is commit/rollback is executed. + async getInsertBatchSize(schema: ReflectionClass): Promise { + return Math.floor(30000 / schema.getProperties().length); + } - if (transaction && transaction.connection) return transaction.connection; + async insert(classSchema: ReflectionClass, items: T[]): Promise { + const batchSize = await this.getInsertBatchSize(classSchema); + const prepared = getPreparedEntity(this.session.adapter, classSchema); - const poolClient = await this.pool.connect(); - this.activeConnections++; - const connection = new PostgresConnection(this, poolClient, logger, transaction, stopwatch); - if (transaction) { - transaction.connection = connection; - try { - await transaction.begin(); - } catch (error) { - transaction.ended = true; - connection.release(); - throw new Error('Could not start transaction: ' + error); + if (batchSize > items.length) { + await this.batchInsert(prepared, items); + await this.populateAutoIncrementFields(prepared, items); + } else { + for (let i = 0; i < items.length; i += batchSize) { + const batched = items.slice(i, i + batchSize); + await this.batchInsert(prepared, batched); + await this.populateAutoIncrementFields(prepared, batched); } } - return connection; } - release(connection: PostgresConnection) { - //connections attached to a transaction are not automatically released. - //only with commit/rollback actions - if (connection.transaction && !connection.transaction.ended) return; + protected async batchInsert(prepared: PreparedEntity, items: T[]) { + const connection = await this.getConnection(); + try { + await connection.execute(new InsertCommand(this.platform, prepared, items)); + } finally { + connection.release(); + } - super.release(connection); - connection.connection.release(); - } -} + // const scopeSerializer = getSerializeFunction(classSchema.type, this.platform.serializer.serializeRegistry); + // const placeholder = new this.platform.placeholderStrategy; + // + // const insert: string[] = []; + // const params: any[] = []; + // const names: string[] = []; + // const prepared = getPreparedEntity(this.session.adapter, classSchema); + // + // for (const property of prepared.fields) { + // if (property.autoIncrement) continue; + // names.push(property.columnNameEscaped); + // } + // + // for (const item of items) { + // const converted = scopeSerializer(item); + // const row: string[] = []; + // + // for (const property of prepared.fields) { + // if (property.autoIncrement) continue; + // + // const v = converted[property.name]; + // params.push(v === undefined ? null : v); + // row.push(property.sqlTypeCast(placeholder.getPlaceholder())); + // } + // + // insert.push(row.join(', ')); + // } + + // const sql = this.getInsertSQL(classSchema, names, insert); + // try { + // await (await this.getConnection()).run(sql, params); + // } catch (error: any) { + // error = new DatabaseInsertError( + // classSchema, + // items as OrmEntity[], + // `Could not insert ${classSchema.getClassName()} into database: ${formatError(error)}`, + // { cause: error }, + // ); + // throw this.handleSpecificError(error); + // } + } + + release(): void { + this.connection?.release(); + } + + async getConnection(): Promise { + if (!this.connection) { + return this.connection = await this.pool.getConnection(); + } -function typeSafeDefaultValue(property: ReflectionProperty): any { - if (property.type.kind === ReflectionKind.string) return ''; - if (property.type.kind === ReflectionKind.number) return 0; - if (property.type.kind === ReflectionKind.boolean) return false; - if (property.type.kind === ReflectionKind.class && property.type.classType === Date) return false; + return this.connection; + } - return null; -} + remove(classSchema: ReflectionClass, items: T[]): Promise { + return Promise.resolve(undefined); + } -export class PostgresPersistence extends SQLPersistence { - constructor(protected platform: DefaultPlatform, public connectionPool: PostgresConnectionPool, session: DatabaseSession) { - super(platform, connectionPool, session); + update(classSchema: ReflectionClass, changeSets: DatabasePersistenceChangeSet[]): Promise { + return Promise.resolve(undefined); } - override handleSpecificError(error: Error): Error { + handleSpecificError(error: Error): Error { return handleSpecificError(this.session, error); } @@ -359,16 +308,16 @@ export class PostgresPersistence extends SQLPersistence { `; try { - const connection = await this.getConnection(); //will automatically be released in SQLPersistence - const result = await connection.execAndReturnAll(sql, params); - for (const returning of result) { - const r = prepared.assignReturning[returning[prepared.pkName]]; - if (!r) continue; - - for (const name of r.names) { - r.item[name] = returning[name]; - } - } + // const connection = await this.getConnection(); //will automatically be released in SQLPersistence + // const result = await connection.execAndReturnAll(sql, params); + // for (const returning of result) { + // const r = prepared.assignReturning[returning[prepared.pkName]]; + // if (!r) continue; + // + // for (const name of r.names) { + // r.item[name] = returning[name]; + // } + // } } catch (error: any) { const reflection = ReflectionClass.from(entity.type); error = new DatabaseUpdateError( @@ -381,20 +330,20 @@ export class PostgresPersistence extends SQLPersistence { } } - protected async populateAutoIncrementFields(classSchema: ReflectionClass, items: T[]) { - const autoIncrement = classSchema.getAutoIncrement(); - if (!autoIncrement) return; - const connection = await this.getConnection(); //will automatically be released in SQLPersistence - - //We adjusted the INSERT SQL with additional RETURNING which returns all generated - //auto-increment values. We read the result and simply assign the value. - const name = autoIncrement.name; - const insertedRows = connection.lastReturningRows; - if (!insertedRows.length) return; - - for (let i = 0; i < items.length; i++) { - items[i][name] = insertedRows[i][name]; - } + protected async populateAutoIncrementFields(prepared: PreparedEntity, items: T[]) { + // const autoIncrement = classSchema.getAutoIncrement(); + // if (!autoIncrement) return; + // const connection = await this.getConnection(); //will automatically be released in SQLPersistence + // + // //We adjusted the INSERT SQL with additional RETURNING which returns all generated + // //auto-increment values. We read the result and simply assign the value. + // const name = autoIncrement.name; + // const insertedRows = connection.lastReturningRows; + // if (!insertedRows.length) return; + // + // for (let i = 0; i < items.length; i++) { + // items[i][name] = insertedRows[i][name]; + // } } protected getInsertSQL(classSchema: ReflectionClass, fields: string[], values: string[]): string { @@ -411,183 +360,227 @@ export class PostgresPersistence extends SQLPersistence { } } -export class PostgresSQLQueryResolver extends SQLQueryResolver { - async delete(model: SQLQueryModel, deleteResult: DeleteResult): Promise { - const primaryKey = this.classSchema.getPrimary(); - const pkField = this.platform.quoteIdentifier(primaryKey.name); - const primaryKeyConverted = primaryKeyObjectConverter(this.classSchema, this.platform.serializer.deserializeRegistry); - - const sqlBuilder = new SqlBuilder(this.adapter); - const tableName = this.platform.getTableIdentifier(this.classSchema); - const select = sqlBuilder.select(this.classSchema, model, { select: [`${tableName}.${pkField}`] }); - - const connection = await this.connectionPool.getConnection(this.session.logger, this.session.assignedTransaction, this.session.stopwatch); - try { - const sql = ` - WITH _ AS (${select.sql}) - DELETE - FROM ${tableName} USING _ - WHERE ${tableName}.${pkField} = _.${pkField} - RETURNING ${tableName}.${pkField} - `; - - const rows = await connection.execAndReturnAll(sql, select.params); - deleteResult.modified = rows.length; - for (const row of rows) { - deleteResult.primaryKeys.push(primaryKeyConverted(row[primaryKey.name])); - } - } catch (error: any) { - error = new DatabaseDeleteError(this.classSchema, 'Could not delete in database', { cause: error }); - error.query = model; - throw this.handleSpecificError(error); - } finally { - connection.release(); - } +export class PostgresSelectorResolver extends SelectorResolver { + constructor( + public connectionPool: PostgresConnectionPool, + public platform: DefaultPlatform, + public adapter: PreparedAdapter, + public session: DatabaseSession, + ) { + super(session); } - override handleSpecificError(error: Error): Error { - return handleSpecificError(this.session, error); + count(model: SelectorState): Promise { + return Promise.resolve(0); } - async patch(model: SQLQueryModel, changes: Changes, patchResult: PatchResult): Promise { - const select: string[] = []; - const selectParams: any[] = []; - const entity = getPreparedEntity(this.session.adapter as SQLDatabaseAdapter, this.classSchema); - const tableName = entity.tableNameEscaped; - const primaryKey = this.classSchema.getPrimary(); - const primaryKeyConverted = primaryKeyObjectConverter(this.classSchema, this.platform.serializer.deserializeRegistry); - - const fieldsSet: { [name: string]: 1 } = {}; - const aggregateFields: { [name: string]: { converted: (v: any) => any } } = {}; - - const patchSerialize = getPatchSerializeFunction(this.classSchema.type, this.platform.serializer.serializeRegistry); - const $set = changes.$set ? patchSerialize(changes.$set, undefined) : undefined; - const set: string[] = []; - - if ($set) for (const i in $set) { - if (!$set.hasOwnProperty(i)) continue; - if ($set[i] === undefined || $set[i] === null) { - set.push(`${this.platform.quoteIdentifier(i)} = NULL`); - } else { - fieldsSet[i] = 1; - select.push(`$${selectParams.length + 1} as ${this.platform.quoteIdentifier(asAliasName(i))}`); - selectParams.push($set[i]); - } - } - - if (changes.$unset) for (const i in changes.$unset) { - if (!changes.$unset.hasOwnProperty(i)) continue; - fieldsSet[i] = 1; - select.push(`NULL as ${this.platform.quoteIdentifier(i)}`); - } - - for (const i of model.returning) { - aggregateFields[i] = { converted: getSerializeFunction(resolvePath(i, this.classSchema.type), this.platform.serializer.deserializeRegistry) }; - select.push(`(${this.platform.quoteIdentifier(i)} ) as ${this.platform.quoteIdentifier(i)}`); - } - - if (changes.$inc) for (const i in changes.$inc) { - if (!changes.$inc.hasOwnProperty(i)) continue; - fieldsSet[i] = 1; - aggregateFields[i] = { converted: getSerializeFunction(resolvePath(i, this.classSchema.type), this.platform.serializer.serializeRegistry) }; - const sqlTypeCast = getDeepTypeCaster(entity, i); - select.push(`(${sqlTypeCast('(' + this.platform.getColumnAccessor('', i) + ')')} + ${this.platform.quoteValue(changes.$inc[i])}) as ${this.platform.quoteIdentifier(asAliasName(i))}`); - } - - for (const i in fieldsSet) { - if (i.includes('.')) { - let [firstPart, secondPart] = splitDotPath(i); - const path = '{' + secondPart.replace(/\./g, ',').replace(/[\]\[]/g, '') + '}'; - set.push(`${this.platform.quoteIdentifier(firstPart)} = jsonb_set(${this.platform.quoteIdentifier(firstPart)}, '${path}', to_jsonb(_b.${this.platform.quoteIdentifier(asAliasName(i))}))`); - } else { - const property = entity.fieldMap[i]; - const ref = '_b.' + this.platform.quoteIdentifier(asAliasName(i)); - set.push(`${this.platform.quoteIdentifier(i)} = ${property.sqlTypeCast(ref)}`); - } - } - let bPrimaryKey = primaryKey.name; - //we need a different name because primaryKeys could be updated as well - if (fieldsSet[primaryKey.name]) { - select.unshift(this.platform.quoteIdentifier(primaryKey.name) + ' as __' + primaryKey.name); - bPrimaryKey = '__' + primaryKey.name; - } else { - select.unshift(this.platform.quoteIdentifier(primaryKey.name)); - } - - const returningSelect: string[] = []; - returningSelect.push(tableName + '.' + this.platform.quoteIdentifier(primaryKey.name)); - - if (!empty(aggregateFields)) { - for (const i in aggregateFields) { - returningSelect.push(this.platform.getColumnAccessor(tableName, i)); - } - } - - const sqlBuilder = new SqlBuilder(this.adapter, selectParams.length); - const selectSQL = sqlBuilder.select(this.classSchema, model, { select }); - - const sql = ` - WITH _b AS (${selectSQL.sql}) - UPDATE - ${tableName} - SET ${set.join(', ')} - FROM _b - WHERE ${tableName}.${this.platform.quoteIdentifier(primaryKey.name)} = _b.${this.platform.quoteIdentifier(bPrimaryKey)} - RETURNING ${returningSelect.join(', ')} - `; + async find(model: SelectorState): Promise { + const cacheId = getStateCacheId(model); + const connection = await this.connectionPool.getConnection({}, cacheId); - const connection = await this.connectionPool.getConnection(this.session.logger, this.session.assignedTransaction, this.session.stopwatch); try { - const result = await connection.execAndReturnAll(sql, selectSQL.params); - - patchResult.modified = result.length; - for (const i in aggregateFields) { - patchResult.returning[i] = []; + let findCommand = connection.getCache(cacheId) as FindCommand | undefined; + if (!findCommand) { + const sqlBuilder = new SqlBuilder(this.adapter); + const sql = sqlBuilder.select(model); + const prepared = getPreparedEntity(this.adapter, model.schema); + findCommand = new FindCommand(sql.sql, prepared, model, ''); + connection.setCache(cacheId, findCommand); } - for (const returning of result) { - patchResult.primaryKeys.push(primaryKeyConverted(returning[primaryKey.name])); - for (const i in aggregateFields) { - patchResult.returning[i].push(aggregateFields[i].converted(returning[i])); - } - } - } catch (error: any) { - error = new DatabasePatchError(this.classSchema, model, changes, `Could not patch ${this.classSchema.getClassName()} in database`, { cause: error }); - throw this.handleSpecificError(error); + findCommand.setParameters(model.params); + + //todo identity map + stuff that the Formatter did + return await connection.execute(findCommand); } finally { connection.release(); } } -} -export class PostgresSQLDatabaseQuery extends SQLDatabaseQuery { -} - -export class PostgresSQLDatabaseQueryFactory extends SQLDatabaseQueryFactory { - createQuery(type?: ReceiveType | ClassType | AbstractClassType | ReflectionClass): PostgresSQLDatabaseQuery { - return new PostgresSQLDatabaseQuery(ReflectionClass.from(type), this.databaseSession, - new PostgresSQLQueryResolver(this.connectionPool, this.platform, ReflectionClass.from(type), this.databaseSession.adapter, this.databaseSession), - ); + async findOneOrUndefined(model: SelectorState): Promise { + model = { ...model, limit: 1 }; + const rows = await this.find(model); + return rows[0]; + } + + async delete(model: SelectorState, deleteResult: DeleteResult): Promise { + // const primaryKey = model.schema.getPrimary(); + // const pkField = this.platform.quoteIdentifier(primaryKey.name); + // const primaryKeyConverted = primaryKeyObjectConverter(model.schema, this.platform.serializer.deserializeRegistry); + // + // const sqlBuilder = new SqlBuilder(this.adapter); + // const tableName = this.platform.getTableIdentifier(model.schema); + // const select = sqlBuilder.select(model, { select: [`${tableName}.${pkField}`] }); + // + // const connection = await this.connectionPool.getConnection(this.session.logger, this.session.assignedTransaction, this.session.stopwatch); + // try { + // const sql = ` + // WITH _ AS (${select.sql}) + // DELETE + // FROM ${tableName} USING _ + // WHERE ${tableName}.${pkField} = _.${pkField} + // RETURNING ${tableName}.${pkField} + // `; + // + // const rows = await connection.execAndReturnAll(sql, select.params); + // deleteResult.modified = rows.length; + // // for (const row of rows) { + // // deleteResult.primaryKeys.push(primaryKeyConverted(row[primaryKey.name])); + // // } + // } catch (error: any) { + // error = new DatabaseDeleteError(model.schema, 'Could not delete in database', { cause: error }); + // error.query = model; + // throw this.handleSpecificError(error); + // } finally { + // connection.release(); + // } + } + + handleSpecificError(error: Error): Error { + // return handleSpecificError(this.session, error); + return error; + } + + async patch(model: SelectorState, changes: Changes, patchResult: PatchResult): Promise { + const select: string[] = []; + const selectParams: any[] = []; + const entity = getPreparedEntity(this.adapter, model.schema); + const tableName = entity.tableNameEscaped; + // const primaryKey = model.schema.getPrimary(); + // const primaryKeyConverted = primaryKeyObjectConverter(model.schema, this.platform.serializer.deserializeRegistry); + // + // const fieldsSet: { [name: string]: 1 } = {}; + // const aggregateFields: { [name: string]: { converted: (v: any) => any } } = {}; + // + // const patchSerialize = getPatchSerializeFunction(model.schema.type, this.platform.serializer.serializeRegistry); + // const $set = changes.$set ? patchSerialize(changes.$set, undefined) : undefined; + // const set: string[] = []; + // + // if ($set) for (const i in $set) { + // if (!$set.hasOwnProperty(i)) continue; + // if ($set[i] === undefined || $set[i] === null) { + // set.push(`${this.platform.quoteIdentifier(i)} = NULL`); + // } else { + // fieldsSet[i] = 1; + // select.push(`$${selectParams.length + 1} as ${this.platform.quoteIdentifier(asAliasName(i))}`); + // selectParams.push($set[i]); + // } + // } + // + // if (changes.$unset) for (const i in changes.$unset) { + // if (!changes.$unset.hasOwnProperty(i)) continue; + // fieldsSet[i] = 1; + // select.push(`NULL as ${this.platform.quoteIdentifier(i)}`); + // } + // + // // todo readd + // // for (const i of model.returning) { + // // aggregateFields[i] = { converted: getSerializeFunction(resolvePath(i, model.schema.type), this.platform.serializer.deserializeRegistry) }; + // // select.push(`(${this.platform.quoteIdentifier(i)} ) as ${this.platform.quoteIdentifier(i)}`); + // // } + // + // if (changes.$inc) for (const i in changes.$inc) { + // if (!changes.$inc.hasOwnProperty(i)) continue; + // fieldsSet[i] = 1; + // aggregateFields[i] = { converted: getSerializeFunction(resolvePath(i, model.schema.type), this.platform.serializer.serializeRegistry) }; + // const sqlTypeCast = getDeepTypeCaster(entity, i); + // select.push(`(${sqlTypeCast('(' + this.platform.getColumnAccessor('', i) + ')')} + ${this.platform.quoteValue(changes.$inc[i])}) as ${this.platform.quoteIdentifier(asAliasName(i))}`); + // } + // + // for (const i in fieldsSet) { + // if (i.includes('.')) { + // let [firstPart, secondPart] = splitDotPath(i); + // const path = '{' + secondPart.replace(/\./g, ',').replace(/[\]\[]/g, '') + '}'; + // set.push(`${this.platform.quoteIdentifier(firstPart)} = jsonb_set(${this.platform.quoteIdentifier(firstPart)}, '${path}', to_jsonb(_b.${this.platform.quoteIdentifier(asAliasName(i))}))`); + // } else { + // const property = entity.fieldMap[i]; + // const ref = '_b.' + this.platform.quoteIdentifier(asAliasName(i)); + // set.push(`${this.platform.quoteIdentifier(i)} = ${property.sqlTypeCast(ref)}`); + // } + // } + // let bPrimaryKey = primaryKey.name; + // //we need a different name because primaryKeys could be updated as well + // if (fieldsSet[primaryKey.name]) { + // select.unshift(this.platform.quoteIdentifier(primaryKey.name) + ' as __' + primaryKey.name); + // bPrimaryKey = '__' + primaryKey.name; + // } else { + // select.unshift(this.platform.quoteIdentifier(primaryKey.name)); + // } + // + // const returningSelect: string[] = []; + // returningSelect.push(tableName + '.' + this.platform.quoteIdentifier(primaryKey.name)); + // + // if (!empty(aggregateFields)) { + // for (const i in aggregateFields) { + // returningSelect.push(this.platform.getColumnAccessor(tableName, i)); + // } + // } + // + // const sqlBuilder = new SqlBuilder(this.adapter, selectParams.slice()); + // const selectSQL = sqlBuilder.select(model, { select }); + // + // const sql = ` + // WITH _b AS (${selectSQL.sql}) + // UPDATE + // ${tableName} + // SET ${set.join(', ')} + // FROM _b + // WHERE ${tableName}.${this.platform.quoteIdentifier(primaryKey.name)} = _b.${this.platform.quoteIdentifier(bPrimaryKey)} + // RETURNING ${returningSelect.join(', ')} + // `; + // + // const connection = await this.connectionPool.getConnection(this.session.logger, this.session.assignedTransaction, this.session.stopwatch); + // try { + // const result = await connection.execAndReturnAll(sql, selectSQL.params); + // + // patchResult.modified = result.length; + // for (const i in aggregateFields) { + // patchResult.returning[i] = []; + // } + // + // for (const returning of result) { + // patchResult.primaryKeys.push(primaryKeyConverted(returning[primaryKey.name])); + // for (const i in aggregateFields) { + // patchResult.returning[i].push(aggregateFields[i].converted(returning[i])); + // } + // } + // } catch (error: any) { + // error = new DatabasePatchError(model.schema, model, changes, `Could not patch ${model.schema.getClassName()} in database`, { cause: error }); + // throw this.handleSpecificError(error); + // } finally { + // connection.release(); + // } } } -export class PostgresDatabaseAdapter extends SQLDatabaseAdapter { - protected options: PoolConfig; - protected pool: pg.Pool; - public connectionPool : PostgresConnectionPool; +export class PostgresDatabaseAdapter extends DatabaseAdapter implements PreparedAdapter { + public client: PostgresClient; public platform = new PostgresPlatform(); closed = false; - constructor(options: PoolConfig | string, additional: Partial = {}) { + builderRegistry = new SqlBuilderRegistry; + cache = {}; + preparedEntities = new Map, PreparedEntity>; + + isNativeForeignKeyConstraintSupported(): boolean { + return false; + } + + migrate(options: MigrateOptions, entityRegistry: DatabaseEntityRegistry): Promise { + return Promise.resolve(undefined); + } + + constructor(options: PostgresClientConfig | string) { super(); - const defaults: PoolConfig = {}; - options = 'string' === typeof options ? parseConnectionString(options) : options; - this.options = Object.assign(defaults, options, additional); - this.pool = new pg.Pool(this.options); - this.connectionPool = new PostgresConnectionPool(this.pool); + this.client = new PostgresClient(options); + } - pg.types.setTypeParser(1700, parseFloat); - pg.types.setTypeParser(20, parseInt); + public async createTables(entityRegistry: DatabaseEntityRegistry): Promise { + await createTables(entityRegistry, this.client.pool, this.platform, this); + } + + createSelectorResolver(session: DatabaseSession): SelectorResolver { + return new PostgresSelectorResolver(this.client.pool, this.platform, this, session); } getName(): string { @@ -599,21 +592,17 @@ export class PostgresDatabaseAdapter extends SQLDatabaseAdapter { return ''; } - createPersistence(session: DatabaseSession): SQLPersistence { - return new PostgresPersistence(this.platform, this.connectionPool, session); + createPersistence(session: DatabaseSession): PostgresPersistence { + return new PostgresPersistence(this.platform, this.client.pool, session); } createTransaction(session: DatabaseSession): PostgresDatabaseTransaction { return new PostgresDatabaseTransaction; } - queryFactory(session: DatabaseSession): SQLDatabaseQueryFactory { - return new PostgresSQLDatabaseQueryFactory(this.connectionPool, this.platform, session); - } - disconnect(force?: boolean): void { if (this.closed) return; this.closed = true; - this.pool.end().catch(console.error); + this.client.pool.close(); } } diff --git a/packages/postgres/src/postgres-platform.ts b/packages/postgres/src/postgres-platform.ts index 2bfc51d32..f0a47d99e 100644 --- a/packages/postgres/src/postgres-platform.ts +++ b/packages/postgres/src/postgres-platform.ts @@ -14,7 +14,6 @@ import { DefaultPlatform, IndexModel, isSet, - PreparedAdapter, SqlPlaceholderStrategy, Table, typeResolvesToBigInt, @@ -39,7 +38,6 @@ import { TypeNumberBrand, } from '@deepkit/type'; import { PostgresSchemaParser } from './postgres-schema-parser.js'; -import { PostgreSQLFilterBuilder } from './sql-filter-builder.js'; import { isArray, isObject } from '@deepkit/core'; import sqlstring from 'sqlstring'; @@ -152,10 +150,6 @@ export class PostgresPlatform extends DefaultPlatform { return super.getAggregateSelect(tableName, property, func); } - override createSqlFilterBuilder(adapter: PreparedAdapter, schema: ReflectionClass, tableName: string): PostgreSQLFilterBuilder { - return new PostgreSQLFilterBuilder(adapter, schema, tableName, this.serializer, new this.placeholderStrategy); - } - override getDeepColumnAccessor(table: string, column: string, path: string) { return `${table ? table + '.' : ''}${this.quoteIdentifier(column)}->${this.quoteValue(path)}`; } diff --git a/packages/postgres/src/sql-filter-builder.ts b/packages/postgres/src/sql-filter-builder.ts deleted file mode 100644 index 5d1e807dd..000000000 --- a/packages/postgres/src/sql-filter-builder.ts +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Deepkit Framework - * Copyright (C) 2021 Deepkit UG, Marc J. Schmidt - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the MIT License. - * - * You should have received a copy of the MIT License along with this program. - */ - -import { SQLFilterBuilder } from '@deepkit/sql'; - -export class PostgreSQLFilterBuilder extends SQLFilterBuilder { - regexpComparator(lvalue: string, value: RegExp) { - if (value.flags.includes('i')) return `${lvalue} ~* ${this.bindParam(value.source)}`; - return `${lvalue} ~ ${this.bindParam(value.source)}`; - } -} diff --git a/packages/postgres/tests/factory.ts b/packages/postgres/tests/factory.ts index 8e940f3cd..ea4032985 100644 --- a/packages/postgres/tests/factory.ts +++ b/packages/postgres/tests/factory.ts @@ -4,7 +4,7 @@ import { PostgresDatabaseAdapter } from '../src/postgres-adapter.js'; import { formatError } from '@deepkit/core'; export const databaseFactory: DatabaseFactory = async (entities, plugins): Promise> => { - const adapter = new PostgresDatabaseAdapter({host: 'localhost', database: 'postgres', user: 'postgres'}); + const adapter = new PostgresDatabaseAdapter('postgres://postgres@localhost:5432/postgres'); try { const database = new Database(adapter); diff --git a/packages/postgres/tests/postgres.spec.ts b/packages/postgres/tests/postgres.spec.ts index 83b4004be..73ca71387 100644 --- a/packages/postgres/tests/postgres.spec.ts +++ b/packages/postgres/tests/postgres.spec.ts @@ -1,9 +1,7 @@ -import { AutoIncrement, cast, entity, getVectorTypeOptions, PrimaryKey, ReflectionClass, Unique, Vector } from '@deepkit/type'; +import { AutoIncrement, entity, PrimaryKey, Reference } from '@deepkit/type'; import { expect, test } from '@jest/globals'; import pg from 'pg'; import { databaseFactory } from './factory.js'; -import { DatabaseError, DatabaseInsertError, UniqueConstraintFailure } from '@deepkit/orm'; -import { assertInstanceOf } from '@deepkit/core'; test('count', async () => { const pool = new pg.Pool({ @@ -40,7 +38,9 @@ test('bool and json', async () => { doc: { flag: boolean } = { flag: false }; } - const database = await databaseFactory([Model]); + const database = await databaseFactory([]); + database.registerEntity(Model); + await database.adapter.createTables(database.entityRegistry); { const m = new Model; @@ -49,282 +49,320 @@ test('bool and json', async () => { await database.persist(m); } - const m = await database.query(Model).findOne(); + const m = await database.singleQuery(Model).findOne(); expect(m).toMatchObject({ flag: true, doc: { flag: true } }); }); -test('change different fields of multiple entities', async () => { - @entity.name('model2') - class Model { - firstName: string = ''; - lastName: string = ''; - - constructor(public id: number & PrimaryKey) { - } - } - - const database = await databaseFactory([Model]); - - { - const m1 = new Model(1); - m1.firstName = 'Peter'; - await database.persist(m1); - const m2 = new Model(2); - m2.lastName = 'Smith'; - await database.persist(m2); - } - - { - const m1 = await database.query(Model).filter({ id: 1 }).findOne(); - const m2 = await database.query(Model).filter({ id: 2 }).findOne(); - - m1.firstName = 'Peter2'; - m2.lastName = 'Smith2'; - await database.persist(m1, m2); - } - - { - const m1 = await database.query(Model).filter({ id: 1 }).findOne(); - const m2 = await database.query(Model).filter({ id: 2 }).findOne(); - - expect(m1).toMatchObject({ id: 1, firstName: 'Peter2', lastName: '' }); - expect(m2).toMatchObject({ id: 2, firstName: '', lastName: 'Smith2' }); - } -}); - -test('change pk', async () => { - @entity.name('model3') - class Model { - firstName: string = ''; - - constructor(public id: number & PrimaryKey) { - } - } - - const database = await databaseFactory([Model]); - - { - const m1 = new Model(1); - m1.firstName = 'Peter'; - await database.persist(m1); - } - - { - const m1 = await database.query(Model).filter({ id: 1 }).findOne(); - m1.id = 2; - await database.persist(m1); - } - - { - const m1 = await database.query(Model).filter({ id: 2 }).findOne(); - expect(m1).toMatchObject({ id: 2, firstName: 'Peter' }); - } - - { - const m1 = await database.query(Model).filter({ id: 2 }).findOne(); - m1.id = 3; - m1.firstName = 'Peter2'; - await database.persist(m1); - } - - { - const m1 = await database.query(Model).filter({ id: 3 }).findOne(); - expect(m1).toMatchObject({ id: 3, firstName: 'Peter2' }); - } -}); - -test('for update/share', async () => { - @entity.name('model4') - class Model { - firstName: string = ''; - - constructor(public id: number & PrimaryKey) { - } - } - - const database = await databaseFactory([Model]); - await database.persist(new Model(1), new Model(2)); - - { - const query = database.query(Model).forUpdate(); - const sql = database.adapter.createSelectSql(query); - expect(sql.sql).toContain(' FOR UPDATE'); - } - - { - const query = database.query(Model).forShare(); - const sql = database.adapter.createSelectSql(query); - expect(sql.sql).toContain(' FOR SHARE'); - } - - const items = await database.query(Model).forUpdate().find(); - expect(items).toHaveLength(2); -}); - -test('json field and query', async () => { - @entity.name('product').collection('products') - class Product { +test('join', async () => { + class User { id: number & PrimaryKey & AutoIncrement = 0; - raw?: { [key: string]: any }; - } - const database = await databaseFactory([Product]); - - await database.persist(cast({ raw: { productId: 1, name: 'first' } })); - await database.persist(cast({ raw: { productId: 2, name: 'second' } })); - - { - const res = await database.query(Product).filter({ 'raw.productId': 1 }).find(); - expect(res).toMatchObject([{ id: 1, raw: { productId: 1, name: 'first' } }]); - } + constructor(public name: string) { + } - { - const res = await database.query(Product).filter({ 'raw.productId': 2 }).find(); - expect(res).toMatchObject([{ id: 2, raw: { productId: 2, name: 'second' } }]); + group?: Group & Reference; } -}); -test('unique constraint 1', async () => { - class Model { + class Group { id: number & PrimaryKey & AutoIncrement = 0; - constructor(public username: string & Unique = '') { - } - } - - const database = await databaseFactory([Model]); - - await database.persist(new Model('peter')); - await database.persist(new Model('paul')); - - { - const m1 = new Model('peter'); - await expect(database.persist(m1)).rejects.toThrow('Key (username)=(peter) already exists'); - await expect(database.persist(m1)).rejects.toBeInstanceOf(UniqueConstraintFailure); - - try { - await database.persist(m1); - } catch (error: any) { - assertInstanceOf(error, UniqueConstraintFailure); - assertInstanceOf(error.cause, DatabaseInsertError); - assertInstanceOf(error.cause.cause, DatabaseError); - // error.cause.cause.cause is from the driver - expect(error.cause.cause.cause.table).toBe('Model'); + constructor(public name: string) { } } - { - const m1 = new Model('marie'); - const m2 = new Model('marie'); - await expect(database.persist(m1, m2)).rejects.toThrow('Key (username)=(marie) already exists'); - await expect(database.persist(m1, m2)).rejects.toBeInstanceOf(UniqueConstraintFailure); + const database = await databaseFactory([User, Group]); + const groupAdmin = new Group('admin'); + const groupUser = new Group('user'); + const groups = [groupAdmin, groupUser]; + + const count = 1000; + const users: User[] = []; + for (let i = 0; i < count; i++) { + const group = groups[i % groups.length]; + const user = new User('User ' + i); + user.group = group; + users.push(user); } - { - const m = await database.query(Model).filter({ username: 'paul' }).findOne(); - m.username = 'peter'; - await expect(database.persist(m)).rejects.toThrow('Key (username)=(peter) already exists'); - await expect(database.persist(m)).rejects.toBeInstanceOf(UniqueConstraintFailure); - } + await database.persist(...groups, ...users); { - const p = database.query(Model).filter({ username: 'paul' }).patchOne({ username: 'peter' }); - await expect(p).rejects.toThrow('Key (username)=(peter) already exists'); - await expect(p).rejects.toBeInstanceOf(UniqueConstraintFailure); + const users = await database.query(User).find(); } }); -test('vector embeddings', async () => { - @entity.name('vector_sentences') - class Sentences { - id: number & PrimaryKey & AutoIncrement = 0; - sentence: string = ''; - embedding: Vector<3> = []; - } - - const reflection = ReflectionClass.from(Sentences); - const embedding = reflection.getProperty('embedding'); - console.log(getVectorTypeOptions(embedding.type)); - - const database = await databaseFactory([Sentences]); - - const s1 = new Sentences; - s1.sentence = 'hello'; - s1.embedding = [0, 0.5, 2]; - await database.persist(s1); - - const query = database.query(Sentences); - - type ModelQuery = { - [P in keyof T]?: string; - }; - - const q: ModelQuery = {} as any; - - // count(q.sentence); - // sum(q.sentence); - // - // sort(l2Distance(q.embedding, [1, 2, 3]), 'asc'); - - const eq = (a: any, b: any): any => {}; - const lt = (a: any, b: any): any => {}; - const gt = (a: any, b: any): any => {}; - const where = (a: any): any => {}; - const select = (a: (m: ModelQuery) => any): any => {}; - const groupBy = (...a: any[]): any => {}; - const orderBy = (...a: any[]): any => {}; - const count = (a: any): any => {}; - const l2Distance = (a: any, b: any): any => {}; - - const join = (a: any, b: any): any => {}; - - select(m => { - where(eq(m.sentence, 'hello')); - return [m.id, m.sentence, count(m.sentence)]; - }); - - const sentenceQuery = [123, 123, 123]; - select(m => { - where(`${l2Distance(m.embedding, sentenceQuery)} > 0.5`); - - // WHERE - // embedding <=> ${sentenceQuery} > 0.5 AND group = 'abc' - // ORDER BY embedding <=> ${sentenceQuery}${asd ? ',' + asd : ''} - - // where(lt(l2Distance(m.embedding, [1, 2, 3]), 0.5)); - orderBy(l2Distance(m.embedding, sentenceQuery)); - return [m.id, m.sentence, l2Distance(m.embedding, sentenceQuery)]; - }); - - select(m => { - groupBy(m.sentence); - return [count(m.id)]; - }); - - interface Group { - id: number; - name: string; - } - - interface User { - id: number; - name: string; - groups: Group[]; - } - - select(m => { - join(m.groups, g => { - return [g.id, g.name]; - }); - return [m.id, m.name, m.groups]; - }); - - const rows = await database.query(Sentences) - .select(count(Sentences)) - .filter({ embedding: { $l2Distance: { query: [2, 3, 4], filter: { $eq: 3.774917217635375 } } } }) - .orderBy() - .find(); - - expect(rows).toHaveLength(1); - console.log(rows); -}); +// test('change different fields of multiple entities', async () => { +// @entity.name('model2') +// class Model { +// firstName: string = ''; +// lastName: string = ''; +// +// constructor(public id: number & PrimaryKey) { +// } +// } +// +// const database = await databaseFactory([Model]); +// +// { +// const m1 = new Model(1); +// m1.firstName = 'Peter'; +// await database.persist(m1); +// const m2 = new Model(2); +// m2.lastName = 'Smith'; +// await database.persist(m2); +// } +// +// { +// const m1 = await database.query(Model).filter({ id: 1 }).findOne(); +// const m2 = await database.query(Model).filter({ id: 2 }).findOne(); +// +// m1.firstName = 'Peter2'; +// m2.lastName = 'Smith2'; +// await database.persist(m1, m2); +// } +// +// { +// const m1 = await database.query(Model).filter({ id: 1 }).findOne(); +// const m2 = await database.query(Model).filter({ id: 2 }).findOne(); +// +// expect(m1).toMatchObject({ id: 1, firstName: 'Peter2', lastName: '' }); +// expect(m2).toMatchObject({ id: 2, firstName: '', lastName: 'Smith2' }); +// } +// }); +// +// test('change pk', async () => { +// @entity.name('model3') +// class Model { +// firstName: string = ''; +// +// constructor(public id: number & PrimaryKey) { +// } +// } +// +// const database = await databaseFactory([Model]); +// +// { +// const m1 = new Model(1); +// m1.firstName = 'Peter'; +// await database.persist(m1); +// } +// +// { +// const m1 = await database.query(Model).filter({ id: 1 }).findOne(); +// m1.id = 2; +// await database.persist(m1); +// } +// +// { +// const m1 = await database.query(Model).filter({ id: 2 }).findOne(); +// expect(m1).toMatchObject({ id: 2, firstName: 'Peter' }); +// } +// +// { +// const m1 = await database.query(Model).filter({ id: 2 }).findOne(); +// m1.id = 3; +// m1.firstName = 'Peter2'; +// await database.persist(m1); +// } +// +// { +// const m1 = await database.query(Model).filter({ id: 3 }).findOne(); +// expect(m1).toMatchObject({ id: 3, firstName: 'Peter2' }); +// } +// }); +// +// test('for update/share', async () => { +// @entity.name('model4') +// class Model { +// firstName: string = ''; +// +// constructor(public id: number & PrimaryKey) { +// } +// } +// +// const database = await databaseFactory([Model]); +// await database.persist(new Model(1), new Model(2)); +// +// { +// const query = database.query(Model).forUpdate(); +// const sql = database.adapter.createSelectSql(query); +// expect(sql.sql).toContain(' FOR UPDATE'); +// } +// +// { +// const query = database.query(Model).forShare(); +// const sql = database.adapter.createSelectSql(query); +// expect(sql.sql).toContain(' FOR SHARE'); +// } +// +// const items = await database.query(Model).forUpdate().find(); +// expect(items).toHaveLength(2); +// }); +// +// test('json field and query', async () => { +// @entity.name('product').collection('products') +// class Product { +// id: number & PrimaryKey & AutoIncrement = 0; +// raw?: { [key: string]: any }; +// } +// +// const database = await databaseFactory([Product]); +// +// await database.persist(cast({ raw: { productId: 1, name: 'first' } })); +// await database.persist(cast({ raw: { productId: 2, name: 'second' } })); +// +// { +// const res = await database.query(Product).filter({ 'raw.productId': 1 }).find(); +// expect(res).toMatchObject([{ id: 1, raw: { productId: 1, name: 'first' } }]); +// } +// +// { +// const res = await database.query(Product).filter({ 'raw.productId': 2 }).find(); +// expect(res).toMatchObject([{ id: 2, raw: { productId: 2, name: 'second' } }]); +// } +// }); +// +// test('unique constraint 1', async () => { +// class Model { +// id: number & PrimaryKey & AutoIncrement = 0; +// +// constructor(public username: string & Unique = '') { +// } +// } +// +// const database = await databaseFactory([Model]); +// +// await database.persist(new Model('peter')); +// await database.persist(new Model('paul')); +// +// { +// const m1 = new Model('peter'); +// await expect(database.persist(m1)).rejects.toThrow('Key (username)=(peter) already exists'); +// await expect(database.persist(m1)).rejects.toBeInstanceOf(UniqueConstraintFailure); +// +// try { +// await database.persist(m1); +// } catch (error: any) { +// assertInstanceOf(error, UniqueConstraintFailure); +// assertInstanceOf(error.cause, DatabaseInsertError); +// assertInstanceOf(error.cause.cause, DatabaseError); +// // error.cause.cause.cause is from the driver +// expect(error.cause.cause.cause.table).toBe('Model'); +// } +// } +// +// { +// const m1 = new Model('marie'); +// const m2 = new Model('marie'); +// await expect(database.persist(m1, m2)).rejects.toThrow('Key (username)=(marie) already exists'); +// await expect(database.persist(m1, m2)).rejects.toBeInstanceOf(UniqueConstraintFailure); +// } +// +// { +// const m = await database.query(Model).filter({ username: 'paul' }).findOne(); +// m.username = 'peter'; +// await expect(database.persist(m)).rejects.toThrow('Key (username)=(peter) already exists'); +// await expect(database.persist(m)).rejects.toBeInstanceOf(UniqueConstraintFailure); +// } +// +// { +// const p = database.query(Model).filter({ username: 'paul' }).patchOne({ username: 'peter' }); +// await expect(p).rejects.toThrow('Key (username)=(peter) already exists'); +// await expect(p).rejects.toBeInstanceOf(UniqueConstraintFailure); +// } +// }); +// +// test('vector embeddings', async () => { +// @entity.name('vector_sentences') +// class Sentences { +// id: number & PrimaryKey & AutoIncrement = 0; +// sentence: string = ''; +// embedding: Vector<3> = []; +// } +// +// const reflection = ReflectionClass.from(Sentences); +// const embedding = reflection.getProperty('embedding'); +// console.log(getVectorTypeOptions(embedding.type)); +// +// const database = await databaseFactory([Sentences]); +// +// const s1 = new Sentences; +// s1.sentence = 'hello'; +// s1.embedding = [0, 0.5, 2]; +// await database.persist(s1); +// +// const query = database.query(Sentences); +// +// type ModelQuery = { +// [P in keyof T]?: string; +// }; +// +// const q: ModelQuery = {} as any; +// +// // count(q.sentence); +// // sum(q.sentence); +// // +// // sort(l2Distance(q.embedding, [1, 2, 3]), 'asc'); +// +// const eq = (a: any, b: any): any => {}; +// const lt = (a: any, b: any): any => {}; +// const gt = (a: any, b: any): any => {}; +// const where = (a: any): any => {}; +// const select = (a: (m: ModelQuery) => any): any => {}; +// const groupBy = (...a: any[]): any => {}; +// const orderBy = (...a: any[]): any => {}; +// const count = (a: any): any => {}; +// const l2Distance = (a: any, b: any): any => {}; +// +// const join = (a: any, b: any): any => {}; +// +// select(m => { +// where(eq(m.sentence, 'hello')); +// return [m.id, m.sentence, count(m.sentence)]; +// }); +// +// const sentenceQuery = [123, 123, 123]; +// select(m => { +// where(`${l2Distance(m.embedding, sentenceQuery)} > 0.5`); +// +// // WHERE +// // embedding <=> ${sentenceQuery} > 0.5 AND group = 'abc' +// // ORDER BY embedding <=> ${sentenceQuery}${asd ? ',' + asd : ''} +// +// // where(lt(l2Distance(m.embedding, [1, 2, 3]), 0.5)); +// orderBy(l2Distance(m.embedding, sentenceQuery)); +// return [m.id, m.sentence, l2Distance(m.embedding, sentenceQuery)]; +// }); +// +// select(m => { +// groupBy(m.sentence); +// return [count(m.id)]; +// }); +// +// interface Group { +// id: number; +// name: string; +// } +// +// interface User { +// id: number; +// name: string; +// groups: Group[]; +// } +// +// select(m => { +// join(m.groups, g => { +// return [g.id, g.name]; +// }); +// return [m.id, m.name, m.groups]; +// }); +// +// const rows = await database.query(Sentences) +// .select(count(Sentences)) +// .filter({ embedding: { $l2Distance: { query: [2, 3, 4], filter: { $eq: 3.774917217635375 } } } }) +// .orderBy() +// .find(); +// +// expect(rows).toHaveLength(1); +// console.log(rows); +// }); diff --git a/packages/postgres/tsconfig.json b/packages/postgres/tsconfig.json index 84bb177bf..8bc193c34 100644 --- a/packages/postgres/tsconfig.json +++ b/packages/postgres/tsconfig.json @@ -22,6 +22,7 @@ ] }, "reflection": [ + "src/client.ts", "src/config.ts", "tests/**/*.ts" ], diff --git a/packages/sql/index.ts b/packages/sql/index.ts index e19b367fa..ca9ce343a 100644 --- a/packages/sql/index.ts +++ b/packages/sql/index.ts @@ -16,6 +16,9 @@ export * from './src/migration/migration-provider.js'; export * from './src/select.js'; +export * from './src/sql-builder-registry.js'; + +export * from './src/migration.js'; export * from './src/test.js'; export * from './src/schema/table.js'; export * from './src/reverse/schema-parser.js'; diff --git a/packages/sql/src/migration.ts b/packages/sql/src/migration.ts new file mode 100644 index 000000000..f04ca6314 --- /dev/null +++ b/packages/sql/src/migration.ts @@ -0,0 +1,33 @@ +/** + * Creates (and re-creates already existing) tables in the database. + * This is only for testing purposes useful. + * + * WARNING: THIS DELETES ALL AFFECTED TABLES AND ITS CONTENT. + */ +import { DatabaseEntityRegistry, DatabaseError } from '@deepkit/orm'; +import { DatabaseModel } from './schema/table.js'; +import { DefaultPlatform } from './platform/default-platform.js'; + +export async function createTables( + entityRegistry: DatabaseEntityRegistry, + pool: { getConnection(): Promise<{ run(sql: string): Promise; release(): void }> }, + platform: DefaultPlatform, + adapter: { getName(): string, getSchemaName(): string }, +): Promise { + const connection = await pool.getConnection(); + try { + const database = new DatabaseModel([], adapter.getName()); + database.schemaName = adapter.getSchemaName(); + platform.createTables(entityRegistry, database); + const DDLs = platform.getAddTablesDDL(database); + for (const sql of DDLs) { + try { + await connection.run(sql); + } catch (error) { + throw new DatabaseError(`Could not create table: ${error}\n${sql}`, { cause: error }); + } + } + } finally { + connection.release(); + } +} diff --git a/packages/sql/src/sql-adapter.ts b/packages/sql/src/sql-adapter.ts index 8c2191df5..0576e4b1d 100644 --- a/packages/sql/src/sql-adapter.ts +++ b/packages/sql/src/sql-adapter.ts @@ -23,6 +23,7 @@ import { DatabaseUpdateError, DeleteResult, filter, + getStateCacheId, MigrateOptions, orderBy, OrmEntity, @@ -31,7 +32,7 @@ import { SelectorResolver, SelectorState, } from '@deepkit/orm'; -import { isClass } from '@deepkit/core'; +import { formatError, isClass } from '@deepkit/core'; import { Changes, entity, @@ -47,6 +48,7 @@ import { DatabaseComparator, DatabaseModel } from './schema/table.js'; import { Stopwatch } from '@deepkit/stopwatch'; import { getPreparedEntity, PreparedAdapter, PreparedEntity, PreparedField } from './prepare.js'; import { SqlBuilderRegistry } from './sql-builder-registry.js'; +import { createTables } from './migration.js'; /** * user.address[0].street => [user, address[0].street] @@ -75,6 +77,8 @@ export abstract class SQLStatement { export abstract class SQLConnection { released: boolean = false; + protected cache: { [cacheId: string]: SQLStatement } = {}; + constructor( protected connectionPool: SQLConnectionPool, public logger: DatabaseLogger = new DatabaseLogger, @@ -83,11 +87,19 @@ export abstract class SQLConnection { ) { } + getCache(cacheId: string): SQLStatement | undefined { + return this.cache[cacheId]; + } + + setCache(cacheId: string, statement: SQLStatement) { + this.cache[cacheId] = statement; + } + release() { this.connectionPool.release(this); } - abstract prepare(sql: string): Promise; + abstract prepare(sql: string, selector: SelectorState): Promise; /** * Runs a single SQL query. @@ -97,7 +109,7 @@ export abstract class SQLConnection { abstract getChanges(): Promise; async execAndReturnSingle(sql: string, params?: any[]): Promise { - const stmt = await this.prepare(sql); + const stmt = await this.prepare(sql, {} as any/*todo*/); try { return await stmt.get(params); } finally { @@ -106,7 +118,7 @@ export abstract class SQLConnection { } async execAndReturnAll(sql: string, params?: any[]): Promise { - const stmt = await this.prepare(sql); + const stmt = await this.prepare(sql, {} as any/*todo*/); try { return await stmt.all(params); } finally { @@ -122,7 +134,7 @@ export abstract class SQLConnectionPool { * Reserves an existing or new connection. It's important to call `.release()` on it when * done. When release is not called a resource leak occurs and server crashes. */ - abstract getConnection(logger?: DatabaseLogger, transaction?: DatabaseTransaction, stopwatch?: Stopwatch): Promise; + abstract getConnection(logger?: DatabaseLogger, transaction?: DatabaseTransaction, stopwatch?: Stopwatch, cacheId?: string): Promise; public getActiveConnections() { return this.activeConnections; @@ -163,11 +175,7 @@ function buildSetFromChanges(platform: DefaultPlatform, classSchema: ReflectionC return set; } -export class SQLQueryResolver extends SelectorResolver { - protected tableId = this.platform.getTableIdentifier.bind(this.platform); - protected quoteIdentifier = this.platform.quoteIdentifier.bind(this.platform); - protected quote = this.platform.quoteValue.bind(this.platform); - +export class SQLSelectorResolver extends SelectorResolver { constructor( protected connectionPool: SQLConnectionPool, protected platform: DefaultPlatform, @@ -182,14 +190,11 @@ export class SQLQueryResolver extends SelectorResolver { state.schema, this.platform.serializer, this.session.getHydrator(), - withIdentityMap ? this.session.identityMap : undefined, + withIdentityMap && this.session.withIdentityMap ? this.session.identityMap : undefined, + this.session.withChangeDetection && state.withChangeDetection !== false, ); } - protected getTableIdentifier(schema: ReflectionClass) { - return this.platform.getTableIdentifier(schema); - } - /** * If possible, this method should handle specific SQL errors and convert * them to more specific error classes with more information, e.g. unique constraint. @@ -245,32 +250,31 @@ export class SQLQueryResolver extends SelectorResolver { } } - protected lastPreparedStatement?: SQLStatement; - async find(model: SelectorState): Promise { - const sqlBuilderFrame = this.session.stopwatch ? this.session.stopwatch.start('SQL Builder') : undefined; - const sqlBuilder = new SqlBuilder(this.adapter); - const sql = sqlBuilder.select(model); - if (sqlBuilderFrame) sqlBuilderFrame.end(); + const cacheId = getStateCacheId(model); const connectionFrame = this.session.stopwatch ? this.session.stopwatch.start('Connection acquisition') : undefined; - const connection = await this.connectionPool.getConnection(this.session.logger, this.session.assignedTransaction, this.session.stopwatch); + const connection = await this.connectionPool.getConnection(this.session.logger, this.session.assignedTransaction, this.session.stopwatch, cacheId); if (connectionFrame) connectionFrame.end(); let rows: any[] = []; try { // todo: find a way to cache prepared statements. this is just a test for best case scenario: - let stmt = this.adapter.cache.lastPreparedStatement; + + let stmt = connection.getCache(cacheId); if (!stmt) { - this.adapter.cache.lastPreparedStatement = stmt = await connection.prepare(sql.sql); + const sqlBuilderFrame = this.session.stopwatch ? this.session.stopwatch.start('SQL Builder') : undefined; + const sqlBuilder = new SqlBuilder(this.adapter); + const sql = sqlBuilder.select(model); + stmt = await connection.prepare(sql.sql, model); + connection.setCache(cacheId, stmt); + if (sqlBuilderFrame) sqlBuilderFrame.end(); } - rows = await stmt.all(sql.params); - // rows = await connection.execAndReturnAll(sql.sql, sql.params); + rows = await stmt.all(model.params); } catch (error: any) { - // error = this.handleSpecificError(error); - // console.log(sql.sql, sql.params); - // throw new DatabaseError(`Could not query ${model.schema.getClassName()} due to SQL error ${error.message}`, { cause: error }); + error = this.handleSpecificError(error); + throw new DatabaseError(`Could not query ${model.schema.getClassName()} due to SQL error ${error.message}`, { cause: error }); } finally { connection.release(); } @@ -290,6 +294,7 @@ export class SQLQueryResolver extends SelectorResolver { // } else { // for (const row of rows) results.push(formatter.hydrate(model, row)); // } + //todo the resolver itself does formatting directly from binary. for (const row of rows) results.push(row); if (formatterFrame) formatterFrame.end(); @@ -299,6 +304,7 @@ export class SQLQueryResolver extends SelectorResolver { async findOneOrUndefined(model: SelectorState): Promise { //when joins are used, it's important to fetch all rows + model.limit = 1; const items = await this.find(model); return items[0]; } @@ -469,7 +475,7 @@ export abstract class SQLDatabaseAdapter extends DatabaseAdapter implements Prep public preparedEntities = new Map, PreparedEntity>(); public builderRegistry: SqlBuilderRegistry = new SqlBuilderRegistry; - public cache: {[name: string]: any} = {}; + public cache: { [name: string]: any } = {}; abstract createPersistence(databaseSession: DatabaseSession): SQLPersistence; @@ -494,22 +500,7 @@ export abstract class SQLDatabaseAdapter extends DatabaseAdapter implements Prep * WARNING: THIS DELETES ALL AFFECTED TABLES AND ITS CONTENT. */ public async createTables(entityRegistry: DatabaseEntityRegistry): Promise { - const connection = await this.connectionPool.getConnection(); - try { - const database = new DatabaseModel([], this.getName()); - database.schemaName = this.getSchemaName(); - this.platform.createTables(entityRegistry, database); - const DDLs = this.platform.getAddTablesDDL(database); - for (const sql of DDLs) { - try { - await connection.run(sql); - } catch (error) { - throw new DatabaseError(`Could not create table: ${error}\n${sql}`, { cause: error }); - } - } - } finally { - connection.release(); - } + await createTables(entityRegistry, this.connectionPool, this.platform, this); } public async getMigrations(options: MigrateOptions, entityRegistry: DatabaseEntityRegistry): Promise<{ @@ -718,7 +709,7 @@ export class SQLPersistence extends DatabasePersistence { error = new DatabaseInsertError( classSchema, items as OrmEntity[], - `Could not insert ${classSchema.getClassName()} into database`, + `Could not insert ${classSchema.getClassName()} into database: ${formatError(error)}`, { cause: error }, ); throw this.handleSpecificError(error); diff --git a/packages/sql/src/sql-builder.ts b/packages/sql/src/sql-builder.ts index 773d17381..b65386157 100644 --- a/packages/sql/src/sql-builder.ts +++ b/packages/sql/src/sql-builder.ts @@ -16,16 +16,7 @@ import { ReflectionClass, ReflectionProperty, } from '@deepkit/type'; -import { - getStateCacheId, - isOp, - isProperty, - OpExpression, - opTag, - propertyTag, - SelectorProperty, - SelectorState, -} from '@deepkit/orm'; +import { isOp, isProperty, OpExpression, opTag, propertyTag, SelectorProperty, SelectorState } from '@deepkit/orm'; import { PreparedAdapter } from './prepare.js'; import { SqlBuilderState } from './sql-builder-registry.js'; @@ -574,10 +565,6 @@ export class SqlBuilder implements SqlBuilderState { model: SelectorState, options: { select?: string[] } = {}, ): Sql { - const cacheId = getStateCacheId(model); - let sql = this.adapter.cache[cacheId]; - if (sql) return sql; - const manualSelect = options.select && options.select.length ? options.select : undefined; this.params = model.params.slice(); @@ -591,7 +578,7 @@ export class SqlBuilder implements SqlBuilderState { // } } - sql = this.buildSql(model, 'SELECT ' + (manualSelect || this.sqlSelect).join(', ')); + const sql = this.buildSql(model, 'SELECT ' + (manualSelect || this.sqlSelect).join(', ')); if (this.platform.supportsSelectFor()) { switch (model.for) { @@ -606,6 +593,6 @@ export class SqlBuilder implements SqlBuilderState { } } - return this.adapter.cache[cacheId] = sql; + return sql; } } diff --git a/packages/sql/tests/my-platform.ts b/packages/sql/tests/my-platform.ts index aaa69d42d..35739e5b1 100644 --- a/packages/sql/tests/my-platform.ts +++ b/packages/sql/tests/my-platform.ts @@ -9,7 +9,7 @@ import { SQLConnectionPool, SQLDatabaseAdapter, SQLPersistence, - SQLQueryResolver, + SQLSelectorResolver, } from '../src/sql-adapter.js'; import { DatabaseLogger, DatabaseSession, DatabaseTransaction, SelectorState } from '@deepkit/orm'; import { Stopwatch } from '@deepkit/stopwatch'; @@ -44,7 +44,7 @@ export class MyAdapter extends SQLDatabaseAdapter { platform: DefaultPlatform = new MyPlatform(); createSelectorResolver(session: DatabaseSession): any { - return new SQLQueryResolver(this.connectionPool, this.platform, this, session); + return new SQLSelectorResolver(this.connectionPool, this.platform, this, session); } createPersistence(databaseSession: DatabaseSession): SQLPersistence { diff --git a/packages/sqlite/src/sqlite-adapter.ts b/packages/sqlite/src/sqlite-adapter.ts index 6cb1ee867..9fca23bbd 100644 --- a/packages/sqlite/src/sqlite-adapter.ts +++ b/packages/sqlite/src/sqlite-adapter.ts @@ -37,7 +37,7 @@ import { SQLConnectionPool, SQLDatabaseAdapter, SQLPersistence, - SQLQueryResolver, + SQLSelectorResolver, SQLStatement, } from '@deepkit/sql'; import { Changes, getPatchSerializeFunction, getSerializeFunction, ReflectionClass, resolvePath } from '@deepkit/type'; @@ -87,9 +87,9 @@ export class SQLiteStatement extends SQLStatement { async all(params: any[] = []): Promise { const frame = this.stopwatch ? this.stopwatch.start('Query', FrameCategory.databaseQuery) : undefined; try { - if (frame) frame.data({ sql: this.sql, sqlParams: params }); + // if (frame) frame.data({ sql: this.sql, sqlParams: params }); const res = this.stmt.all(...params); - this.logger.logQuery(this.sql, params); + // this.logger.logQuery(this.sql, params); return res; } catch (error: any) { error = ensureDatabaseError(error); @@ -206,6 +206,8 @@ export class SQLiteConnectionPool extends SQLConnectionPool { //we keep the first connection alive protected firstConnection?: SQLiteConnection; + protected connectionHints: { [cacheId: string]: SQLiteConnection } = {}; + constructor(protected dbPath: string | ':memory:') { super(); //memory databases can not have more than one connection @@ -220,7 +222,12 @@ export class SQLiteConnectionPool extends SQLConnectionPool { return new SQLiteConnection(this, this.dbPath, logger, transaction, stopwatch); } - async getConnection(logger?: DatabaseLogger, transaction?: SQLiteDatabaseTransaction, stopwatch?: Stopwatch): Promise { + async getConnection( + logger?: DatabaseLogger, + transaction?: SQLiteDatabaseTransaction, + stopwatch?: Stopwatch, + cacheHint?: string, + ): Promise { //when a transaction object is given, it means we make the connection sticky exclusively to that transaction //and only release the connection when the transaction is commit/rollback is executed. @@ -229,18 +236,26 @@ export class SQLiteConnectionPool extends SQLConnectionPool { return transaction.connection; } - const connection = this.firstConnection && this.firstConnection.released ? this.firstConnection : - this.activeConnections >= this.maxConnections - //we wait for the next query to be released and reuse it - ? await asyncOperation((resolve) => { - this.queue.push(resolve); - }) - : this.createConnection(logger, transaction, stopwatch); + let connection = cacheHint ? this.connectionHints[cacheHint] : undefined; + if (!connection) { + connection = this.firstConnection && this.firstConnection.released ? this.firstConnection : + this.activeConnections >= this.maxConnections + //we wait for the next query to be released and reuse it + ? await asyncOperation((resolve) => { + this.queue.push(resolve); + }) + : this.createConnection(logger, transaction, stopwatch); + } if (!this.firstConnection) this.firstConnection = connection; connection.released = false; connection.stopwatch = stopwatch; + if (cacheHint) { + // todo add interval to clear stale cache entries + this.connectionHints[cacheHint] = connection; + } + //first connection is always reused, so we update the logger if (logger) connection.logger = logger; @@ -423,7 +438,7 @@ export class SQLitePersistence extends SQLPersistence { } } -export class SQLiteQueryResolver extends SQLQueryResolver { +export class SQLiteSelectorResolver extends SQLSelectorResolver { constructor( protected connectionPool: SQLiteConnectionPool, protected platform: DefaultPlatform, @@ -452,6 +467,7 @@ export class SQLiteQueryResolver extends SQLQueryResolver): SQLQueryResolver { - return new SQLiteQueryResolver(this.connectionPool, this.platform, session); + createSelectorResolver(session: DatabaseSession): SQLSelectorResolver { + return new SQLiteSelectorResolver(this.connectionPool, this.platform, session); } async getInsertBatchSize(schema: ReflectionClass): Promise {