diff --git a/.gitignore b/.gitignore index 5deace34ddd..4daf6b49d28 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ .vscode/clang* .vscode/cpp* .zig-cache +.ccls-cache .bake-debug *.a *.bc diff --git a/.zed/settings.json b/.zed/settings.json new file mode 100644 index 00000000000..e1ecc3c5515 --- /dev/null +++ b/.zed/settings.json @@ -0,0 +1,9 @@ +{ + "lsp": { + "zls": { + "settings": { + "build_on_save_args": ["-Dgenerated-code=./build/debug/codegen", "--watch", "-fincremental"], + }, + }, + }, +} diff --git a/packages/bun-types/sql.d.ts b/packages/bun-types/sql.d.ts index 59681350ffb..cc47cfe77c9 100644 --- a/packages/bun-types/sql.d.ts +++ b/packages/bun-types/sql.d.ts @@ -10,6 +10,72 @@ declare module "bun" { * Releases the client back to the connection pool */ release(): void; + + /** + * Register callback when server replies with CopyInResponse/CopyOutResponse + */ + onCopyStart(handler: () => void): void; + + /** + * Send COPY data chunk (for COPY FROM STDIN) + */ + copySendData(data: string | Uint8Array | ArrayBuffer): void; + + /** + * Signal end of COPY FROM STDIN operation + */ + copyDone(): void; + + /** + * Abort COPY operation with optional error message + */ + copyFail(message?: string): void; + + /** + * Enable or disable streaming mode for COPY TO + * When enabled, data is not accumulated in memory and chunks are emitted via onCopyChunk + */ + setCopyStreamingMode(enable: boolean): void; + + /** + * Set COPY operation timeout in milliseconds (0 to disable) + */ + setCopyTimeout(ms: number): void; + + /** + * Set maximum buffer size for COPY operations in bytes + */ + setMaxCopyBufferSize(bytes: number): void; + + /** + * Register callback for streaming COPY TO data chunks + * + * @returns true if the handler was registered (adapter supports streaming), otherwise false. + */ + onCopyChunk(handler: (chunk: string | ArrayBuffer | Uint8Array) => void): boolean; + + /** + * Register callback when COPY TO completes + * + * @returns true if the handler was registered (adapter supports streaming), otherwise false. + */ + onCopyEnd(handler: () => void): boolean; + + /** + * Get current COPY operation defaults + */ + getCopyDefaults(): { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }; + + /** + * Set COPY operation defaults + */ + setCopyDefaults(defaults: { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }): this; } type ArrayType = @@ -18,7 +84,6 @@ declare module "bun" { | "CHAR" | "NAME" | "TEXT" - | "CHAR" | "VARCHAR" | "SMALLINT" | "INT2VECTOR" @@ -60,6 +125,40 @@ declare module "bun" { | "PG_DATABASE" | (string & {}); + /** + * PostgreSQL COPY binary format base types + */ + type CopyBinaryBaseType = + | "bool" + | "int2" + | "int4" + | "int8" + | "float4" + | "float8" + | "text" + | "varchar" + | "bpchar" + | "bytea" + | "date" + | "time" + | "timestamp" + | "timestamptz" + | "uuid" + | "json" + | "jsonb" + | "numeric" + | "interval"; + + /** + * PostgreSQL COPY binary format array types + */ + type CopyBinaryArrayType = `${CopyBinaryBaseType}[]`; + + /** + * PostgreSQL COPY binary format type tokens + */ + type CopyBinaryType = CopyBinaryBaseType | CopyBinaryArrayType; + /** * Represents a SQL array parameter */ @@ -101,6 +200,48 @@ declare module "bun" { constructor(message: string); } + /** + * COPY FROM STDIN options (PostgreSQL COPY protocol) + */ + type CopyFromOptions = { + format?: "text" | "csv" | "binary"; + delimiter?: string; + null?: string; + sanitizeNUL?: boolean; + replaceInvalid?: string; + signal?: AbortSignal; + onProgress?: (info: { bytesSent: number; chunksSent: number }) => void; + batchSize?: number; + /** + * When format is "binary" and passing row arrays, provide per-column type tokens + * (e.g. "int4","text","uuid","int4[]") + */ + binaryTypes?: readonly CopyBinaryType[]; + /** Maximum number of bytes to send per chunk (defaults to 256 KiB) */ + maxChunkSize?: number; + /** Maximum total number of bytes to send (0 = unlimited) */ + maxBytes?: number; + /** COPY operation timeout in milliseconds (0 = no timeout) */ + timeout?: number; + }; + + /** + * COPY TO STDOUT options (PostgreSQL COPY protocol) + */ + type CopyToOptions = { + table: string; + columns?: string[]; + format?: "text" | "csv" | "binary"; + signal?: AbortSignal; + onProgress?: (info: { bytesReceived: number; chunksReceived: number }) => void; + /** Maximum total number of bytes to receive (0 = unlimited) */ + maxBytes?: number; + /** Enable streaming mode to avoid buffering (defaults to true) */ + stream?: boolean; + /** COPY operation timeout in milliseconds (0 = no timeout) */ + timeout?: number; + }; + class PostgresError extends SQLError { public readonly code: string; public readonly errno?: string | undefined; @@ -537,6 +678,53 @@ declare module "bun" { * ``` */ (value: T): SQL.Helper; + + /** COPY FROM STDIN - bulk import helper (PostgreSQL COPY protocol) */ + copyFrom( + table: string, + columns: string[], + data: + | string + | unknown[] + | Iterable + | AsyncIterable + | AsyncIterable + | (() => Iterable), + options?: SQL.CopyFromOptions, + ): Promise<{ command: string | null; count: number | null }>; + + /** COPY TO STDOUT - streaming export helper (PostgreSQL COPY protocol) */ + copyTo(queryOrOptions: string | SQL.CopyToOptions): AsyncIterable; + + /** COPY TO STDOUT piping helper - pipe stream directly to a sink */ + copyToPipeTo( + queryOrOptions: string | SQL.CopyToOptions, + writable: + | WritableStream + | { + write: (chunk: string | ArrayBuffer | Uint8Array) => unknown | Promise; + close?: () => unknown | Promise; + end?: () => unknown | Promise; + }, + ): Promise; + + /** + * Get current COPY operation defaults + */ + getCopyDefaults(): { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }; + + /** + * Set COPY operation defaults + * + * Returns the SQL instance for chaining. + */ + setCopyDefaults(defaults: { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }): this; } /** @@ -866,6 +1054,53 @@ declare module "bun" { * const result = await sql.file("query.sql", [1, 2, 3]); */ file(filename: string, values?: any[]): SQL.Query; + + /** COPY FROM STDIN - bulk import helper (PostgreSQL COPY protocol) */ + copyFrom( + table: string, + columns: string[], + data: + | string + | unknown[] + | Iterable + | AsyncIterable + | AsyncIterable + | (() => Iterable), + options?: SQL.CopyFromOptions, + ): Promise<{ command: string | null; count: number | null }>; + + /** COPY TO STDOUT - streaming export helper (PostgreSQL COPY protocol) */ + copyTo(queryOrOptions: string | SQL.CopyToOptions): AsyncIterable; + + /** COPY TO STDOUT piping helper - pipe stream directly to a sink */ + copyToPipeTo( + queryOrOptions: string | SQL.CopyToOptions, + writable: + | WritableStream + | { + write: (chunk: string | ArrayBuffer | Uint8Array) => unknown | Promise; + close?: () => unknown | Promise; + end?: () => unknown | Promise; + }, + ): Promise; + + /** + * Get current COPY operation defaults + */ + getCopyDefaults(): { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }; + + /** + * Set COPY operation defaults + * + * Returns the SQL instance for chaining. + */ + setCopyDefaults(defaults: { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }): this; } /** diff --git a/src/bun.js/api/sql.classes.ts b/src/bun.js/api/sql.classes.ts index ee1405ca47c..26804f8ce63 100644 --- a/src/bun.js/api/sql.classes.ts +++ b/src/bun.js/api/sql.classes.ts @@ -3,6 +3,54 @@ import { ClassDefinition, define } from "../../codegen/class-definitions"; const types = ["PostgresSQL", "MySQL"]; const classes: ClassDefinition[] = []; for (const type of types) { + const proto: any = { + close: { + fn: "doClose", + }, + connected: { + getter: "getConnected", + }, + ref: { + fn: "doRef", + }, + unref: { + fn: "doUnref", + }, + flush: { + fn: "doFlush", + }, + queries: { + getter: "getQueries", + this: true, + }, + onconnect: { + getter: "getOnConnect", + setter: "setOnConnect", + this: true, + }, + onclose: { + getter: "getOnClose", + setter: "setOnClose", + this: true, + }, + }; + + // Add COPY methods only for PostgreSQL + if (type === "PostgresSQL") { + proto.sendCopyData = { + fn: "sendCopyData", + length: 1, + }; + proto.sendCopyDone = { + fn: "sendCopyDone", + length: 0, + }; + proto.sendCopyFail = { + fn: "sendCopyFail", + length: 1, + }; + } + classes.push( define({ name: `${type}Connection`, @@ -19,37 +67,7 @@ for (const type of types) { // }, }, JSType: "0b11101110", - proto: { - close: { - fn: "doClose", - }, - connected: { - getter: "getConnected", - }, - ref: { - fn: "doRef", - }, - unref: { - fn: "doUnref", - }, - flush: { - fn: "doFlush", - }, - queries: { - getter: "getQueries", - this: true, - }, - onconnect: { - getter: "getOnConnect", - setter: "setOnConnect", - this: true, - }, - onclose: { - getter: "getOnClose", - setter: "setOnClose", - this: true, - }, - }, + proto, values: ["onconnect", "onclose", "queries"], }), ); diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index 047aa1eab1b..1c54ebea386 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -173,6 +173,11 @@ const errors: ErrorCodeMapping = [ ["ERR_POSTGRES_AUTHENTICATION_FAILED_PBKDF2", Error, "PostgresError"], ["ERR_POSTGRES_CONNECTION_CLOSED", Error, "PostgresError"], ["ERR_POSTGRES_CONNECTION_TIMEOUT", Error, "PostgresError"], + ["ERR_POSTGRES_COPY_BOTH_NOT_IMPLEMENTED", Error, "PostgresError"], + ["ERR_POSTGRES_COPY_BUFFER_TOO_LARGE", RangeError, "PostgresError"], + ["ERR_POSTGRES_COPY_CHUNK_TOO_LARGE", RangeError, "PostgresError"], + ["ERR_POSTGRES_COPY_TIMEOUT", Error, "PostgresError"], + ["ERR_POSTGRES_COPY_WRITE_FAILED", Error, "PostgresError"], ["ERR_POSTGRES_EXPECTED_REQUEST", Error, "PostgresError"], ["ERR_POSTGRES_EXPECTED_STATEMENT", Error, "PostgresError"], ["ERR_POSTGRES_IDLE_TIMEOUT", Error, "PostgresError"], @@ -199,6 +204,7 @@ const errors: ErrorCodeMapping = [ ["ERR_POSTGRES_SYNTAX_ERROR", SyntaxError, "PostgresError"], ["ERR_POSTGRES_TLS_NOT_AVAILABLE", Error, "PostgresError"], ["ERR_POSTGRES_TLS_UPGRADE_FAILED", Error, "PostgresError"], + ["ERR_POSTGRES_UNEXPECTED_COPY_DATA", Error, "PostgresError"], ["ERR_POSTGRES_UNEXPECTED_MESSAGE", Error, "PostgresError"], ["ERR_POSTGRES_UNKNOWN_AUTHENTICATION_METHOD", Error, "PostgresError"], ["ERR_POSTGRES_UNKNOWN_FORMAT_CODE", Error, "PostgresError"], diff --git a/src/codegen/class-definitions.ts b/src/codegen/class-definitions.ts index 7af59d45b0b..849bf8eda18 100644 --- a/src/codegen/class-definitions.ts +++ b/src/codegen/class-definitions.ts @@ -291,7 +291,7 @@ export function define( Object.entries(klass) .sort(([a], [b]) => a.localeCompare(b)) .map(([k, v]) => { - v["DOMJIT"] = undefined; + v.DOMJIT = undefined; return [k, v]; }), ), @@ -299,7 +299,7 @@ export function define( Object.entries(proto) .sort(([a], [b]) => a.localeCompare(b)) .map(([k, v]) => { - v["DOMJIT"] = undefined; + v.DOMJIT = undefined; return [k, v]; }), ), diff --git a/src/js/builtins.d.ts b/src/js/builtins.d.ts index 570922df8c4..1280337ce3c 100644 --- a/src/js/builtins.d.ts +++ b/src/js/builtins.d.ts @@ -2,6 +2,156 @@ /// /// /// +// Bun.SQL COPY streaming helpers (copyFrom / copyTo) +declare namespace Bun { + interface SQL { + /** + * COPY FROM STDIN - High-level helper for bulk data import. + * + * Efficiently inserts large amounts of data into PostgreSQL using the COPY protocol. + * Much faster than individual INSERT statements for bulk operations. + * + * @param table - Target table name (will be properly escaped) + * @param columns - Array of column names to insert into + * @param data - Data to insert, supporting multiple formats: + * - `string`: Raw text data (tab-delimited by default, or CSV if format="csv") + * - `any[][]`: Array of row arrays (will be serialized based on format) + * - `Iterable`: Generator or iterable of row arrays + * - `AsyncIterable`: Async iterable of row arrays + * - `AsyncIterable`: Raw data chunks (for streaming) + * - `() => Iterable`: Function returning an iterable + * + * @param options - Configuration options: + * - `format`: "text" (default), "csv", or "binary" + * - `delimiter`: Custom delimiter (default: tab for text, comma for csv) + * - `null`: Custom NULL representation (default: \N for text, empty for csv) + * - `sanitizeNUL`: Strip NUL bytes (0x00) from data (default: false) + * - `replaceInvalid`: Replacement string for NUL bytes (default: "") + * - `signal`: AbortSignal for cancellation + * - `onProgress`: Callback for progress tracking (receives {bytesSent, chunksSent}) + * + * @returns Promise with command tag and row count + * + * @throws Error if connection is not available or COPY operation fails + * + * @example + * ```typescript + * // Array of rows + * await sql.copyFrom("users", ["id", "name"], [ + * [1, "Alice"], + * [2, "Bob"] + * ]); + * + * // Generator for memory efficiency + * async function* generateRows() { + * for (let i = 0; i < 1000000; i++) { + * yield [i, `User ${i}`]; + * } + * } + * await sql.copyFrom("users", ["id", "name"], generateRows()); + * + * // CSV format with progress + * await sql.copyFrom("users", ["id", "name"], csvData, { + * format: "csv", + * onProgress: ({ bytesSent }) => console.log(`Sent ${bytesSent} bytes`) + * }); + * ``` + */ + copyFrom( + table: string, + columns: string[], + data: + | string + | any[] + | Iterable + | AsyncIterable + | AsyncIterable + | (() => Iterable), + options?: { + /** Data format: "text" (default), "csv", or "binary" */ + format?: "text" | "csv" | "binary"; + /** Field delimiter (default: tab for text, comma for csv) */ + delimiter?: string; + /** NULL representation (default: \N for text, empty for csv) */ + null?: string; + /** Strip NUL (0x00) bytes from strings and data (default: false) */ + sanitizeNUL?: boolean; + /** Replacement for NUL bytes when sanitizeNUL is true (default: "") */ + replaceInvalid?: string; + /** AbortSignal for cancellation */ + signal?: AbortSignal; + /** Progress callback receiving {bytesSent, chunksSent} */ + onProgress?: (info: { bytesSent: number; chunksSent: number }) => void; + }, + ): Promise; + + /** + * COPY TO STDOUT - Streaming helper for bulk data export. + * + * Efficiently exports data from PostgreSQL using the COPY protocol. + * Returns an async iterable that streams data chunks as they arrive. + * Much faster than fetching individual rows for large datasets. + * + * @param queryOrOptions - Either: + * - A string SQL query: `"COPY table_name TO STDOUT"` or `"COPY (SELECT ...) TO STDOUT"` + * - An options object with table, columns, and format + * + * @returns AsyncIterable - Stream of data chunks + * - For "text" or "csv" format: yields `string` chunks + * - For "binary" format: yields `ArrayBuffer` chunks + * + * @throws Error if connection is not available or COPY operation fails + * + * @example + * ```typescript + * // Query string form + * for await (const chunk of sql.copyTo("COPY users TO STDOUT")) { + * console.log(chunk); // string chunk + * } + * + * // Options form with CSV + * for await (const chunk of sql.copyTo({ + * table: "users", + * columns: ["id", "name"], + * format: "csv" + * })) { + * process.stdout.write(chunk); + * } + * + * // With progress tracking and cancellation + * const controller = new AbortController(); + * for await (const chunk of sql.copyTo({ + * table: "large_table", + * format: "binary", + * signal: controller.signal, + * onProgress: ({ bytesReceived }) => { + * if (bytesReceived > 1000000) controller.abort(); + * } + * })) { + * // Process binary chunk + * } + * ``` + */ + copyTo( + queryOrOptions: + | string + | { + /** Table name to export from */ + table: string; + /** Column names to export (omit for all columns) */ + columns?: string[]; + /** Data format: "text" (default), "csv", or "binary" */ + format?: "text" | "csv" | "binary"; + /** AbortSignal for cancellation */ + signal?: AbortSignal; + /** Progress callback receiving {bytesReceived, chunksReceived} */ + onProgress?: (info: { bytesReceived: number; chunksReceived: number }) => void; + }, + ): AsyncIterable; + } +} + + // Typedefs for JSC intrinsics. Instead of @, we use $ type TODO = any; diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index dc436d367fe..b3159d89799 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -9,11 +9,171 @@ const { MySQLAdapter } = require("internal/sql/mysql"); const { SQLiteAdapter } = require("internal/sql/sqlite"); const { SQLHelper, parseOptions } = require("internal/sql/shared"); +const { validateOneOf, validateObject } = require("internal/validators"); + const { SQLError, PostgresError, SQLiteError, MySQLError } = require("internal/sql/errors"); +const ensurePostgresAdapter = (adapter: any, methodName: string) => { + if (adapter !== "postgres") { + throw $ERR_INVALID_ARG_VALUE( + "options.adapter", + adapter, + `${methodName} is only supported for the postgres adapter`, + ); + } +}; + +type CopyStreamLikeSink = { + write: (chunk: string | ArrayBuffer | Uint8Array) => unknown | Promise; + close?: () => unknown | Promise; + end?: () => unknown | Promise; +}; + +const isWritableStream = (value: unknown): value is WritableStream => { + return ( + !!value && + typeof value === "object" && + "getWriter" in value && + typeof (value as { getWriter: unknown }).getWriter === "function" + ); +}; + +const isWritableSink = (value: unknown): value is CopyStreamLikeSink => { + return ( + !!value && + typeof value === "object" && + "write" in value && + typeof (value as { write: unknown }).write === "function" + ); +}; + +const isIterable = (value: unknown): value is Iterable => { + return ( + !!value && + typeof value === "object" && + Symbol.iterator in value && + typeof (value as { [Symbol.iterator]: unknown })[Symbol.iterator] === "function" + ); +}; + +const isAsyncIterable = (value: unknown): value is AsyncIterable => { + return ( + !!value && + typeof value === "object" && + Symbol.asyncIterator in value && + typeof (value as { [Symbol.asyncIterator]: unknown })[Symbol.asyncIterator] === "function" + ); +}; + +const hasByteLength = (value: unknown): value is { byteLength: number } => { + return ( + !!value && + (typeof value === "object" || typeof value === "function") && + "byteLength" in (value as object) && + typeof (value as { byteLength: unknown }).byteLength === "number" + ); +}; + +const toUint8ArrayView = (value: unknown): Uint8Array | null => { + if (value instanceof Uint8Array) return value; + if (value instanceof ArrayBuffer) return new Uint8Array(value); + if (ArrayBuffer.isView(value)) { + const view = value as ArrayBufferView; + return new Uint8Array(view.buffer, view.byteOffset, view.byteLength); + } + return null; +}; + +// Import shared PostgreSQL encoding utilities (types only via import type, runtime via require) +import type { CopyBinaryType, CopyBinaryBaseType } from "internal/sql/postgres-encoding"; + +const { + encodeBinaryValue, + encodeBinaryRow, + encodeArray1D, + createBinaryCopyHeader, + createBinaryCopyTrailer, + copyTextEscape, + csvQuote: pgCsvQuote, + needsCsvQuote, +} = require("internal/sql/postgres-encoding"); + +const { + TYPE_OID, + TYPE_ARRAY_OID, + isSupportedBaseType, + isSupportedArrayType, + getSupportedBaseTypes, + getSupportedArrayTypes, +} = require("internal/sql/postgres-types"); + const defineProperties = Object.defineProperties; -type TransactionCallback = (sql: (strings: string, ...values: any[]) => Query) => Promise; +// Default COPY protocol constants +const DEFAULT_COPY_BATCH_SIZE = 64 * 1024; // 64 KiB - batch accumulation threshold +const DEFAULT_COPY_MAX_CHUNK_SIZE = 256 * 1024; // 256 KiB - max bytes per chunk + +// Re-export types for convenience +export type { CopyBinaryType, CopyBinaryBaseType }; + +interface CopyFromOptionsBase { + format?: "text" | "csv" | "binary"; + delimiter?: string; + null?: string; + sanitizeNUL?: boolean; + replaceInvalid?: string; + signal?: AbortSignal; + onProgress?: (info: { bytesSent: number; chunksSent: number }) => void; + batchSize?: number; + /** + * Maximum number of bytes to send per chunk. Defaults to 256 KiB when not set. + */ + maxChunkSize?: number; + /** + * Maximum total number of bytes to send for this COPY FROM operation. + * When exceeded, the operation is aborted with CopyFail. + */ + maxBytes?: number; + /** + * COPY operation timeout in milliseconds (0 = no timeout). + */ + timeout?: number; +} + +interface CopyFromBinaryOptions extends CopyFromOptionsBase { + format: "binary"; + binaryTypes: CopyBinaryType[]; +} + +type CopyFromOptions = CopyFromOptionsBase | CopyFromBinaryOptions; + +const isCopyFromBinaryOptions = (options: CopyFromOptions | undefined): options is CopyFromBinaryOptions => { + return !!options && options.format === "binary" && "binaryTypes" in options && $isArray(options.binaryTypes); +}; + +interface CopyToOptions { + table: string; + columns?: string[]; + format?: "text" | "csv" | "binary"; + signal?: AbortSignal; + onProgress?: (info: { bytesReceived: number; chunksReceived: number }) => void; + /** + * Maximum total number of bytes to receive for this COPY TO operation. + * When exceeded, the stream stops early with an error. + */ + maxBytes?: number; + /** + * Enable streaming mode to avoid buffering in Zig. Defaults to true. + */ + stream?: boolean; + /** + * COPY operation timeout in milliseconds (0 = no timeout). + */ + timeout?: number; +} + +type SQLTemplateFn = (strings: TemplateStringsArray | string, ...values: unknown[]) => Query; +type TransactionCallback = (sql: SQLTemplateFn) => Promise; enum ReservedConnectionState { acceptQueries = 1 << 0, @@ -41,6 +201,166 @@ function adapterFromOptions(options: Bun.SQL.__internal.DefinedOptions) { } } +// Helper types and functions for COPY protocol +type __CopyDefaults__ = { + from: { maxChunkSize: number; maxBytes: number }; + to: { stream: boolean; maxBytes: number }; +}; + +/** + * Resolves copyFrom limits (maxBytes, maxChunkSize) from options and pool defaults + */ +function resolveCopyFromLimits(options: any, pool: any): { maxBytes: number; maxChunkSize: number } { + const __defaults__: __CopyDefaults__ | undefined = + "getCopyDefaults" in pool + ? (pool as unknown as { getCopyDefaults: () => __CopyDefaults__ }).getCopyDefaults() + : undefined; + const __fromDefaults__ = (__defaults__ && __defaults__.from) || { + maxChunkSize: DEFAULT_COPY_MAX_CHUNK_SIZE, + maxBytes: 0, + }; + + const maxBytes = + options && typeof options.maxBytes === "number" && options.maxBytes > 0 + ? Number(options.maxBytes) + : Math.max(0, Math.trunc(Number(__fromDefaults__.maxBytes) || 0)); + + const maxChunkSize = + options && typeof options.maxChunkSize === "number" && options.maxChunkSize > 0 + ? Math.max(1, Math.trunc(Number(options.maxChunkSize))) + : Math.max(1, Math.trunc(Number(__fromDefaults__.maxChunkSize) || DEFAULT_COPY_MAX_CHUNK_SIZE)); + + return { maxBytes, maxChunkSize }; +} + +/** + * Resolves copyTo maxBytes from query options and pool defaults + */ +function resolveCopyToMaxBytes(queryOrOptions: any, pool: any): number { + const __defaults__: __CopyDefaults__ | undefined = + "getCopyDefaults" in pool + ? (pool as unknown as { getCopyDefaults: () => __CopyDefaults__ }).getCopyDefaults() + : undefined; + const __toDefaults__ = (__defaults__ && __defaults__.to) || { stream: true, maxBytes: 0 }; + + return typeof queryOrOptions === "string" + ? Math.max(0, Math.trunc(Number(__toDefaults__.maxBytes) || 0)) + : typeof queryOrOptions?.maxBytes === "number" && queryOrOptions.maxBytes > 0 + ? Number(queryOrOptions.maxBytes) + : Math.max(0, Math.trunc(Number(__toDefaults__.maxBytes) || 0)); +} + +function getByteLength(value: string | { byteLength: number } | Uint8Array | ArrayBuffer): number { + if (typeof value === "string") { + return Buffer?.byteLength ? Buffer.byteLength(value, "utf8") : new TextEncoder().encode(value).byteLength; + } + const length = value?.byteLength; + return Number.isFinite(length) ? Math.max(0, Math.trunc(length)) : 0; +} + +/** + * Await socket writability with a microtask fallback to prevent hanging. + * Used throughout COPY protocol to handle backpressure. + */ +async function awaitWritableWithFallback(reserved: any, pool: any): Promise { + if (reserved && typeof reserved.awaitWritable === "function") { + const promise = reserved.awaitWritable(); + if (promise && typeof promise.then === "function") { + await promise; + return; + } + } + if (pool && typeof pool.awaitWritableFor === "function") { + const promise = pool.awaitWritableFor(reserved); + if (promise && typeof promise.then === "function") { + await promise; + return; + } + } + await new Promise(queueMicrotask); +} + +/** + * Sends data in chunks with backpressure handling + */ +async function sendChunkedData( + data: Uint8Array | string, + reserved: any, + pool: any, + limits: { maxBytes: number; maxChunkSize: number }, + counters: { bytesSent: number; chunksSent: number }, + notifyProgress: () => void, +): Promise { + const { maxBytes, maxChunkSize } = limits; + + // Convert string to Uint8Array to ensure chunking by bytes, not characters + // This prevents splitting multi-byte UTF-8 characters + const bytes: Uint8Array = typeof data === "string" ? new TextEncoder().encode(data) : data; + const dataLength = bytes.byteLength; + + if (dataLength <= maxChunkSize) { + if (maxBytes && counters.bytesSent + dataLength > maxBytes) { + throw new Error("copyFrom: maxBytes exceeded"); + } + reserved.copySendData(bytes); + counters.bytesSent += dataLength; + counters.chunksSent += 1; + notifyProgress(); + await awaitWritableWithFallback(reserved, pool); + } else { + for (let i = 0; i < dataLength; i += maxChunkSize) { + const part = bytes.subarray(i, Math.min(dataLength, i + maxChunkSize)); + const partLength = part.byteLength; + + if (maxBytes && counters.bytesSent + partLength > maxBytes) { + throw new Error("copyFrom: maxBytes exceeded"); + } + reserved.copySendData(part); + counters.bytesSent += partLength; + counters.chunksSent += 1; + notifyProgress(); + await awaitWritableWithFallback(reserved, pool); + } + } +} + +/** + * Validates binary types for COPY FROM binary format + */ +function validateBinaryTypes(options: any, columns: string[] | undefined): string[] { + const types = options?.binaryTypes as string[] | undefined; + if (!types || !$isArray(types)) { + throw new Error( + "Binary COPY format requires raw bytes or provide options.binaryTypes to enable automatic binary row encoding.", + ); + } + if (types.length !== (columns?.length ?? types.length)) { + throw new Error("binaryTypes length must match number of columns for COPY FROM."); + } + + // Validate that each provided token is a supported base or array type + // Uses helper functions from shared postgres-types module + const isSupportedToken = (token: string): boolean => { + if (typeof token !== "string" || token.length === 0) return false; + return token.endsWith("[]") ? isSupportedArrayType(token) : isSupportedBaseType(token); + }; + + for (let i = 0; i < types.length; i++) { + const token = types[i]; + if (!isSupportedToken(token)) { + throw new Error( + `Unsupported COPY binaryTypes token at index ${i}: "${token}".` + + " Supported base types include: " + + getSupportedBaseTypes().join(", ") + + "; supported array types include: " + + getSupportedArrayTypes().join(", "), + ); + } + } + + return types; +} + const SQL: typeof Bun.SQL = function SQL( stringOrUrlOrOptions: Bun.SQL.Options | string | undefined = undefined, definitelyOptionsButMaybeEmpty: Bun.SQL.Options = {}, @@ -109,8 +429,12 @@ const SQL: typeof Bun.SQL = function SQL( } function queryFromPool( - strings: string | TemplateStringsArray | import("internal/sql/shared.ts").SQLHelper | Query, - values: any[], + strings: + | string + | TemplateStringsArray + | import("internal/sql/shared.ts").SQLHelper + | Query, + values: unknown[], ) { try { return new Query( @@ -126,8 +450,12 @@ const SQL: typeof Bun.SQL = function SQL( } function unsafeQuery( - strings: string | TemplateStringsArray | import("internal/sql/shared.ts").SQLHelper | Query, - values: any[], + strings: + | string + | TemplateStringsArray + | import("internal/sql/shared.ts").SQLHelper + | Query, + values: unknown[], ) { try { let flags = connectionInfo.bigint ? SQLQueryFlags.bigint | SQLQueryFlags.unsafe : SQLQueryFlags.unsafe; @@ -173,8 +501,12 @@ const SQL: typeof Bun.SQL = function SQL( } function queryFromTransaction( - strings: string | TemplateStringsArray | import("internal/sql/shared.ts").SQLHelper | Query, - values: any[], + strings: + | string + | TemplateStringsArray + | import("internal/sql/shared.ts").SQLHelper + | Query, + values: unknown[], pooledConnection: PooledPostgresConnection, transactionQueries: Set>, ) { @@ -197,8 +529,12 @@ const SQL: typeof Bun.SQL = function SQL( } function unsafeQueryFromTransaction( - strings: string | TemplateStringsArray | import("internal/sql/shared.ts").SQLHelper | Query, - values: any[], + strings: + | string + | TemplateStringsArray + | import("internal/sql/shared.ts").SQLHelper + | Query, + values: unknown[], pooledConnection: PooledPostgresConnection, transactionQueries: Set>, ) { @@ -237,7 +573,7 @@ const SQL: typeof Bun.SQL = function SQL( } } - function onReserveConnected(this: Query, err: Error | null, pooledConnection) { + function onReserveConnected(this: Query, err: Error | null, pooledConnection) { const { resolve, reject } = this; if (err) { @@ -253,12 +589,21 @@ const SQL: typeof Bun.SQL = function SQL( queries: new Set(), }; + const clampUint32 = (value: number) => { + const n = Number(value); + if (!Number.isFinite(n) || n <= 0) return 0; + return Math.min(0xffffffff, Math.trunc(n)); + }; + const onClose = onTransactionDisconnected.bind(state); if (pooledConnection.onClose) { pooledConnection.onClose(onClose); } - function reserved_sql(strings: string | TemplateStringsArray | SQLHelper | Query, ...values: any[]) { + function reserved_sql( + strings: string | TemplateStringsArray | SQLHelper | Query, + ...values: unknown[] + ) { if ( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) @@ -298,7 +643,7 @@ const SQL: typeof Bun.SQL = function SQL( reserved_sql.commitDistributed = async function (name: string) { if (!pool.getCommitDistributedSQL) { - throw Error(`This adapter doesn't support distributed transactions.`); + throw new Error(`This adapter doesn't support distributed transactions.`); } const sql = pool.getCommitDistributedSQL(name); @@ -306,7 +651,7 @@ const SQL: typeof Bun.SQL = function SQL( }; reserved_sql.rollbackDistributed = async function (name: string) { if (!pool.getRollbackDistributedSQL) { - throw Error(`This adapter doesn't support distributed transactions.`); + throw new Error(`This adapter doesn't support distributed transactions.`); } const sql = pool.getRollbackDistributedSQL(name); @@ -317,6 +662,149 @@ const SQL: typeof Bun.SQL = function SQL( // this matchs the behavior of the postgres package reserved_sql.reserve = () => sql.reserve(); reserved_sql.array = sql.array; + + // COPY FROM STDIN low-level helpers (Phase 2) + // These delegate to adapter instance methods bound to this reserved connection + reserved_sql.onCopyStart = (handler: () => void) => { + ensurePostgresAdapter(connectionInfo.adapter, "onCopyStart"); + // register one-shot callback when server replies with CopyInResponse/CopyOutResponse + pool.onCopyStartFor(pooledConnection, handler); + }; + reserved_sql.copySendData = (data: string | Uint8Array) => { + ensurePostgresAdapter(connectionInfo.adapter, "copySendData"); + pool.copySendDataFor(pooledConnection, data); + }; + reserved_sql.copyDone = () => { + ensurePostgresAdapter(connectionInfo.adapter, "copyDone"); + pool.copyDoneFor(pooledConnection); + }; + reserved_sql.copyFail = (message?: string) => { + ensurePostgresAdapter(connectionInfo.adapter, "copyFail"); + pool.copyFailFor(pooledConnection, message); + }; + /** + * Enable or disable streaming mode for COPY TO. + * When enabled, the connection will not accumulate COPY TO data in memory + * and will emit chunks via onCopyChunk instead. + */ + /** @type {(enable: boolean) => void} */ + reserved_sql.setCopyStreamingMode = (enable: boolean) => { + ensurePostgresAdapter(connectionInfo.adapter, "setCopyStreamingMode"); + const copyPool = pool as unknown as { + setCopyStreamingModeFor?: (connection: any, enable: boolean) => void; + getConnectionForQuery?: (connection: any) => any; + }; + if (typeof copyPool.setCopyStreamingModeFor === "function") { + copyPool.setCopyStreamingModeFor(pooledConnection, !!enable); + return; + } + + const underlying = copyPool.getConnectionForQuery + ? copyPool.getConnectionForQuery(pooledConnection) + : pooledConnection?.connection; + + const adapter = PostgresAdapter as unknown as { + setCopyStreamingMode?: (connection: any, enable: boolean) => void; + }; + if (underlying && typeof adapter.setCopyStreamingMode === "function") { + adapter.setCopyStreamingMode(underlying, !!enable); + } + }; + /** @type {(ms: number) => void} */ + reserved_sql.setCopyTimeout = (ms: number) => { + ensurePostgresAdapter(connectionInfo.adapter, "setCopyTimeout"); + const copyPool = pool as unknown as { + setCopyTimeoutFor?: (connection: any, ms: number) => void; + getConnectionForQuery?: (connection: any) => any; + }; + const clamped = clampUint32(ms); + + if (typeof copyPool.setCopyTimeoutFor === "function") { + copyPool.setCopyTimeoutFor(pooledConnection, clamped); + return; + } + + const underlying = copyPool.getConnectionForQuery + ? copyPool.getConnectionForQuery(pooledConnection) + : pooledConnection?.connection; + + const adapter = PostgresAdapter as unknown as { setCopyTimeout?: (connection: any, ms: number) => void }; + if (underlying && typeof adapter.setCopyTimeout === "function") { + adapter.setCopyTimeout(underlying, clamped); + } + }; + /** @type {(bytes: number) => void} */ + reserved_sql.setMaxCopyBufferSize = (bytes: number) => { + ensurePostgresAdapter(connectionInfo.adapter, "setMaxCopyBufferSize"); + const copyPool = pool as unknown as { + setMaxCopyBufferSizeFor?: (connection: any, bytes: number) => void; + getConnectionForQuery?: (connection: any) => any; + }; + const clamped = clampUint32(bytes); + + if (typeof copyPool.setMaxCopyBufferSizeFor === "function") { + copyPool.setMaxCopyBufferSizeFor(pooledConnection, clamped); + return; + } + + const underlying = copyPool.getConnectionForQuery + ? copyPool.getConnectionForQuery(pooledConnection) + : pooledConnection?.connection; + + const adapter = PostgresAdapter as unknown as { setMaxCopyBufferSize?: (connection: any, bytes: number) => void }; + if (underlying && typeof adapter.setMaxCopyBufferSize === "function") { + // Delegate to adapter binding so native-side safety caps are applied consistently. + adapter.setMaxCopyBufferSize(underlying, clamped); + } + }; + // Expose adapter-level COPY defaults on reserved connections + reserved_sql.getCopyDefaults = () => { + ensurePostgresAdapter(connectionInfo.adapter, "getCopyDefaults"); + return pool.getCopyDefaults(); + }; + reserved_sql.setCopyDefaults = (defaults: { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }) => { + ensurePostgresAdapter(connectionInfo.adapter, "setCopyDefaults"); + pool.setCopyDefaultsFor(pooledConnection, defaults); + return reserved_sql; + }; + + // Streaming COPY TO STDOUT helpers (Phase 4) + reserved_sql.onCopyChunk = (handler: (chunk: string | ArrayBuffer | Uint8Array) => void) => { + ensurePostgresAdapter(connectionInfo.adapter, "onCopyChunk"); + const copyPool = pool as unknown as { getConnectionForQuery?: (connection: any) => any }; + const underlying = copyPool.getConnectionForQuery + ? copyPool.getConnectionForQuery(pooledConnection) + : pooledConnection?.connection; + + const adapter = PostgresAdapter as unknown as { + onCopyChunk?: (connection: any, handler: (chunk: any) => void) => void; + }; + if (underlying && typeof adapter.onCopyChunk === "function") { + adapter.onCopyChunk(underlying, handler as unknown as (chunk: any) => void); + return true; + } + + return false; + }; + reserved_sql.onCopyEnd = (handler: () => void) => { + ensurePostgresAdapter(connectionInfo.adapter, "onCopyEnd"); + const copyPool = pool as unknown as { getConnectionForQuery?: (connection: any) => any }; + const underlying = copyPool.getConnectionForQuery + ? copyPool.getConnectionForQuery(pooledConnection) + : pooledConnection?.connection; + + const adapter = PostgresAdapter as unknown as { onCopyEnd?: (connection: any, handler: () => void) => void }; + if (underlying && typeof adapter.onCopyEnd === "function") { + adapter.onCopyEnd(underlying, handler); + return true; + } + + return false; + }; + function onTransactionFinished(transaction_promise: Promise) { reservedTransaction.delete(transaction_promise); } @@ -558,8 +1046,12 @@ const SQL: typeof Bun.SQL = function SQL( return unsafeQueryFromTransaction(string, [], pooledConnection, state.queries); } function transaction_sql( - strings: string | TemplateStringsArray | import("internal/sql/shared.ts").SQLHelper | Query, - ...values: any[] + strings: + | string + | TemplateStringsArray + | import("internal/sql/shared.ts").SQLHelper + | Query, + ...values: unknown[] ) { if ( state.connectionState & ReservedConnectionState.closed || @@ -602,7 +1094,7 @@ const SQL: typeof Bun.SQL = function SQL( }; transaction_sql.commitDistributed = async function (name: string) { if (!pool.getCommitDistributedSQL) { - throw Error(`This adapter doesn't support distributed transactions.`); + throw new Error(`This adapter doesn't support distributed transactions.`); } const sql = pool.getCommitDistributedSQL(name); @@ -610,7 +1102,7 @@ const SQL: typeof Bun.SQL = function SQL( }; transaction_sql.rollbackDistributed = async function (name: string) { if (!pool.getRollbackDistributedSQL) { - throw Error(`This adapter doesn't support distributed transactions.`); + throw new Error(`This adapter doesn't support distributed transactions.`); } const sql = pool.getRollbackDistributedSQL(name); @@ -835,13 +1327,1029 @@ const SQL: typeof Bun.SQL = function SQL( sql.array = (values: any[], typeNameOrID: number | string | undefined = undefined) => { return pool.array(values, typeNameOrID); }; + + type CopyReservedConnection = { + unsafe: (sqlText: string, values?: unknown[]) => Promise; + release: () => Promise; + onCopyStart?: (handler: () => void) => void; + onCopyChunk?: (handler: (chunk: string | ArrayBuffer | Uint8Array) => void) => boolean | void; + onCopyEnd?: (handler: () => void) => boolean | void; + copySendData: (data: string | Uint8Array) => void; + copyDone: () => void; + copyFail?: (message?: string) => void; + setCopyTimeout?: (ms: number) => void; + setCopyStreamingMode?: (enable: boolean) => void; + getCopyDefaults?: () => { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }; + }; + + // High-level COPY FROM STDIN helper + // Usage: await sql.copyFrom("table", ["col1","col2"], data, { + // format: "text"|"csv"|"binary", + // delimiter?: string, + // null?: string, + // sanitizeNUL?: boolean, // strip NUL (0x00) from strings and raw bytes + // replaceInvalid?: string, // replacement for NUL in strings (default: "") + // signal?: AbortSignal, // optional cancellation + // onProgress?: (info: { bytesSent: number; chunksSent: number }) => void, // optional progress + // }) + // - data can be: string, any[][], generator/iterator, AsyncIterable, or AsyncIterable + sql.copyFrom = async function ( + this: Bun.SQL, + table: string, + columns: string[], + data: + | string + | unknown[] + | Iterable + | AsyncIterable + | AsyncIterable + | (() => Iterable), + options?: CopyFromOptions, + ) { + ensurePostgresAdapter(connectionInfo.adapter, "COPY"); + + if (typeof table !== "string" || table.length === 0) { + throw $ERR_INVALID_ARG_VALUE("table", table, "must be a non-empty string"); + } + + if (!$isArray(columns)) { + throw $ERR_INVALID_ARG_VALUE("columns", columns, "must be an array of strings"); + } + for (let i = 0; i < columns.length; i++) { + const column = columns[i]; + if (typeof column !== "string") { + throw $ERR_INVALID_ARG_VALUE(`columns[${i}]`, column, "must be a string"); + } + } + + if (options !== undefined) { + validateObject(options, "options"); + } + + // Reserve a dedicated connection for COPY + const reserved = (await sql.reserve()) as CopyReservedConnection; + const closeReserved = async () => { + try { + await reserved.release(); + } catch {} + }; + + // Helpers + const escapeIdentifier = + pool.escapeIdentifier && typeof pool.escapeIdentifier === "function" + ? (s: string) => pool.escapeIdentifier(s) + : (s: string) => '"' + String(s).replaceAll('"', '""').replaceAll(".", '"."') + '"'; + + if (options?.format !== undefined) { + validateOneOf(options.format, "options.format", ["text", "csv", "binary"]); + } + + const fmt = options?.format === "csv" ? "csv" : options?.format === "binary" ? "binary" : "text"; + + let delimiter = options?.delimiter ?? (fmt === "csv" ? "," : "\t"); + if (delimiter !== undefined) { + delimiter = String(delimiter); + if (delimiter.length !== 1) { + throw $ERR_INVALID_ARG_VALUE("options.delimiter", delimiter, "must be exactly one character"); + } + } + + const nullToken = options?.null ?? (fmt === "csv" ? "" : "\\N"); + + const stripNul = options?.sanitizeNUL === true; + const replaceInvalid = options?.replaceInvalid ?? ""; + + const sanitizeString = (s: string) => (stripNul ? s.replaceAll("\u0000", replaceInvalid) : s); + const sanitizeBytes = (u8: Uint8Array) => { + if (!stripNul) return u8; + let keep = 0; + for (let i = 0; i < u8.length; i++) if (u8[i] !== 0) keep++; + if (keep === u8.length) return u8; + const out = new Uint8Array(keep); + let j = 0; + for (let i = 0; i < u8.length; i++) if (u8[i] !== 0) out[j++] = u8[i]; + return out; + }; + + // Abort handling and progress + const signal: AbortSignal | undefined = options?.signal; + let aborted = false; + const counters = { bytesSent: 0, chunksSent: 0 }; + const notifyProgress = () => { + try { + options?.onProgress?.({ bytesSent: counters.bytesSent, chunksSent: counters.chunksSent }); + } catch {} + }; + const onAbort = () => { + aborted = true; + }; + if (signal) { + if (signal.aborted) onAbort(); + signal.addEventListener("abort", onAbort, { once: true }); + } + + const serializeValue = (v: any): string => { + if (v === null || v === undefined) return nullToken; + if (v instanceof Date) return v.toISOString(); + if (typeof v === "boolean") return fmt === "csv" ? (v ? "true" : "false") : v ? "t" : "f"; + if (typeof v === "number" || typeof v === "bigint") return String(v); + if (typeof v === "string") return sanitizeString(v); + if (ArrayBuffer.isView(v) && !globalThis.Buffer?.isBuffer?.(v)) { + // Typed array -> string + return String(v); + } + // Fallback stringify + try { + const json = JSON.stringify(v); + if (json === undefined) { + return sanitizeString(String(v)); + } + return sanitizeString(json); + } catch { + return sanitizeString(String(v)); + } + }; + + // Use shared needsCsvQuote(s, delimiter) from encoding utilities + + const serializeRow = (row: any[]): string => { + if (fmt === "csv") { + const parts = row.map(v => { + // Check for actual null/undefined before serializing + if (v === null || v === undefined) { + // Emit caller-provided NULL literal; quote if needed for CSV + return needsCsvQuote(nullToken, delimiter) ? pgCsvQuote(nullToken) : nullToken; + } + const s = serializeValue(v); + // Empty string should be quoted to distinguish from NULL + if (s === "") { + return pgCsvQuote(""); + } + return needsCsvQuote(s, delimiter) ? pgCsvQuote(s) : s; + }); + return parts.join(delimiter) + "\n"; + } else { + // text format: escape backslash, tab, LF, CR; null => \N + const parts = row.map(v => { + if (v === null || v === undefined) return nullToken; + const serialized = serializeValue(v); + return copyTextEscape(serialized); + }); + return parts.join(delimiter) + "\n"; + } + }; + + // TYPE_OID and TYPE_ARRAY_OID are now imported from postgres-encoding + + const feedData = async () => { + // Batch size for accumulating small chunks (configurable) + const BATCH_SIZE = + options && typeof options.batchSize === "number" && options.batchSize > 0 + ? (options.batchSize as number) + : DEFAULT_COPY_BATCH_SIZE; + let batch = ""; + + // Resolve limits once at start (avoid repeated option resolution inside loops) + const resolvedLimits = resolveCopyFromLimits(options, pool); + const resolvedMaxBytes = resolvedLimits.maxBytes; + + // Binary COPY support using shared encoding utilities + let binaryHeaderSent = false; + const sendBinaryHeader = () => { + if (binaryHeaderSent) return; + reserved.copySendData(createBinaryCopyHeader()); + binaryHeaderSent = true; + }; + const sendBinaryTrailer = () => { + // Only emit the binary envelope for the automatic encoder path. + // For raw binary chunk streams, the caller is responsible for providing a correct envelope. + if (!shouldAutoEmitBinaryEnvelope) return; + + // Always emit a valid trailer. If the header was never sent (e.g. empty iterable), + // send it now so the stream is still a valid PostgreSQL COPY BINARY payload. + if (!binaryHeaderSent) { + sendBinaryHeader(); + } + reserved.copySendData(createBinaryCopyTrailer()); + }; + + const shouldAutoEmitBinaryEnvelope = isCopyFromBinaryOptions(options); + if (shouldAutoEmitBinaryEnvelope) { + sendBinaryHeader(); + } + + const flushBatch = async () => { + if (batch.length > 0) { + // Enforce maxBytes and update progress before sending this batch + const bLen = getByteLength(batch); + + if (resolvedMaxBytes && counters.bytesSent + bLen > resolvedMaxBytes) { + throw new Error("copyFrom: maxBytes exceeded"); + } + + await sendChunkedData(batch, reserved, pool, resolvedLimits, counters, notifyProgress); + batch = ""; + } + }; + + const addToBatch = async (chunk: string) => { + batch += chunk; + if (batch.length >= BATCH_SIZE) { + await flushBatch(); + } + }; + + // Send data depending on type + if (typeof data === "string") { + if (fmt === "binary") { + throw new Error( + 'copyFrom: string payloads are not allowed when format is "binary". Provide row arrays with options.binaryTypes for automatic encoding, or provide raw byte chunks (Uint8Array/ArrayBuffer) that already include the COPY BINARY envelope.', + ); + } + if (aborted) throw new Error("AbortError"); + const payload = sanitizeString(data); + await sendChunkedData(payload, reserved, pool, resolvedLimits, counters, notifyProgress); + sendBinaryTrailer(); + reserved.copyDone(); + return; + } + + const maybeIter = typeof data === "function" ? data() : data; + + let cachedBinaryTypes: ReturnType | undefined = undefined; + const getBinaryTypes = () => { + if (cachedBinaryTypes !== undefined) return cachedBinaryTypes; + cachedBinaryTypes = validateBinaryTypes(options, columns); + return cachedBinaryTypes; + }; + + // Async iterable (rows or raw string/Uint8Array chunks) + if (isAsyncIterable(maybeIter)) { + for await (const item of maybeIter as AsyncIterable) { + if (aborted) throw new Error("AbortError"); + if ($isArray(item)) { + if (fmt === "binary") { + const types = getBinaryTypes(); + await flushBatch(); + const payload = encodeBinaryRow(item, types); + await sendChunkedData(payload, reserved, pool, resolvedLimits, counters, notifyProgress); + } else { + // text/csv: treat as row[] + await addToBatch(serializeRow(item)); + } + } else if (typeof item === "string") { + if (fmt === "binary") { + throw $ERR_INVALID_ARG_VALUE( + "data", + item, + 'must be an array row or a byte source when format is "binary"', + ); + } + // raw string chunk + await addToBatch(sanitizeString(item)); + } else if (hasByteLength(item)) { + // raw bytes (Uint8Array or ArrayBuffer) - flush and send directly + await flushBatch(); + const view = toUint8ArrayView(item); + if (!view) { + throw $ERR_INVALID_ARG_VALUE("data", item, "must be a string, an array row, or a byte source"); + } + // For binary format, send raw bytes as-is; for text/csv, sanitize NUL bytes if requested + const src = fmt === "binary" ? view : sanitizeBytes(view); + await sendChunkedData(src, reserved, pool, resolvedLimits, counters, notifyProgress); + } else { + if (fmt === "binary") { + throw $ERR_INVALID_ARG_VALUE( + "data", + item, + 'must be an array row or a byte source when format is "binary"', + ); + } + // fallback: attempt to serialize as a row + await addToBatch(serializeRow(item)); + } + } + await flushBatch(); + sendBinaryTrailer(); + reserved.copyDone(); + return; + } + + // Raw byte buffers (Uint8Array/Buffer/ArrayBuffer) are iterable, so handle them before the generic iterable branch. + if (hasByteLength(maybeIter)) { + if (aborted) throw new Error("AbortError"); + await flushBatch(); + if (aborted) throw new Error("AbortError"); + const view = toUint8ArrayView(maybeIter); + if (!view) { + throw $ERR_INVALID_ARG_VALUE("data", maybeIter, "must be a string, an array row, or a byte source"); + } + const src = fmt === "binary" ? view : sanitizeBytes(view); + await sendChunkedData(src, reserved, pool, resolvedLimits, counters, notifyProgress); + sendBinaryTrailer(); + reserved.copyDone(); + return; + } + + // Sync iterable (rows or raw string/Uint8Array chunks) + if (isIterable(maybeIter)) { + for (const item of maybeIter as Iterable) { + if (aborted) throw new Error("AbortError"); + if ($isArray(item)) { + if (fmt === "binary") { + const types = getBinaryTypes(); + if (aborted) throw new Error("AbortError"); + await flushBatch(); + if (aborted) throw new Error("AbortError"); + const payload = encodeBinaryRow(item, types); + await sendChunkedData(payload, reserved, pool, resolvedLimits, counters, notifyProgress); + } else { + if (aborted) throw new Error("AbortError"); + await addToBatch(serializeRow(item)); + } + } else if (typeof item === "string") { + if (aborted) throw new Error("AbortError"); + if (fmt === "binary") { + throw $ERR_INVALID_ARG_VALUE( + "data", + item, + 'must be an array row or a byte source when format is "binary"', + ); + } + await addToBatch(sanitizeString(item)); + } else if (hasByteLength(item)) { + // raw bytes (Uint8Array or ArrayBuffer) - flush and send directly + if (aborted) throw new Error("AbortError"); + await flushBatch(); + if (aborted) throw new Error("AbortError"); + const view = toUint8ArrayView(item); + if (!view) { + throw $ERR_INVALID_ARG_VALUE("data", item, "must be a string, an array row, or a byte source"); + } + const src = fmt === "binary" ? view : sanitizeBytes(view); + await sendChunkedData(src, reserved, pool, resolvedLimits, counters, notifyProgress); + } else { + if (aborted) throw new Error("AbortError"); + if (fmt === "binary") { + throw $ERR_INVALID_ARG_VALUE( + "data", + item, + 'must be an array row or a byte source when format is "binary"', + ); + } + await addToBatch(serializeRow(item)); + } + } + await flushBatch(); + sendBinaryTrailer(); + reserved.copyDone(); + return; + } + + // Array of arrays + if ($isArray(data)) { + if (fmt === "binary") { + if (!isCopyFromBinaryOptions(options)) { + throw $ERR_INVALID_ARG_VALUE( + "options.binaryTypes", + undefined, + 'must be provided when format is "binary" and data is an array of rows', + ); + } + + const types = validateBinaryTypes(options, columns); + await flushBatch(); + + for (let i = 0; i < data.length; i++) { + const row = data[i]; + if (aborted) throw new Error("AbortError"); + if (!$isArray(row)) { + throw $ERR_INVALID_ARG_VALUE(`data[${i}]`, row, "must be an array"); + } + const payload = encodeBinaryRow(row, types); + await sendChunkedData(payload, reserved, pool, resolvedLimits, counters, notifyProgress); + } + + await flushBatch(); + sendBinaryTrailer(); + reserved.copyDone(); + return; + } + + for (const row of data) { + if (aborted) throw new Error("AbortError"); + await addToBatch(serializeRow(row)); + } + await flushBatch(); + reserved.copyDone(); + return; + } + + // Fallback: treat as string + if (aborted) throw new Error("AbortError"); + const fallback = sanitizeString(String(data ?? "")); + await sendChunkedData(fallback, reserved, pool, resolvedLimits, counters, notifyProgress); + sendBinaryTrailer(); + reserved.copyDone(); + }; + + try { + // Register one-shot onCopyStart to feed rows + if (typeof reserved.onCopyStart === "function") { + reserved.onCopyStart(() => { + // Properly handle errors during data feeding + feedData().catch(feedErr => { + try { + // Send CopyFail to server to abort the COPY operation + if (typeof reserved.copyFail === "function") { + reserved.copyFail(String(feedErr?.message || feedErr || "Error feeding data")); + } + } catch {} + }); + }); + } + + // Build and run COPY ... FROM STDIN + const cols = (columns ?? []).map(c => escapeIdentifier(String(c))).join(", "); + const tableName = escapeIdentifier(String(table)); + // If automatic binary encoding is requested, validate column OIDs match expected types + if (fmt === "binary" && isCopyFromBinaryOptions(options)) { + const typeTokens = options.binaryTypes; + if (typeTokens.length !== (columns?.length ?? typeTokens.length)) { + throw new Error("binaryTypes length must match number of columns for COPY FROM."); + } + // Fetch column OIDs in the provided order using array_position for stable ordering + const colNames = columns ?? []; + // Determine schema and relation name (unquoted) for OID validation + const rawTable = String(table).replaceAll('"', ""); + let schemaName: string | null = null; + let relName = rawTable; + const dotIndex = rawTable.indexOf("."); + if (dotIndex !== -1) { + schemaName = rawTable.slice(0, dotIndex); + relName = rawTable.slice(dotIndex + 1); + } + + // Fetch all columns and validate in JS according to the provided columns[] order + const q = ` + SELECT a.attname::text AS name, a.atttypid::oid AS oid + FROM pg_catalog.pg_attribute a + JOIN pg_catalog.pg_class c ON c.oid = a.attrelid + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE c.relname = $1 + AND ($2::text IS NULL OR n.nspname = $2) + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + `; + const rows = await reserved.unsafe(q, [relName, schemaName]); + // Build expected OIDs for provided type tokens + const expectedOids: number[] = typeTokens.map(tok => { + if (tok.endsWith("[]")) { + const arrOid = TYPE_ARRAY_OID[tok]; + if (!arrOid) throw new Error(`Unsupported array type for validation: ${tok}`); + return arrOid; + } + // map varchar/bpchar to their OIDs, otherwise base TYPE_OID + const base = + TYPE_OID[tok] ?? (tok === "varchar" ? TYPE_OID.varchar : tok === "bpchar" ? TYPE_OID.bpchar : undefined); + if (!base && base !== 0) throw new Error(`Unsupported type for validation: ${tok}`); + // Column OID must be the base type OID when not array + return base!; + }); + if (!$isArray(rows) || rows.length === 0) { + throw new Error("Could not resolve column OIDs for validation."); + } + if ((colNames?.length ?? 0) > 0) { + const oidByName = new Map(); + for (const r of rows) { + if (typeof r?.name === "string" && typeof r?.oid === "number") { + oidByName.set(r.name, r.oid); + } + } + for (let i = 0; i < expectedOids.length; i++) { + const colName = String(colNames[i]); + const got = oidByName.get(colName); + const want = expectedOids[i]; + if (typeof got !== "number" || got !== want) { + throw new Error( + `COPY binaryTypes validation failed for column "${colName}": expected OID ${want}, got ${got}`, + ); + } + } + } else { + if (rows.length < expectedOids.length) { + throw new Error("Could not resolve column OIDs for validation."); + } + for (let i = 0; i < expectedOids.length; i++) { + const got = rows[i]?.oid; + const want = expectedOids[i]; + if (typeof got !== "number" || got !== want) { + throw new Error( + `COPY binaryTypes validation failed for column #${i + 1}: expected OID ${want}, got ${got}`, + ); + } + } + } + } + let sqlText = cols ? `COPY ${tableName} (${cols}) FROM STDIN` : `COPY ${tableName} FROM STDIN`; + if (fmt === "csv" || fmt === "text") { + const delimiterOption = + delimiter && String(delimiter).length === 1 ? `, DELIMITER '${String(delimiter).replaceAll("'", "''")}'` : ""; + const nullOption = options?.null != null ? `, NULL '${String(nullToken).replaceAll("'", "''")}'` : ""; + const formatOption = fmt === "csv" ? "CSV" : "TEXT"; + sqlText += ` (FORMAT ${formatOption}${delimiterOption}${nullOption})`; + } else if (fmt === "binary") { + sqlText += ` (FORMAT BINARY)`; + } + + // Handle AbortSignal: if aborted before issuing query + if (aborted) throw new Error("AbortError"); + + // Apply COPY FROM timeout default (if provided) before issuing the command + try { + const __defaults__ = + (reserved && typeof reserved.getCopyDefaults === "function" ? reserved.getCopyDefaults() : undefined) || + (pool && typeof pool.getCopyDefaults === "function" ? pool.getCopyDefaults() : undefined) || + undefined; + + const __fromDefaults__ = (__defaults__ && __defaults__.from) || { + maxChunkSize: DEFAULT_COPY_MAX_CHUNK_SIZE, + maxBytes: 0, + timeout: 0, + }; + const timeout = + options && typeof options.timeout === "number" && options.timeout >= 0 + ? Math.max(0, Math.trunc(options.timeout)) + : Math.max(0, Math.trunc(__fromDefaults__.timeout ?? 0)); + if (typeof reserved.setCopyTimeout === "function") { + try { + reserved.setCopyTimeout(timeout); + } catch {} + } + } catch {} + + const result = await reserved.unsafe(sqlText); + await closeReserved(); + return result; + } catch (err) { + // Ensure we send CopyFail if we haven't already + try { + if (typeof reserved.copyFail === "function") { + reserved.copyFail(String(err?.message || err || "COPY operation failed")); + } + } catch {} + await closeReserved(); + throw err; + } finally { + // detach abort listener + if (options?.signal) { + options.signal.removeEventListener("abort", onAbort); + } + } + }; + + // Streaming COPY TO STDOUT helper: + // Usage: + // for await (const chunk of sql.copyTo(`COPY (SELECT ...) TO STDOUT`)) { + // // chunk is string for text format, ArrayBuffer for binary + // } + // or pass table/columns/options: + // for await (const chunk of sql.copyTo({ + // table: "t", + // columns: ["a","b"], + // format: "csv", + // signal?: AbortSignal, + // onProgress?: (info: { bytesReceived: number; chunksReceived: number }) => void, + // })) { ... } + sql.copyTo = function (this: Bun.SQL, queryOrOptions: string | CopyToOptions): AsyncIterable { + ensurePostgresAdapter(connectionInfo.adapter, "COPY"); + + const self = this; + const makeQuery = () => { + if (typeof queryOrOptions === "string") { + return queryOrOptions; + } + + validateObject(queryOrOptions, "queryOrOptions"); + + const table = queryOrOptions.table; + if ((typeof table !== "string" && typeof table !== "symbol") || String(table).length === 0) { + throw $ERR_INVALID_ARG_VALUE("queryOrOptions.table", table, "must be a non-empty string or symbol"); + } + + const format = queryOrOptions.format; + if (format !== undefined) { + validateOneOf(format, "queryOrOptions.format", ["text", "csv", "binary"]); + } + + const columns = queryOrOptions.columns; + if (columns !== undefined && !$isArray(columns)) { + throw $ERR_INVALID_ARG_VALUE("queryOrOptions.columns", columns, "must be an array of strings"); + } + if ($isArray(columns)) { + for (let i = 0; i < columns.length; i++) { + const column = columns[i]; + if (typeof column !== "string" || column.length === 0) { + throw $ERR_INVALID_ARG_VALUE(`queryOrOptions.columns[${i}]`, column, "must be a non-empty string"); + } + } + } + + // Use adapter's escapeIdentifier to handle schema-qualified names correctly + const escapeIdentifier = pool.escapeIdentifier + ? pool.escapeIdentifier.bind(pool) + : (str: string) => '"' + String(str).replaceAll('"', '""').replaceAll(".", '"."') + '"'; + + const tableName = escapeIdentifier(String(table)); + const list = $isArray(columns) ? columns.map(c => escapeIdentifier(String(c))).join(", ") : ""; + const fmt = format === "csv" ? " (FORMAT CSV)" : format === "binary" ? " (FORMAT BINARY)" : ""; + return `COPY ${tableName}${list ? ` (${list})` : ""} TO STDOUT${fmt}`; + }; + + return { + async *[Symbol.asyncIterator](): AsyncIterator { + const reserved = await self.reserve(); + const chunks: any[] = []; + let done = false; + let rejectErr: any = null; + + // Progress and abort state + let bytesReceived = 0; + let chunksReceived = 0; + const notifyProgress = () => { + try { + if (typeof queryOrOptions !== "string") { + queryOrOptions.onProgress?.({ bytesReceived, chunksReceived }); + } + } catch {} + }; + let aborted = false; + const signal = typeof queryOrOptions === "string" ? undefined : queryOrOptions.signal; + const onAbort = () => { + aborted = true; + if (chunkResolve) { + chunkResolve(); + chunkResolve = null; + } + }; + if (signal) { + if (signal.aborted) onAbort(); + signal.addEventListener("abort", onAbort, { once: true }); + } + + let chunkResolve: (() => void) | null = null; + + const waitForChunk = () => + new Promise(r => { + chunkResolve = r; + }); + + let hasCopyChunkHandler = false; + let hasCopyEndHandler = false; + + // Register streaming handlers (wrapper functions always exist; detect real registration via return value) + if (typeof reserved.onCopyChunk === "function") { + hasCopyChunkHandler = + reserved.onCopyChunk((chunk: any) => { + chunks.push(chunk); + if (chunkResolve) { + chunkResolve(); + chunkResolve = null; + } + try { + // Update progress + if (chunk instanceof ArrayBuffer) { + bytesReceived += chunk.byteLength; + } else if (typeof chunk === "string") { + bytesReceived += Buffer?.byteLength + ? Buffer.byteLength(chunk, "utf8") + : new TextEncoder().encode(chunk).byteLength; + } else if (chunk?.byteLength != null) { + bytesReceived += chunk.byteLength; + } + chunksReceived += 1; + notifyProgress(); + // Guardrail: maxBytes + const toMax = resolveCopyToMaxBytes(queryOrOptions, pool); + if (toMax > 0 && bytesReceived > toMax) { + rejectErr = new Error("copyTo: maxBytes exceeded"); + done = true; + // Immediately release connection to halt incoming data + try { + reserved.release(); + } catch {} + } + } catch {} + }) === true; + } + + if (typeof reserved.onCopyEnd === "function") { + hasCopyEndHandler = + reserved.onCopyEnd(() => { + done = true; + if (chunkResolve) { + chunkResolve(); + chunkResolve = null; + } + }) === true; + } + + const toUint8Array = (value: unknown): Uint8Array | null => { + if (value instanceof Uint8Array) return value; + if (value instanceof ArrayBuffer) return new Uint8Array(value); + if (ArrayBuffer.isView(value)) { + const view = value as ArrayBufferView; + return new Uint8Array(view.buffer, view.byteOffset, view.byteLength); + } + return null; + }; + + const toRealArrayBuffer = (u8: Uint8Array): ArrayBuffer => { + const buffer = u8.buffer; + if (buffer instanceof ArrayBuffer && u8.byteOffset === 0 && u8.byteLength === buffer.byteLength) { + return buffer; + } + return u8.slice().buffer; + }; + + const joinUint8Arrays = (parts: Uint8Array[]): ArrayBuffer => { + let total = 0; + for (let i = 0; i < parts.length; i++) total += parts[i].byteLength; + const out = new Uint8Array(total); + let offset = 0; + for (let i = 0; i < parts.length; i++) { + out.set(parts[i], offset); + offset += parts[i].byteLength; + } + return toRealArrayBuffer(out); + }; + + const yieldAccumulated = async function* ( + accumulated: unknown, + isBinary: boolean, + format: string | undefined, + ): AsyncGenerator { + if (isBinary) { + if ($isArray(accumulated)) { + const parts: Uint8Array[] = []; + for (let i = 0; i < accumulated.length; i++) { + const u8 = toUint8Array(accumulated[i]); + if (!u8) { + throw $ERR_INVALID_ARG_VALUE( + "format", + format, + 'COPY TO returned non-binary data while format is "binary"', + ); + } + parts.push(u8); + } + yield joinUint8Arrays(parts); + return; + } + + const value = accumulated ?? null; + const u8 = toUint8Array(value); + if (!u8) { + throw $ERR_INVALID_ARG_VALUE( + "format", + format, + 'COPY TO returned non-binary data while format is "binary"', + ); + } + yield toRealArrayBuffer(u8); + return; + } + + if ($isArray(accumulated)) { + yield accumulated.map(x => String(x ?? "")).join(""); + return; + } + + yield String(accumulated ?? ""); + }; + + try { + if (aborted) throw new Error("AbortError"); + + // Determine whether streaming was requested for COPY TO. + const __defaults__ = + (reserved && typeof reserved.getCopyDefaults === "function" ? reserved.getCopyDefaults() : undefined) || + (pool && typeof pool.getCopyDefaults === "function" ? pool.getCopyDefaults() : undefined) || + undefined; + + const __toDefaults__ = (__defaults__ && __defaults__.to) || { stream: true, maxBytes: 0, timeout: 0 }; + const desiredStream = + typeof queryOrOptions === "string" + ? __toDefaults__.stream + : queryOrOptions.stream !== undefined + ? !!queryOrOptions.stream + : __toDefaults__.stream; + + const timeout = + typeof queryOrOptions === "string" + ? (__toDefaults__.timeout ?? 0) + : queryOrOptions.timeout !== undefined + ? (() => { + const value = queryOrOptions.timeout; + if (typeof value !== "number" || !Number.isFinite(value) || value < 0) { + throw $ERR_INVALID_ARG_VALUE( + "queryOrOptions.timeout", + value, + "must be a finite non-negative number", + ); + } + return Math.trunc(value); + })() + : (__toDefaults__.timeout ?? 0); + + // Tightened semantics: + // - If streaming is requested but we do not have an onCopyChunk handler, we force accumulation + // and yield exactly one chunk (the accumulated payload). + if (desiredStream && !hasCopyChunkHandler) { + if (typeof reserved.setCopyTimeout === "function") { + try { + reserved.setCopyTimeout(timeout); + } catch {} + } + + if (typeof reserved.setCopyStreamingMode === "function") { + try { + reserved.setCopyStreamingMode(false); + } catch {} + } + + const q = makeQuery(); + const accumulated = await reserved.unsafe(q); + + const format = typeof queryOrOptions === "string" ? undefined : queryOrOptions.format; + const isBinary = format === "binary"; + + for await (const part of yieldAccumulated(accumulated, isBinary, format)) { + yield part; + } + + done = true; + } else if (!desiredStream) { + if (typeof reserved.setCopyTimeout === "function") { + try { + reserved.setCopyTimeout(timeout); + } catch {} + } + + if (typeof reserved.setCopyStreamingMode === "function") { + try { + reserved.setCopyStreamingMode(false); + } catch {} + } + + const q = makeQuery(); + const accumulated = await reserved.unsafe(q); + + const format = typeof queryOrOptions === "string" ? undefined : queryOrOptions.format; + const isBinary = format === "binary"; + + for await (const part of yieldAccumulated(accumulated, isBinary, format)) { + yield part; + } + + done = true; + } else { + // Enable streaming mode to avoid accumulation in Zig during COPY TO. + if (typeof reserved.setCopyStreamingMode === "function") { + try { + if (typeof reserved.setCopyTimeout === "function") { + try { + reserved.setCopyTimeout(timeout); + } catch {} + } + + reserved.setCopyStreamingMode(!!desiredStream); + } catch {} + } + + // Start COPY TO STDOUT + const q = makeQuery(); + await reserved.unsafe(q); + + // Drain chunks as they arrive; finish when done flag is set + while (!done || chunks.length > 0) { + if (aborted) { + // Stop consumption early; close the reserved connection to abort server-side + rejectErr = new Error("AbortError"); + break; + } + if (chunks.length === 0) { + // yield to event loop + await waitForChunk(); + continue; + } + const next = chunks.shift(); + if (next instanceof Uint8Array) { + // Normalize Uint8Array view to an ArrayBuffer containing only the view's bytes + const buffer = next.buffer.slice(next.byteOffset, next.byteOffset + next.byteLength); + yield buffer; + } else { + yield next; + } + } + } + } catch (e) { + rejectErr = e; + } finally { + try { + if (typeof reserved.setCopyStreamingMode === "function") { + try { + reserved.setCopyStreamingMode(false); + } catch {} + } + await reserved.release(); + } catch {} + if (signal) { + signal.removeEventListener("abort", onAbort); + } + } + + if (rejectErr) { + throw rejectErr; + } + }, + }; + }; + + // Helper to pipe COPY TO stream directly into a WritableStream or stream-like sink + // Usage: + // await sql.copyToPipeTo({ table: "t", format: "binary" }, writable) + // Where writable is a Web WritableStream or an object with write(), close()/end() + sql.copyToPipeTo = async function ( + this: Bun.SQL, + queryOrOptions: string | CopyToOptions, + writable: + | WritableStream + | { + write: (chunk: string | ArrayBuffer | Uint8Array) => unknown | Promise; + close?: () => unknown | Promise; + end?: () => unknown | Promise; + }, + ) { + ensurePostgresAdapter(connectionInfo.adapter, "COPY"); + + const isWritable = isWritableStream(writable); + const isStreamLike = isWritableSink(writable); + + if (!isWritable && !isStreamLike) { + throw $ERR_INVALID_ARG_VALUE("writable", writable, "must be a WritableStream or an object with a write() method"); + } + + const iterable = this.copyTo(queryOrOptions); + + // Web WritableStream path + if (isWritable) { + const writer = writable.getWriter(); + try { + for await (const chunk of iterable) { + // Normalize ArrayBuffer to Uint8Array for WritableStream + if (chunk instanceof ArrayBuffer) { + await writer.write(new Uint8Array(chunk)); + } else { + await writer.write(chunk); + } + } + await writer.close(); + } catch (e) { + try { + await writer.close(); + } catch {} + throw e; + } + return; + } + + // Generic stream-like sink with write()/close() or end() + if (isStreamLike) { + for await (const chunk of iterable) { + await writable.write(chunk); + } + if (typeof writable.close === "function") { + await writable.close(); + } else if (typeof writable.end === "function") { + await writable.end(); + } + return; + } + + throw $ERR_INVALID_ARG_VALUE("writable", writable, "must be a WritableStream or an object with a write() method"); + }; + sql.rollbackDistributed = async function (name: string) { if (pool.closed) { throw pool.connectionClosedError(); } if (!pool.getRollbackDistributedSQL) { - throw Error(`This adapter doesn't support distributed transactions.`); + throw new Error(`This adapter doesn't support distributed transactions.`); } const sqlQuery = pool.getRollbackDistributedSQL(name); @@ -854,7 +2362,7 @@ const SQL: typeof Bun.SQL = function SQL( } if (!pool.getCommitDistributedSQL) { - throw Error(`This adapter doesn't support distributed transactions.`); + throw new Error(`This adapter doesn't support distributed transactions.`); } const sqlQuery = pool.getCommitDistributedSQL(name); @@ -936,6 +2444,18 @@ const SQL: typeof Bun.SQL = function SQL( sql.transaction = sql.begin; sql.distributed = sql.beginDistributed; sql.end = sql.close; + // Expose adapter-level COPY defaults on SQL instance (only when supported by adapter) + if (pool && typeof pool.getCopyDefaults === "function" && typeof pool.setCopyDefaults === "function") { + sql.getCopyDefaults = () => pool.getCopyDefaults(); + sql.setCopyDefaults = (defaults: { + from?: { maxChunkSize?: number; maxBytes?: number; timeout?: number }; + to?: { stream?: boolean; maxBytes?: number; timeout?: number }; + }) => { + pool.setCopyDefaults(defaults); + return sql; + }; + } + return sql; }; diff --git a/src/js/internal/sql/postgres-encoding.ts b/src/js/internal/sql/postgres-encoding.ts new file mode 100644 index 00000000000..c98cf100342 --- /dev/null +++ b/src/js/internal/sql/postgres-encoding.ts @@ -0,0 +1,466 @@ +/** + * Shared PostgreSQL encoding utilities for binary COPY and array serialization + */ + +// Import shared type constants from centralized module +const { TYPE_OID, TYPE_ARRAY_OID } = require("./postgres-types"); + +// Re-export for consumers +export { TYPE_OID, TYPE_ARRAY_OID }; + +// Binary encoding helpers +const encText = new TextEncoder(); + +export function be16(n: number): Uint8Array { + const b = new Uint8Array(2); + new DataView(b.buffer).setInt16(0, n, false); + return b; +} + +export function be32(n: number): Uint8Array { + const b = new Uint8Array(4); + new DataView(b.buffer).setInt32(0, n, false); + return b; +} + +export function be64(big: bigint): Uint8Array { + const b = new Uint8Array(8); + const dv = new DataView(b.buffer); + dv.setInt32(0, Number((big >> 32n) & 0xffffffffn), false); + dv.setUint32(4, Number(big & 0xffffffffn), false); + return b; +} + +// Escape functions for PostgreSQL text format +export function copyTextEscape(s: string): string { + // COPY text format escaping: backslash, tab, newline, carriage return + return s.replaceAll("\\", "\\\\").replaceAll("\t", "\\t").replaceAll("\n", "\\n").replaceAll("\r", "\\r"); +} + +export function arrayEscape(value: string): string { + // Array element escaping: backslash and double quotes + return value.replace(/\\/g, "\\\\").replace(/"/g, '\\"'); +} + +export function csvQuote(s: string): string { + return `"${s.replaceAll('"', '""')}"`; +} + +/** + * Determine if a CSV field requires quoting based on RFC-like rules: + * quote when it contains a double-quote, newline, carriage return, or the delimiter. + */ +export function needsCsvQuote(s: string, delimiter: string = ","): boolean { + return s.includes('"') || s.includes("\n") || s.includes("\r") || s.includes(delimiter); +} + +// Numeric encoding for PostgreSQL binary format +function expandExponent(s: string): string { + const m = s.match(/^(-?)(\d+)(?:\.(\d+))?[eE]([+-]?\d+)$/); + if (!m) return s; + const sign = m[1] === "-" ? "-" : ""; + let intPart = m[2] || "0"; + let fracPart = m[3] || ""; + const exp = Math.trunc(Number(m[4])); + if (exp > 0) { + const needed = exp - fracPart.length; + if (needed >= 0) { + intPart = intPart + fracPart + "0".repeat(needed); + fracPart = ""; + } else { + intPart = intPart + fracPart.slice(0, exp); + fracPart = fracPart.slice(exp); + } + } else if (exp < 0) { + const zeros = "0".repeat(Math.max(0, -exp - intPart.length)); + const all = zeros ? zeros + intPart : intPart; + const idx = all.length + exp; + fracPart = all.slice(idx) + fracPart; + intPart = all.slice(0, idx) || "0"; + } + intPart = intPart.replace(/^0+/, "") || "0"; + return fracPart ? `${sign}${intPart}.${fracPart}` : `${sign}${intPart}`; +} + +export function encodeNumericBinary(val: any): Uint8Array { + let s = typeof val === "bigint" ? val.toString() : typeof val === "number" ? val.toString() : String(val); + s = s.trim(); + if (!/^-?(\d+)(\.\d+)?([eE][+-]?\d+)?$/.test(s)) { + throw new Error("numeric: value must be a plain decimal string/number"); + } + if (/[eE]/.test(s)) s = expandExponent(s); + let sign = 0x0000; + if (s.startsWith("-")) { + sign = 0x4000; + s = s.slice(1); + } else if (s.startsWith("+")) { + s = s.slice(1); + } + let intPart = s; + let fracPart = ""; + const dot = s.indexOf("."); + if (dot !== -1) { + intPart = s.slice(0, dot); + fracPart = s.slice(dot + 1); + } + intPart = intPart.replace(/^0+/, "") || "0"; + const padLeft = (4 - (intPart.length % 4)) % 4; + const intPadded = "0".repeat(padLeft) + intPart; + const intGroups: number[] = []; + for (let i = 0; i < intPadded.length; i += 4) { + intGroups.push(parseInt(intPadded.slice(i, i + 4), 10) || 0); + } + const dscale = fracPart.length; + const padRight = (4 - (fracPart.length % 4)) % 4; + const fracPadded = fracPart + "0".repeat(padRight); + const fracGroups: number[] = []; + for (let i = 0; i < fracPadded.length; i += 4) { + if (i < fracPart.length || padRight > 0) { + const g = fracPadded.slice(i, i + 4); + fracGroups.push(parseInt(g, 10) || 0); + } + } + while (intGroups.length > 0 && intGroups[0] === 0) intGroups.shift(); + let weight = intGroups.length - 1; + let digits = intGroups.concat(fracGroups); + while (digits.length > 0 && digits[digits.length - 1] === 0) digits.pop(); + if (digits.length === 0) { + const out = new Uint8Array(8); + const dv = new DataView(out.buffer); + dv.setInt16(0, 0, false); + dv.setInt16(2, 0, false); + dv.setInt16(4, 0x0000, false); + dv.setInt16(6, dscale | 0, false); + return out; + } + const ndigits = digits.length; + const out = new Uint8Array(8 + ndigits * 2); + const dv = new DataView(out.buffer); + dv.setInt16(0, ndigits, false); + dv.setInt16(2, weight, false); + dv.setInt16(4, sign, false); + dv.setInt16(6, dscale | 0, false); + let o = 8; + for (let i = 0; i < ndigits; i++) { + dv.setInt16(o, digits[i], false); + o += 2; + } + return out; +} + +export function encodeIntervalBinary(val: any): Uint8Array { + let months = 0, + days = 0; + let micros = 0n; + if (val && typeof val === "object") { + if ("months" in val) months = Number((val as any).months) | 0; + if ("days" in val) days = Number((val as any).days) | 0; + if ("micros" in val) micros = BigInt((val as any).micros); + else if ("ms" in val) micros = BigInt(Math.trunc((val as any).ms)) * 1000n; + else if ("seconds" in val) micros = BigInt(Math.trunc((val as any).seconds)) * 1_000_000n; + } else if (typeof val === "string") { + const m = val.match(/^(\d{1,2}):(\d{2}):(\d{2})(?:\.(\d{1,6}))?$/); + if (m) { + const hh = Number(m[1]) | 0, + mm = Number(m[2]) | 0, + ss = Number(m[3]) | 0; + const frac = (m[4] || "").padEnd(6, "0").slice(0, 6); + const us = Number(frac) | 0; + micros = BigInt((hh * 3600 + mm * 60 + ss) * 1_000_000 + us); + } else { + micros = 0n; + } + } else if (typeof val === "number") { + micros = BigInt(Math.trunc(val)) * 1000n; + } + const out = new Uint8Array(16); + const dv = new DataView(out.buffer); + dv.setInt32(0, Number((micros >> 32n) & 0xffffffffn), false); + dv.setUint32(4, Number(micros & 0xffffffffn), false); + dv.setInt32(8, days, false); + dv.setInt32(12, months, false); + return out; +} + +export type CopyBinaryBaseType = + | "bool" + | "int2" + | "int4" + | "int8" + | "float4" + | "float8" + | "text" + | "varchar" + | "bpchar" + | "bytea" + | "date" + | "time" + | "timestamp" + | "timestamptz" + | "uuid" + | "json" + | "jsonb" + | "numeric" + | "interval"; + +export type CopyBinaryArrayType = `${CopyBinaryBaseType}[]`; +export type CopyBinaryType = CopyBinaryBaseType | CopyBinaryArrayType; + +/** + * Encode a single value in PostgreSQL binary format for COPY + */ +export function encodeBinaryValue(v: unknown, t: CopyBinaryType): Uint8Array { + // Handle arrays like "int4[]" + if (t.endsWith("[]")) { + const base = t.slice(0, -2) as CopyBinaryBaseType; + if (!$isArray(v)) throw new Error("binary array expects a JavaScript array value"); + return encodeArray1D(v, base); + } + switch (t) { + case "bool": { + const out = new Uint8Array(1); + out[0] = v ? 1 : 0; + return out; + } + case "int2": { + const b = new Uint8Array(2); + new DataView(b.buffer).setInt16(0, Number(v) | 0, false); + return b; + } + case "int4": { + const b = new Uint8Array(4); + new DataView(b.buffer).setInt32(0, Number(v) | 0, false); + return b; + } + case "int8": { + const b = new Uint8Array(8); + const dv = new DataView(b.buffer); + const big = BigInt(v as string | number | bigint | boolean); + dv.setInt32(0, Number((big >> 32n) & 0xffffffffn), false); + dv.setUint32(4, Number(big & 0xffffffffn), false); + return b; + } + case "float4": { + const b = new Uint8Array(4); + new DataView(b.buffer).setFloat32(0, Number(v), false); + return b; + } + case "float8": { + const b = new Uint8Array(8); + new DataView(b.buffer).setFloat64(0, Number(v), false); + return b; + } + case "bytea": { + if (v instanceof Uint8Array) return v; + if (v && (v as any).byteLength !== undefined) return new Uint8Array(v as ArrayBuffer); + const s = typeof v === "string" ? v : v == null ? "" : String(v); + return encText.encode(s); + } + case "date": { + // int32 days since 2000-01-01 + const epoch2000 = Date.UTC(2000, 0, 1); + let ms: number; + if (v instanceof Date) ms = v.getTime(); + else if (typeof v === "number") ms = v; + else ms = new Date(String(v)).getTime(); + const days = Math.floor((ms - epoch2000) / 86400000); + const b = new Uint8Array(4); + new DataView(b.buffer).setInt32(0, days, false); + return b; + } + case "time": { + // int64 microseconds since midnight + const toMicros = (val: any): bigint => { + if (typeof val === "number") return BigInt(Math.floor(val)); + if (val instanceof Date) { + const h = val.getUTCHours(); + const m = val.getUTCMinutes(); + const s = val.getUTCSeconds(); + const ms = val.getUTCMilliseconds(); + return BigInt(((h * 3600 + m * 60 + s) * 1000 + ms) * 1000); + } + const str = String(val); + const m = str.match(/^(\d{1,2}):(\d{2}):(\d{2})(?:\.(\d{1,6}))?$/); + if (!m) return 0n; + const hh = Number(m[1]) | 0; + const mm = Number(m[2]) | 0; + const ss = Number(m[3]) | 0; + const frac = (m[4] || "").padEnd(6, "0").slice(0, 6); + const us = Number(frac) | 0; + return BigInt((hh * 3600 + mm * 60 + ss) * 1_000_000 + us); + }; + const micros = toMicros(v); + const b = new Uint8Array(8); + const dv = new DataView(b.buffer); + dv.setInt32(0, Number((micros >> 32n) & 0xffffffffn), false); + dv.setUint32(4, Number(micros & 0xffffffffn), false); + return b; + } + case "timestamp": + case "timestamptz": { + // int64 microseconds since 2000-01-01 UTC + const epoch2000 = Date.UTC(2000, 0, 1); + let ms: number; + if (v instanceof Date) ms = v.getTime(); + else if (typeof v === "number") ms = v; + else ms = new Date(String(v)).getTime(); + const micros = BigInt(Math.round((ms - epoch2000) * 1000)); + const b = new Uint8Array(8); + const dv = new DataView(b.buffer); + dv.setInt32(0, Number((micros >> 32n) & 0xffffffffn), false); + dv.setUint32(4, Number(micros & 0xffffffffn), false); + return b; + } + case "uuid": { + // 16 bytes + const s = String(v).toLowerCase(); + const hex = s.replace(/-/g, ""); + const out = new Uint8Array(16); + for (let i = 0; i < 16; i++) { + const byte = hex.slice(i * 2, i * 2 + 2); + out[i] = parseInt(byte, 16) || 0; + } + return out; + } + case "json": { + const s = typeof v === "string" ? v : JSON.stringify(v ?? null); + return encText.encode(s); + } + case "jsonb": { + const s = typeof v === "string" ? v : JSON.stringify(v ?? null); + const txt = encText.encode(s); + // version 1 + textual json + const out = new Uint8Array(1 + txt.length); + out[0] = 1; + out.set(txt, 1); + return out; + } + case "numeric": { + return encodeNumericBinary(v); + } + case "interval": { + return encodeIntervalBinary(v); + } + case "varchar": + case "bpchar": + case "text": + default: { + // default to text encoding for unknown types + const s = typeof v === "string" ? v : v == null ? "" : String(v); + return encText.encode(s); + } + } +} + +/** + * Encode a 1-dimensional array in PostgreSQL binary format + */ +export function encodeArray1D(arr: unknown[], elemType: CopyBinaryBaseType): Uint8Array { + const oid = TYPE_OID[elemType]; + if (!oid) throw new Error(`Unsupported array base type for binary encoding: ${elemType}`); + const n = arr.length; + let hasNull = 0; + const elems: Uint8Array[] = new Array(n); + for (let i = 0; i < n; i++) { + const v = arr[i]; + if (v === null || v === undefined) { + elems[i] = new Uint8Array(0); + hasNull = 1; + } else { + elems[i] = encodeBinaryValue(v, elemType); + } + } + let size = 4 * 3 + 8; // ndim, hasnull, oid, dim length + lbound + for (let i = 0; i < n; i++) { + size += 4 + (elems[i].length || 0); + } + const out = new Uint8Array(size); + const dv = new DataView(out.buffer); + let o = 0; + dv.setInt32(o, 1, false); // ndim + o += 4; + dv.setInt32(o, hasNull, false); + o += 4; + dv.setInt32(o, oid, false); + o += 4; + dv.setInt32(o, n, false); // length + o += 4; + dv.setInt32(o, 1, false); // lbound + o += 4; + for (let i = 0; i < n; i++) { + if (arr[i] === null || arr[i] === undefined) { + dv.setInt32(o, -1, false); + o += 4; + } else { + const b = elems[i]; + dv.setInt32(o, b.length, false); + o += 4; + out.set(b, o); + o += b.length; + } + } + return out; +} + +/** + * Encode a binary COPY row with the given types + */ +export function encodeBinaryRow(row: any[], types: CopyBinaryType[]): Uint8Array { + const fieldCount = types.length; + // First pass: compute total size + let size = 2; // int16 field count + const vals: Uint8Array[] = new Array(fieldCount); + for (let i = 0; i < fieldCount; i++) { + const val = row[i]; + if (val === null || val === undefined) { + size += 4; // -1 length + vals[i] = new Uint8Array(0); + continue; + } + const t = types[i]; + const bytes = encodeBinaryValue(val, t); + vals[i] = bytes; + size += 4 + bytes.length; + } + const out = new Uint8Array(size); + const dv = new DataView(out.buffer); + let o = 0; + dv.setInt16(o, fieldCount, false); + o += 2; + for (let i = 0; i < fieldCount; i++) { + const v = row[i]; + if (v === null || v === undefined) { + dv.setInt32(o, -1, false); + o += 4; + continue; + } + const bytes = vals[i]; + dv.setInt32(o, bytes.length, false); + o += 4; + out.set(bytes, o); + o += bytes.length; + } + return out; +} + +/** + * Create binary COPY header + */ +export function createBinaryCopyHeader(): Uint8Array { + const sig = new Uint8Array([0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00]); + const flags = new Uint8Array(4); // 0 + const extlen = new Uint8Array(4); // 0 + const out = new Uint8Array(sig.length + flags.length + extlen.length); + out.set(sig, 0); + out.set(flags, sig.length); + out.set(extlen, sig.length + flags.length); + return out; +} + +/** + * Create binary COPY trailer + */ +export function createBinaryCopyTrailer(): Uint8Array { + // int16 -1 (0xFFFF) big-endian + return new Uint8Array([0xff, 0xff]); +} diff --git a/src/js/internal/sql/postgres-types.ts b/src/js/internal/sql/postgres-types.ts new file mode 100644 index 00000000000..8d5b15d64e5 --- /dev/null +++ b/src/js/internal/sql/postgres-types.ts @@ -0,0 +1,240 @@ +/** + * Shared PostgreSQL type constants and utilities. + * + * This module provides a single source of truth for PostgreSQL type OIDs, + * used by both regular SQL array serialization and COPY binary protocol. + */ + +/** + * PostgreSQL base type OIDs (type name -> OID) + * Used for binary COPY format encoding + */ +export const BASE_TYPE_OID: Record = { + bool: 16, + int2: 21, + int4: 23, + int8: 20, + float4: 700, + float8: 701, + text: 25, + varchar: 1043, + bpchar: 1042, + bytea: 17, + date: 1082, + time: 1083, + timestamp: 1114, + timestamptz: 1184, + uuid: 2950, + json: 114, + jsonb: 3802, + numeric: 1700, + interval: 1186, +}; + +/** + * PostgreSQL array type OIDs (array type token -> OID) + * Used for binary COPY format array encoding + */ +export const ARRAY_TYPE_OID: Record = { + // Boolean + "bool[]": 1000, + + // Binary + "bytea[]": 1001, + + // Character types + "char[]": 1002, + "name[]": 1003, + "text[]": 1009, + "bpchar[]": 1014, + "varchar[]": 1015, + + // Numeric types + "int2[]": 1005, + "int4[]": 1007, + "int8[]": 1016, + "float4[]": 1021, + "float8[]": 1022, + "numeric[]": 1231, + + // Date/Time types + "date[]": 1182, + "time[]": 1183, + "timestamp[]": 1115, + "timestamptz[]": 1185, + "interval[]": 1187, + + // Other types + "uuid[]": 2951, + "json[]": 199, + "jsonb[]": 3807, +}; + +/** + * PostgreSQL array OID to type name mapping (OID -> type name) + * Used for decoding array types from PostgreSQL responses + */ +export const ARRAY_OID_TO_TYPE: Record = { + // Boolean + 1000: "BOOLEAN", + + // Binary + 1001: "BYTEA", + + // Character types + 1002: "CHAR", + 1003: "NAME", + 1009: "TEXT", + 1014: "CHAR", + 1015: "VARCHAR", + + // Numeric types + 1005: "SMALLINT", + 1006: "INT2VECTOR", + 1007: "INTEGER", + 1016: "BIGINT", + 1021: "REAL", + 1022: "DOUBLE PRECISION", + 1231: "NUMERIC", + 791: "MONEY", + + // OID types + 1028: "OID", + 1010: "TID", + 1011: "XID", + 1012: "CID", + + // JSON types + 199: "JSON", + 3802: "JSONB", + 3807: "JSONB", + 4072: "JSONPATH", + 4073: "JSONPATH", + + // XML + 143: "XML", + + // Geometric types + 1017: "POINT", + 1018: "LSEG", + 1019: "PATH", + 1020: "BOX", + 1027: "POLYGON", + 629: "LINE", + 719: "CIRCLE", + + // Network types + 651: "CIDR", + 1040: "MACADDR", + 1041: "INET", + 775: "MACADDR8", + 2951: "UUID", + + // Date/Time types + 1182: "DATE", + 1183: "TIME", + 1115: "TIMESTAMP", + 1185: "TIMESTAMPTZ", + 1187: "INTERVAL", + 1270: "TIMETZ", + + // Bit string types + 1561: "BIT", + 1563: "VARBIT", + + // ACL + 1034: "ACLITEM", + + // System catalog types + 12052: "PG_DATABASE", + 10052: "PG_DATABASE", +}; + +/** + * Check if a PostgreSQL type name is a numeric type + */ +export function isNumericType(type: string): boolean { + switch (type) { + case "BIT": + case "VARBIT": + case "SMALLINT": + case "INT2VECTOR": + case "INTEGER": + case "INT": + case "BIGINT": + case "REAL": + case "DOUBLE PRECISION": + case "NUMERIC": + case "MONEY": + return true; + default: + return false; + } +} + +/** + * Check if a PostgreSQL type name is a JSON type + */ +export function isJsonType(type: string): boolean { + switch (type) { + case "JSON": + case "JSONB": + return true; + default: + return false; + } +} + +/** + * Get array type name from OID, returns null if not found + */ +export function getArrayTypeName(oid: number): string | null { + return ARRAY_OID_TO_TYPE[oid] ?? null; +} + +/** + * Get base type OID from type name, returns undefined if not found + */ +export function getBaseTypeOid(typeName: string): number | undefined { + return BASE_TYPE_OID[typeName]; +} + +/** + * Get array type OID from array type token (e.g., "int4[]"), returns undefined if not found + */ +export function getArrayTypeOid(typeToken: string): number | undefined { + return ARRAY_TYPE_OID[typeToken]; +} + +/** + * Check if a type token is a supported base type for binary encoding + */ +export function isSupportedBaseType(token: string): boolean { + return Object.hasOwn(BASE_TYPE_OID, token); +} + +/** + * Check if a type token is a supported array type for binary encoding + */ +export function isSupportedArrayType(token: string): boolean { + return Object.hasOwn(ARRAY_TYPE_OID, token); +} + +/** + * Get list of supported base type names + */ +export function getSupportedBaseTypes(): string[] { + return Object.keys(BASE_TYPE_OID).sort(); +} + +/** + * Get list of supported array type tokens + */ +export function getSupportedArrayTypes(): string[] { + return Object.keys(ARRAY_TYPE_OID).sort(); +} + +// Type aliases for backwards compatibility +export const TYPE_OID = BASE_TYPE_OID; +export const TYPE_ARRAY_OID = ARRAY_TYPE_OID; +export const POSTGRES_ARRAY_TYPES = ARRAY_OID_TO_TYPE; diff --git a/src/js/internal/sql/postgres.ts b/src/js/internal/sql/postgres.ts index af4502cd9ad..bb29a40db82 100644 --- a/src/js/internal/sql/postgres.ts +++ b/src/js/internal/sql/postgres.ts @@ -20,123 +20,36 @@ function isTypedArray(value: any) { } const { PostgresError } = require("internal/sql/errors"); +const { arrayEscape } = require("internal/sql/postgres-encoding"); +const { + POSTGRES_ARRAY_TYPES, + isNumericType: isPostgresNumericType, + isJsonType: isPostgresJsonType, +} = require("internal/sql/postgres-types"); const { createConnection: createPostgresConnection, createQuery: createPostgresQuery, init: initPostgres, + sendCopyData, + sendCopyDone, + sendCopyFail, + awaitWritable, + setCopyStreamingMode, + setCopyChunkHandlerRegistered, + setCopyTimeout, + setMaxCopyBufferSize, + setMaxCopyBufferSizeUnsafe, } = $zig("postgres.zig", "createBinding") as PostgresDotZig; -const cmds = ["", "INSERT", "DELETE", "UPDATE", "MERGE", "SELECT", "MOVE", "FETCH", "COPY"]; - -const escapeBackslash = /\\/g; -const escapeQuote = /"/g; +const copyStartHandlers = new WeakMap<$ZigGeneratedClasses.PostgresSQLConnection, () => void>(); +const copyChunkHandlers = new WeakMap<$ZigGeneratedClasses.PostgresSQLConnection, (chunk: any) => void>(); +const copyEndHandlers = new WeakMap<$ZigGeneratedClasses.PostgresSQLConnection, () => void>(); +const writableHandlers = new WeakMap<$ZigGeneratedClasses.PostgresSQLConnection, () => void>(); -function arrayEscape(value: string) { - return value.replace(escapeBackslash, "\\\\").replace(escapeQuote, '\\"'); -} -const POSTGRES_ARRAY_TYPES = { - // Boolean - 1000: "BOOLEAN", // bool_array - - // Binary - 1001: "BYTEA", // bytea_array - - // Character types - 1002: "CHAR", // char_array - 1003: "NAME", // name_array - 1009: "TEXT", // text_array - 1014: "CHAR", // bpchar_array - 1015: "VARCHAR", // varchar_array - - // Numeric types - 1005: "SMALLINT", // int2_array - 1006: "INT2VECTOR", // int2vector_array - 1007: "INTEGER", // int4_array - 1016: "BIGINT", // int8_array - 1021: "REAL", // float4_array - 1022: "DOUBLE PRECISION", // float8_array - 1231: "NUMERIC", // numeric_array - 791: "MONEY", // money_array - - // OID types - 1028: "OID", // oid_array - 1010: "TID", // tid_array - 1011: "XID", // xid_array - 1012: "CID", // cid_array - - // JSON types - 199: "JSON", // json_array - 3802: "JSONB", // jsonb (not array) - 3807: "JSONB", // jsonb_array - 4072: "JSONPATH", // jsonpath - 4073: "JSONPATH", // jsonpath_array - - // XML - 143: "XML", // xml_array - - // Geometric types - 1017: "POINT", // point_array - 1018: "LSEG", // lseg_array - 1019: "PATH", // path_array - 1020: "BOX", // box_array - 1027: "POLYGON", // polygon_array - 629: "LINE", // line_array - 719: "CIRCLE", // circle_array - - // Network types - 651: "CIDR", // cidr_array - 1040: "MACADDR", // macaddr_array - 1041: "INET", // inet_array - 775: "MACADDR8", // macaddr8_array - - // Date/Time types - 1182: "DATE", // date_array - 1183: "TIME", // time_array - 1115: "TIMESTAMP", // timestamp_array - 1185: "TIMESTAMPTZ", // timestamptz_array - 1187: "INTERVAL", // interval_array - 1270: "TIMETZ", // timetz_array - - // Bit string types - 1561: "BIT", // bit_array - 1563: "VARBIT", // varbit_array - - // ACL - 1034: "ACLITEM", // aclitem_array - - // System catalog types - 12052: "PG_DATABASE", // pg_database_array - 10052: "PG_DATABASE", // pg_database_array2 -}; +const cmds = ["", "INSERT", "DELETE", "UPDATE", "MERGE", "SELECT", "MOVE", "FETCH", "COPY"]; -function isPostgresNumericType(type: string) { - switch (type) { - case "BIT": // bit_array - case "VARBIT": // varbit_array - case "SMALLINT": // int2_array - case "INT2VECTOR": // int2vector_array - case "INTEGER": // int4_array - case "INT": // int4_array - case "BIGINT": // int8_array - case "REAL": // float4_array - case "DOUBLE PRECISION": // float8_array - case "NUMERIC": // numeric_array - case "MONEY": // money_array - return true; - default: - return false; - } -} -function isPostgresJsonType(type: string) { - switch (type) { - case "JSON": - case "JSONB": - return true; - default: - return false; - } -} +// POSTGRES_ARRAY_TYPES, isPostgresNumericType, isPostgresJsonType imported from postgres-types function getPostgresArrayType(typeId: number) { return POSTGRES_ARRAY_TYPES[typeId] || null; } @@ -181,16 +94,22 @@ function arrayValueSerializer(type: ArrayType, is_numeric: boolean, is_json: boo // fallback to string return value === true ? '"true"' : '"false"'; } - default: - if (value instanceof Date) { - const isoValue = value.toISOString(); + case "object": { + // Type assertion needed because TypeScript's control flow analysis + // incorrectly infers 'never' after the typeof switch + const objectValue = value as object | null; + if (objectValue === null) { + return "null"; + } + if (objectValue instanceof Date) { + const isoValue = objectValue.toISOString(); if (is_json) { return `"${arrayEscape(JSON.stringify(isoValue))}"`; } return `"${arrayEscape(isoValue)}"`; } - if (Buffer.isBuffer(value)) { - const hexValue = value.toString("hex"); + if (Buffer.isBuffer(objectValue)) { + const hexValue = objectValue.toString("hex"); // bytea array if (type === "BYTEA") { return `"\\x${arrayEscape(hexValue)}"`; @@ -201,6 +120,10 @@ function arrayValueSerializer(type: ArrayType, is_numeric: boolean, is_json: boo return `"${arrayEscape(hexValue)}"`; } // fallback to JSON.stringify + return `"${arrayEscape(JSON.stringify(objectValue))}"`; + } + default: + // function, symbol - fallback to JSON.stringify return `"${arrayEscape(JSON.stringify(value))}"`; } } @@ -230,7 +153,14 @@ function wrapPostgresError(error: Error | PostgresErrorOptions) { if (Error.isError(error)) { return error; } - return new PostgresError(error.message, error); + + let message = "PostgreSQL error"; + + if ("message" in error) { + message = error.message as string; + } + + return new PostgresError(message, error); } initPostgres( @@ -314,6 +244,48 @@ initPostgres( query.reject(reject as Error); } catch {} }, + function onCopyStart(this: $ZigGeneratedClasses.PostgresSQLConnection) { + const handler = copyStartHandlers.get(this); + if (handler) { + copyStartHandlers.delete(this); + try { + handler(); + } catch {} + } + }, + function onCopyChunk(this: $ZigGeneratedClasses.PostgresSQLConnection, chunk: any) { + const handler = copyChunkHandlers.get(this); + if (handler) { + try { + handler(chunk); + } catch {} + } + }, + function onCopyEnd(this: $ZigGeneratedClasses.PostgresSQLConnection) { + const handler = copyEndHandlers.get(this); + if (handler) { + try { + handler(); + } catch {} + } + // Always clear COPY handlers on end (even if no explicit end handler was registered), + // to avoid retaining a connection object through WeakMap entries. + copyChunkHandlers.delete(this); + copyEndHandlers.delete(this); + copyStartHandlers.delete(this); + try { + setCopyChunkHandlerRegistered(this, false); + } catch {} + }, + function onWritable(this: $ZigGeneratedClasses.PostgresSQLConnection) { + const handler = writableHandlers.get(this); + if (handler) { + writableHandlers.delete(this); + try { + handler(); + } catch {} + } + }, ); export interface PostgresDotZig { @@ -327,6 +299,10 @@ export interface PostgresDotZig { is_last: boolean, ) => void, onRejectQuery: (query: Query, err: Error, queries) => void, + onCopyStart: (this: $ZigGeneratedClasses.PostgresSQLConnection) => void, + onCopyChunk: (this: $ZigGeneratedClasses.PostgresSQLConnection, chunk: any) => void, + onCopyEnd: (this: $ZigGeneratedClasses.PostgresSQLConnection) => void, + onWritable: (this: $ZigGeneratedClasses.PostgresSQLConnection) => void, ) => void; createConnection: ( hostname: string | undefined, @@ -353,9 +329,35 @@ export interface PostgresDotZig { bigint: boolean, simple: boolean, ) => $ZigGeneratedClasses.PostgresSQLQuery; + + // Low-level COPY helpers (call with explicit thisArg as first parameter) + sendCopyData: (connection: $ZigGeneratedClasses.PostgresSQLConnection, data: string | Uint8Array) => void; + sendCopyDone: (connection: $ZigGeneratedClasses.PostgresSQLConnection) => void; + sendCopyFail: (connection: $ZigGeneratedClasses.PostgresSQLConnection, message?: string) => void; + awaitWritable: (connection: $ZigGeneratedClasses.PostgresSQLConnection) => Promise; + setCopyStreamingMode: (connection: $ZigGeneratedClasses.PostgresSQLConnection, enable: boolean) => void; + setCopyChunkHandlerRegistered: (connection: $ZigGeneratedClasses.PostgresSQLConnection, registered: boolean) => void; + setCopyTimeout: (connection: $ZigGeneratedClasses.PostgresSQLConnection, ms: number) => void; + setMaxCopyBufferSize: (connection: $ZigGeneratedClasses.PostgresSQLConnection, bytes: number) => void; + setMaxCopyBufferSizeUnsafe: (connection: $ZigGeneratedClasses.PostgresSQLConnection, bytes: number) => void; } -const enum SQLCommand { +function onCopyStart(connection: $ZigGeneratedClasses.PostgresSQLConnection, handler: () => void) { + // kept for internal use; prefer PostgresAdapter.onCopyStart + copyStartHandlers.set(connection, handler); +} +function copySendData(connection: $ZigGeneratedClasses.PostgresSQLConnection, data: string | Uint8Array) { + // delegate to Zig binding (expects explicit thisArg as first parameter) + // Zig side accepts string and ArrayBuffer/TypedArray payloads via PostgresSQLConnection.copySendDataFromJSValue. + sendCopyData(connection, data); +} +function copyDone(connection: $ZigGeneratedClasses.PostgresSQLConnection) { + sendCopyDone(connection); +} +function copyFail(connection: $ZigGeneratedClasses.PostgresSQLConnection, message?: string) { + sendCopyFail(connection, message ?? ""); +} +enum SQLCommand { insert = 0, update = 1, updateSet = 2, @@ -363,7 +365,6 @@ const enum SQLCommand { in = 4, none = -1, } -export type { SQLCommand }; function commandToString(command: SQLCommand): string { switch (command) { @@ -579,14 +580,33 @@ class PooledPostgresConnection { if (connectionInfo?.onclose) { connectionInfo.onclose(err); } + + const underlyingConnection = this.connection; + this.state = PooledConnectionState.closed; this.connection = null; this.storedError = err; + if (underlyingConnection) { + // Clear any COPY/writable handlers to avoid retaining the underlying connection object. + copyStartHandlers.delete(underlyingConnection); + copyChunkHandlers.delete(underlyingConnection); + copyEndHandlers.delete(underlyingConnection); + writableHandlers.delete(underlyingConnection); + try { + setCopyChunkHandlerRegistered(underlyingConnection, false); + } catch {} + } + // remove from ready connections if its there this.adapter.readyConnections?.delete(this); const queries = new Set(this.queries); this.queries?.clear?.(); + // Decrement totalQueries by current queryCount before zeroing + // This ensures the adapter's pending query count stays accurate + if (this.queryCount > 0) { + this.adapter.totalQueries -= this.queryCount; + } this.queryCount = 0; this.flags &= ~PooledConnectionFlags.reserved; @@ -699,10 +719,41 @@ class PostgresAdapter public totalQueries: number = 0; public onAllQueriesFinished: (() => void) | null = null; + // Default COPY behavior and guardrails for this adapter instance + public copyDefaults: { + from: { maxChunkSize: number; maxBytes: number; timeout: number }; + to: { stream: boolean; maxBytes: number; timeout: number }; + } = { + from: { maxChunkSize: 256 * 1024, maxBytes: 0, timeout: 0 }, // 0 = unlimited + to: { stream: true, maxBytes: 0, timeout: 0 }, // 0 = unlimited + }; + + // Global defaults for new adapters (can be overridden via setGlobalCopyDefaults) + static globalCopyDefaults: { + from: { maxChunkSize: number; maxBytes: number; timeout: number }; + to: { stream: boolean; maxBytes: number; timeout: number }; + } = { + from: { maxChunkSize: 256 * 1024, maxBytes: 0, timeout: 0 }, + to: { stream: true, maxBytes: 0, timeout: 0 }, + }; + constructor(connectionInfo: Bun.SQL.__internal.DefinedPostgresOrMySQLOptions) { this.connectionInfo = connectionInfo; this.connections = new Array(connectionInfo.max); this.readyConnections = new Set(); + // Clone global defaults into this instance + this.copyDefaults = { + from: { + maxChunkSize: PostgresAdapter.globalCopyDefaults.from.maxChunkSize, + maxBytes: PostgresAdapter.globalCopyDefaults.from.maxBytes, + timeout: PostgresAdapter.globalCopyDefaults.from.timeout, + }, + to: { + stream: PostgresAdapter.globalCopyDefaults.to.stream, + maxBytes: PostgresAdapter.globalCopyDefaults.to.maxBytes, + timeout: PostgresAdapter.globalCopyDefaults.to.timeout, + }, + }; } escapeIdentifier(str: string) { @@ -733,6 +784,196 @@ class PostgresAdapter return true; } + // Global setter to change defaults for subsequently constructed adapters + static setGlobalCopyDefaults( + newDefaults: Partial<{ + from: Partial<{ maxChunkSize: number; maxBytes: number; timeout: number }>; + to: Partial<{ stream: boolean; maxBytes: number; timeout: number }>; + }>, + ) { + if (!newDefaults) return; + if (newDefaults.from) { + if (typeof newDefaults.from.maxChunkSize === "number" && newDefaults.from.maxChunkSize > 0) { + PostgresAdapter.globalCopyDefaults.from.maxChunkSize = Math.floor(newDefaults.from.maxChunkSize); + } + if (typeof newDefaults.from.maxBytes === "number" && newDefaults.from.maxBytes >= 0) { + PostgresAdapter.globalCopyDefaults.from.maxBytes = Math.floor(newDefaults.from.maxBytes); + } + if (typeof newDefaults.from.timeout === "number" && newDefaults.from.timeout >= 0) { + PostgresAdapter.globalCopyDefaults.from.timeout = Math.floor(newDefaults.from.timeout); + } + } + if (newDefaults.to) { + if (typeof newDefaults.to.stream === "boolean") { + PostgresAdapter.globalCopyDefaults.to.stream = newDefaults.to.stream; + } + if (typeof newDefaults.to.maxBytes === "number" && newDefaults.to.maxBytes >= 0) { + PostgresAdapter.globalCopyDefaults.to.maxBytes = Math.floor(newDefaults.to.maxBytes); + } + if (typeof newDefaults.to.timeout === "number" && newDefaults.to.timeout >= 0) { + PostgresAdapter.globalCopyDefaults.to.timeout = Math.floor(newDefaults.to.timeout); + } + } + } + + // Instance getter to read current defaults (for sql.ts to merge with per-call options) + getCopyDefaults() { + return this.copyDefaults; + } + + // Instance setter to change defaults for this adapter instance + setCopyDefaults( + newDefaults: Partial<{ + from: Partial<{ maxChunkSize: number; maxBytes: number; timeout: number }>; + to: Partial<{ stream: boolean; maxBytes: number; timeout: number }>; + }>, + ) { + if (!newDefaults) return; + if (newDefaults.from) { + if (typeof newDefaults.from.maxChunkSize === "number" && newDefaults.from.maxChunkSize > 0) { + this.copyDefaults.from.maxChunkSize = Math.floor(newDefaults.from.maxChunkSize); + } + if (typeof newDefaults.from.maxBytes === "number" && newDefaults.from.maxBytes >= 0) { + this.copyDefaults.from.maxBytes = Math.floor(newDefaults.from.maxBytes); + } + if (typeof newDefaults.from.timeout === "number" && newDefaults.from.timeout >= 0) { + this.copyDefaults.from.timeout = Math.floor(newDefaults.from.timeout); + } + } + if (newDefaults.to) { + if (typeof newDefaults.to.stream === "boolean") { + this.copyDefaults.to.stream = newDefaults.to.stream; + } + if (typeof newDefaults.to.maxBytes === "number" && newDefaults.to.maxBytes >= 0) { + this.copyDefaults.to.maxBytes = Math.floor(newDefaults.to.maxBytes); + } + if (typeof newDefaults.to.timeout === "number" && newDefaults.to.timeout >= 0) { + this.copyDefaults.to.timeout = Math.floor(newDefaults.to.timeout); + } + } + } + + // Reserved connection helper. Note: the `connection` parameter is intentionally ignored. + // This forwards to global adapter-level defaults via `setCopyDefaults()`. + // If per-connection defaults are desired, callers should configure them on the connection object (when supported). + setCopyDefaultsFor( + connection: PooledPostgresConnection, + newDefaults: Partial<{ + from: Partial<{ maxChunkSize: number; maxBytes: number }>; + to: Partial<{ stream: boolean; maxBytes: number }>; + }>, + ) { + this.setCopyDefaults(newDefaults); + } + + // COPY protocol low-level helpers exposed as static methods for internal use + static onCopyStart(connection: $ZigGeneratedClasses.PostgresSQLConnection, handler: () => void) { + copyStartHandlers.set(connection, handler); + } + static onCopyChunk(connection: $ZigGeneratedClasses.PostgresSQLConnection, handler: (chunk: any) => void) { + copyChunkHandlers.set(connection, handler); + try { + setCopyChunkHandlerRegistered(connection, true); + } catch {} + } + static onCopyEnd(connection: $ZigGeneratedClasses.PostgresSQLConnection, handler: () => void) { + copyEndHandlers.set(connection, handler); + } + static copySendData(connection: $ZigGeneratedClasses.PostgresSQLConnection, data: string | Uint8Array) { + // delegate to Zig binding (expects explicit thisArg as first parameter) + sendCopyData(connection, data); + } + static copyDone(connection: $ZigGeneratedClasses.PostgresSQLConnection) { + sendCopyDone(connection); + } + static copyFail(connection: $ZigGeneratedClasses.PostgresSQLConnection, message?: string) { + sendCopyFail(connection, message ?? ""); + } + static setCopyStreamingMode(connection: $ZigGeneratedClasses.PostgresSQLConnection, enable: boolean) { + setCopyStreamingMode(connection, !!enable); + } + static setCopyTimeout(connection: $ZigGeneratedClasses.PostgresSQLConnection, ms: number) { + const n = Math.min(0xffffffff, Math.max(0, Math.trunc(Number(ms) || 0))); + setCopyTimeout(connection, n); + } + static setMaxCopyBufferSize(connection: $ZigGeneratedClasses.PostgresSQLConnection, bytes: number) { + // Normalize to a non-negative integer. Zig enforces the safety cap (and treats 0 as disabled). + const n = Math.max(0, Math.trunc(Number(bytes) || 0)); + setMaxCopyBufferSize(connection, n); + } + + static setMaxCopyBufferSizeUnsafe(connection: $ZigGeneratedClasses.PostgresSQLConnection, bytes: number) { + // Normalize to a non-negative integer. Zig enforces the hard cap (and treats 0 as disabled). + const n = Math.max(0, Math.trunc(Number(bytes) || 0)); + setMaxCopyBufferSizeUnsafe(connection, n); + } + static onWritable(connection: $ZigGeneratedClasses.PostgresSQLConnection, handler: () => void) { + writableHandlers.set(connection, handler); + } + static awaitWritable(connection: $ZigGeneratedClasses.PostgresSQLConnection, handler?: () => void) { + if (handler) { + writableHandlers.set(connection, handler); + } + // Use the connection as thisArg; the Zig binding returns a Promise that resolves when the socket becomes writable. + return awaitWritable(connection); + } + + // Instance helpers to control COPY using a pooled connection handle + onCopyStartFor(connection: PooledPostgresConnection, handler: () => void) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + PostgresAdapter.onCopyStart(underlying, handler); + } + } + copySendDataFor(connection: PooledPostgresConnection, data: string | Uint8Array) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + PostgresAdapter.copySendData(underlying, data); + } + } + copyDoneFor(connection: PooledPostgresConnection) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + PostgresAdapter.copyDone(underlying); + } + } + copyFailFor(connection: PooledPostgresConnection, message?: string) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + PostgresAdapter.copyFail(underlying, message); + } + } + onWritableFor(connection: PooledPostgresConnection, handler: () => void) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + PostgresAdapter.onWritable(underlying, handler); + } + } + awaitWritableFor(connection: PooledPostgresConnection, handler?: () => void) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + return PostgresAdapter.awaitWritable(underlying, handler); + } + } + setCopyStreamingModeFor(connection: PooledPostgresConnection, enable: boolean) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + PostgresAdapter.setCopyStreamingMode(underlying, enable); + } + } + setCopyTimeoutFor(connection: PooledPostgresConnection, ms: number) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + PostgresAdapter.setCopyTimeout(underlying, ms); + } + } + setMaxCopyBufferSizeFor(connection: PooledPostgresConnection, bytes: number) { + const underlying = this.getConnectionForQuery(connection); + if (underlying) { + PostgresAdapter.setMaxCopyBufferSize(underlying, bytes); + } + } + getConnectionForQuery(pooledConnection: PooledPostgresConnection) { return pooledConnection.connection; } @@ -1404,4 +1645,7 @@ export default { SQLCommand, commandToString, detectCommand, + arrayValueSerializer, + getArrayType, + serializeArray, }; diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index 29de836e543..110f350d9c0 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -14,17 +14,167 @@ pub fn createBinding(globalObject: *jsc.JSGlobalObject) JSValue { jsc.JSFunction.create(globalObject, "createConnection", PostgresSQLConnection.call, 2, .{}), ); + binding.put(globalObject, ZigString.static("sendCopyData"), jsc.JSFunction.create(globalObject, "sendCopyData", __pg_sendCopyData, 2, .{})); + binding.put(globalObject, ZigString.static("sendCopyDone"), jsc.JSFunction.create(globalObject, "sendCopyDone", __pg_sendCopyDone, 1, .{})); + binding.put(globalObject, ZigString.static("sendCopyFail"), jsc.JSFunction.create(globalObject, "sendCopyFail", __pg_sendCopyFail, 2, .{})); + binding.put(globalObject, ZigString.static("awaitWritable"), jsc.JSFunction.create(globalObject, "awaitWritable", __pg_awaitWritable, 2, .{})); + binding.put(globalObject, ZigString.static("setCopyStreamingMode"), jsc.JSFunction.create(globalObject, "setCopyStreamingMode", __pg_setCopyStreamingMode, 2, .{})); + binding.put(globalObject, ZigString.static("setCopyChunkHandlerRegistered"), jsc.JSFunction.create(globalObject, "setCopyChunkHandlerRegistered", __pg_setCopyChunkHandlerRegistered, 2, .{})); + binding.put(globalObject, ZigString.static("setCopyTimeout"), jsc.JSFunction.create(globalObject, "setCopyTimeout", __pg_setCopyTimeout, 2, .{})); + binding.put(globalObject, ZigString.static("setMaxCopyBufferSize"), jsc.JSFunction.create(globalObject, "setMaxCopyBufferSize", __pg_setMaxCopyBufferSize, 2, .{})); + binding.put(globalObject, ZigString.static("setMaxCopyBufferSizeUnsafe"), jsc.JSFunction.create(globalObject, "setMaxCopyBufferSizeUnsafe", __pg_setMaxCopyBufferSizeUnsafe, 2, .{})); + return binding; } -pub const PostgresSQLConnection = @import("./postgres/PostgresSQLConnection.zig"); -pub const PostgresSQLContext = @import("./postgres/PostgresSQLContext.zig"); -pub const PostgresSQLQuery = @import("./postgres/PostgresSQLQuery.zig"); -pub const protocol = @import("./postgres/PostgresProtocol.zig"); -pub const types = @import("./postgres/PostgresTypes.zig"); +// Low-level COPY helper wrappers (call with .call(connection, ...)) +fn __pg_sendCopyData(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection, Arg1: data + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("sendCopyData first argument must be a PostgresSQLConnection", .{}); + }; -const bun = @import("bun"); + const data_value = callframe.argument(1); + if (data_value == .zero) { + return globalObject.throwNotEnoughArguments("sendCopyData", 2, 1); + } + + try connection.copySendDataFromJSValue(globalObject, data_value); + return .js_undefined; +} +fn __pg_sendCopyDone(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("sendCopyDone first argument must be a PostgresSQLConnection", .{}); + }; + return connection.sendCopyDone(globalObject, callframe); +} +fn __pg_sendCopyFail(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection, Arg1: message? + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("sendCopyFail first argument must be a PostgresSQLConnection", .{}); + }; + + const args = callframe.arguments(); + const message_value: jsc.JSValue = if (args.len > 1) args[1] else .js_undefined; + + try connection.copySendFailFromJSValue(globalObject, message_value); + return .js_undefined; +} +fn __pg_setCopyStreamingMode(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection, Arg1: enable (boolean) + // Returns: undefined. + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("setCopyStreamingMode first argument must be a PostgresSQLConnection", .{}); + }; + + const enable_arg = callframe.argument(1); + const enable = enable_arg.toBoolean(); + + // Apply the requested mode, but never enable streaming unless a per-connection chunk handler is registered. + // Otherwise, COPY TO streaming could silently drop data. + connection.copy_streaming_mode = enable and connection.copy_chunk_handler_registered; + + return .js_undefined; +} + +fn __pg_setCopyChunkHandlerRegistered(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection, Arg1: registered (boolean) + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("setCopyChunkHandlerRegistered first argument must be a PostgresSQLConnection", .{}); + }; + + const registered_arg = callframe.argument(1); + const registered = registered_arg.toBoolean(); + + connection.copy_chunk_handler_registered = registered; + + return .js_undefined; +} + +fn __pg_setCopyTimeout(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection, Arg1: timeout in ms (number; 0 disables COPY timeout) + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("setCopyTimeout first argument must be a PostgresSQLConnection", .{}); + }; + const ms_value = callframe.argument(1); + if (ms_value == .zero) { + return globalObject.throwNotEnoughArguments("setCopyTimeout", 2, 1); + } + + const ms_num = try ms_value.toNumber(globalObject); + + // 0 means disabled. Clamp to u32 max. + var ms_u32: u32 = 0; + if (std.math.isFinite(ms_num) and ms_num > 0) { + const max_u32_f64: f64 = @floatFromInt(std.math.maxInt(u32)); + const clamped_f64: f64 = @min(ms_num, max_u32_f64); + const ms_u64: u64 = @intFromFloat(clamped_f64); + ms_u32 = @intCast(ms_u64); + } + + connection.copy_timeout_ms = ms_u32; + + return .js_undefined; +} + +fn __pg_setMaxCopyBufferSize(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection, Arg1: size in bytes (number; 0 disables limit) + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("setMaxCopyBufferSize first argument must be a PostgresSQLConnection", .{}); + }; + + const bytes_value = callframe.argument(1); + if (bytes_value == .zero) { + return globalObject.throwNotEnoughArguments("setMaxCopyBufferSize", 2, 1); + } + + // Delegate to the connection method to apply the safety cap. + return connection.setMaxCopyBufferSize(globalObject, callframe); +} + +fn __pg_setMaxCopyBufferSizeUnsafe(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection, Arg1: size in bytes (number; 0 disables limit) + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("setMaxCopyBufferSizeUnsafe first argument must be a PostgresSQLConnection", .{}); + }; + + const bytes_value = callframe.argument(1); + if (bytes_value == .zero) { + return globalObject.throwNotEnoughArguments("setMaxCopyBufferSizeUnsafe", 2, 1); + } + + // Delegate to the connection method to apply the hard cap. + return connection.setMaxCopyBufferSizeUnsafe(globalObject, callframe); +} +fn __pg_awaitWritable(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + // Arg0: PostgresSQLConnection + const connection_value = callframe.argument(0); + const connection: *PostgresSQLConnection = connection_value.as(PostgresSQLConnection) orelse { + return globalObject.throw("awaitWritable first argument must be a PostgresSQLConnection", .{}); + }; + + // Delegate to the connection method, which returns a Promise that resolves when the socket becomes writable. + return connection.awaitWritable(globalObject, callframe); +} + +const std = @import("std"); +const bun = @import("bun"); const jsc = bun.jsc; const JSValue = jsc.JSValue; const ZigString = jsc.ZigString; + +pub const protocol = @import("./postgres/PostgresProtocol.zig"); +pub const PostgresSQLConnection = @import("./postgres/PostgresSQLConnection.zig"); +pub const PostgresSQLContext = @import("./postgres/PostgresSQLContext.zig"); +pub const PostgresSQLQuery = @import("./postgres/PostgresSQLQuery.zig"); +pub const types = @import("./postgres/PostgresTypes.zig"); diff --git a/src/sql/postgres/AnyPostgresError.zig b/src/sql/postgres/AnyPostgresError.zig index 6e93fa39c45..cf713e0a4aa 100644 --- a/src/sql/postgres/AnyPostgresError.zig +++ b/src/sql/postgres/AnyPostgresError.zig @@ -1,5 +1,9 @@ pub const AnyPostgresError = error{ ConnectionClosed, + CopyBothNotImplemented, + CopyBufferTooLarge, + CopyChunkTooLarge, + CopyWriteFailed, ExpectedRequest, ExpectedStatement, InvalidBackendKeyData, @@ -19,12 +23,14 @@ pub const AnyPostgresError = error{ NullsInArrayNotSupportedYet, OutOfMemory, Overflow, + CopyTimeout, PBKDFD2, SASL_SIGNATURE_MISMATCH, SASL_SIGNATURE_INVALID_BASE64, ShortRead, TLSNotAvailable, TLSUpgradeFailed, + UnexpectedCopyData, UnexpectedMessage, UNKNOWN_AUTHENTICATION_METHOD, UNSUPPORTED_AUTHENTICATION_METHOD, @@ -80,6 +86,10 @@ pub fn createPostgresError( pub fn postgresErrorToJS(globalObject: *jsc.JSGlobalObject, message: ?[]const u8, err: AnyPostgresError) JSValue { const code = switch (err) { error.ConnectionClosed => "ERR_POSTGRES_CONNECTION_CLOSED", + error.CopyBothNotImplemented => "ERR_POSTGRES_COPY_BOTH_NOT_IMPLEMENTED", + error.CopyBufferTooLarge => "ERR_POSTGRES_COPY_BUFFER_TOO_LARGE", + error.CopyChunkTooLarge => "ERR_POSTGRES_COPY_CHUNK_TOO_LARGE", + error.CopyWriteFailed => "ERR_POSTGRES_COPY_WRITE_FAILED", error.ExpectedRequest => "ERR_POSTGRES_EXPECTED_REQUEST", error.ExpectedStatement => "ERR_POSTGRES_EXPECTED_STATEMENT", error.InvalidBackendKeyData => "ERR_POSTGRES_INVALID_BACKEND_KEY_DATA", @@ -96,11 +106,13 @@ pub fn postgresErrorToJS(globalObject: *jsc.JSGlobalObject, message: ?[]const u8 error.MultidimensionalArrayNotSupportedYet => "ERR_POSTGRES_MULTIDIMENSIONAL_ARRAY_NOT_SUPPORTED_YET", error.NullsInArrayNotSupportedYet => "ERR_POSTGRES_NULLS_IN_ARRAY_NOT_SUPPORTED_YET", error.Overflow => "ERR_POSTGRES_OVERFLOW", + error.CopyTimeout => "ERR_POSTGRES_COPY_TIMEOUT", error.PBKDFD2 => "ERR_POSTGRES_AUTHENTICATION_FAILED_PBKDF2", error.SASL_SIGNATURE_MISMATCH => "ERR_POSTGRES_SASL_SIGNATURE_MISMATCH", error.SASL_SIGNATURE_INVALID_BASE64 => "ERR_POSTGRES_SASL_SIGNATURE_INVALID_BASE64", error.TLSNotAvailable => "ERR_POSTGRES_TLS_NOT_AVAILABLE", error.TLSUpgradeFailed => "ERR_POSTGRES_TLS_UPGRADE_FAILED", + error.UnexpectedCopyData => "ERR_POSTGRES_UNEXPECTED_COPY_DATA", error.UnexpectedMessage => "ERR_POSTGRES_UNEXPECTED_MESSAGE", error.UNKNOWN_AUTHENTICATION_METHOD => "ERR_POSTGRES_UNKNOWN_AUTHENTICATION_METHOD", error.UNSUPPORTED_AUTHENTICATION_METHOD => "ERR_POSTGRES_UNSUPPORTED_AUTHENTICATION_METHOD", diff --git a/src/sql/postgres/PostgresProtocol.zig b/src/sql/postgres/PostgresProtocol.zig index 20e6cd21909..bb91918f2f8 100644 --- a/src/sql/postgres/PostgresProtocol.zig +++ b/src/sql/postgres/PostgresProtocol.zig @@ -26,6 +26,10 @@ pub const BackendKeyData = @import("./protocol/BackendKeyData.zig"); pub const CommandComplete = @import("./protocol/CommandComplete.zig"); pub const CopyData = @import("./protocol/CopyData.zig"); pub const CopyFail = @import("./protocol/CopyFail.zig"); +pub const CopyResponse = @import("./protocol/CopyResponse.zig"); +pub const CopyBothResponse = @import("./protocol/CopyBothResponse.zig").CopyBothResponse; +pub const CopyInResponse = @import("./protocol/CopyInResponse.zig").CopyInResponse; +pub const CopyOutResponse = @import("./protocol/CopyOutResponse.zig").CopyOutResponse; pub const DataRow = @import("./protocol/DataRow.zig"); pub const Describe = @import("./protocol/Describe.zig"); pub const ErrorResponse = @import("./protocol/ErrorResponse.zig"); diff --git a/src/sql/postgres/PostgresSQLConnection.zig b/src/sql/postgres/PostgresSQLConnection.zig index 7cbceae6290..afe5ef9d6c9 100644 --- a/src/sql/postgres/PostgresSQLConnection.zig +++ b/src/sql/postgres/PostgresSQLConnection.zig @@ -1,5 +1,24 @@ const PostgresSQLConnection = @This(); const RefCount = bun.ptr.RefCount(@This(), "ref_count", deinit, .{}); + +/// Maximum buffer size for COPY data accumulation (256MB) +const MAX_COPY_BUFFER_SIZE: usize = 256 * 1024 * 1024; + +/// Hard upper bound for COPY buffer size to avoid unbounded memory growth. +/// This is intentionally conservative; raising it should require explicit opt-in. +const MAX_COPY_BUFFER_SIZE_HARD_CAP: usize = 1024 * 1024 * 1024; // 1 GiB + +/// Threshold for shrinking the COPY buffer after operation completes (64MB) +/// If buffer capacity exceeds this after COPY, we shrink it to avoid wasting memory +const COPY_BUFFER_SHRINK_THRESHOLD: usize = 64 * 1024 * 1024; + +/// Default COPY operation timeout in milliseconds. +/// 0 means no COPY timeout. +const DEFAULT_COPY_TIMEOUT_MS: u32 = 0; + +/// PostgreSQL binary COPY format signature: "PGCOPY\n\xff\r\n\0" +const COPY_BINARY_SIGNATURE = [_]u8{ 'P', 'G', 'C', 'O', 'P', 'Y', '\n', 0xff, '\r', '\n', 0 }; + socket: Socket, status: Status = Status.connecting, ref_count: RefCount = RefCount.init(), @@ -43,6 +62,10 @@ connection_timeout_ms: u32 = 0, flags: ConnectionFlags = .{}, +/// Promise used by `awaitWritable()` to await socket writability. +/// Stored strongly on the connection so it is kept alive until resolved/rejected. +await_writable_promise: jsc.Strong.Optional = .empty, + /// Before being connected, this is a connection timeout timer. /// After being connected, this is an idle timeout timer. timer: bun.api.Timer.EventLoopTimer = .{ @@ -60,9 +83,174 @@ max_lifetime_timer: bun.api.Timer.EventLoopTimer = .{ }, auto_flusher: AutoFlusher = .{}, +/// COPY protocol state tracking +copy_state: enum { + none, + copy_in_progress, // COPY FROM STDIN + copy_out_progress, // COPY TO STDOUT +} = .none, +copy_format: u8 = 0, // 0=text, 1=binary +copy_column_formats: []u16 = &.{}, +copy_data_buffer: std.array_list.Managed(u8) = std.array_list.Managed(u8).init(bun.default_allocator), +max_copy_buffer_size: usize = MAX_COPY_BUFFER_SIZE, + +/// The query that owns the currently active COPY operation. +/// This is set when COPY starts (CopyInResponse / CopyOutResponse) and is used to +/// deterministically reject the correct request on COPY failures. +copy_owner: ?*PostgresSQLQuery = null, + +/// COPY progress tracking +copy_bytes_transferred: u64 = 0, +copy_chunks_processed: u64 = 0, +/// If true, do not accumulate COPY TO data in memory; only emit streaming chunks to JS +copy_streaming_mode: bool = false, +/// Whether JavaScript has registered an onCopyChunk handler for this connection. +/// This is used to prevent enabling streaming mode when there is nowhere to deliver chunks. +copy_chunk_handler_registered: bool = false, +/// Track if we're currently processing a streaming callback to prevent reentrant calls +copy_callback_in_progress: bool = false, +/// COPY-specific timeout in milliseconds. +/// 0 means no COPY timeout. +copy_timeout_ms: u32 = 0, +/// Timestamp when COPY operation started (for timeout tracking) +copy_start_timestamp_ms: u64 = 0, +/// Track if we've validated the binary COPY header +copy_binary_header_validated: bool = false, + pub const ref = RefCount.ref; pub const deref = RefCount.deref; +fn effectiveMaxCopyBufferSize(this: *const PostgresSQLConnection) usize { + return if (this.max_copy_buffer_size == 0) std.math.maxInt(usize) else this.max_copy_buffer_size; +} + +fn resolveAwaitWritable(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject) void { + const promise_value = this.await_writable_promise.swap(); + if (promise_value == .zero) return; + + const promise = promise_value.asInternalPromise() orelse return; + promise.resolve(globalObject, .js_undefined) catch {}; +} + +fn rejectAwaitWritable(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, message: []const u8) void { + const promise_value = this.await_writable_promise.swap(); + if (promise_value == .zero) return; + + const promise = promise_value.asInternalPromise() orelse return; + const err = globalObject.createErrorInstance("{s}", .{message}); + promise.rejectAsHandled(globalObject, err); +} + +/// JS: PostgresSQLConnection.setCopyStreamingMode(enable: boolean) +pub fn setCopyStreamingMode(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + _ = globalObject; + const args = callframe.arguments(); + const enable = if (args.len > 0) args[0].toBoolean() else true; + this.copy_streaming_mode = enable; + return .js_undefined; +} + +/// JS: PostgresSQLConnection.setCopyTimeout(ms: number) +pub fn setCopyTimeout(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + const args = callframe.arguments(); + if (args.len < 1) { + return globalObject.throwNotEnoughArguments("setCopyTimeout", 1, args.len); + } + const n = try args[0].toNumber(globalObject); + var ms: u32 = 0; + if (std.math.isFinite(n) and n > 0) { + const cap_f64: f64 = @floatFromInt(std.math.maxInt(u32)); + const clamped: f64 = @min(n, cap_f64); + const n_u64: u64 = @intFromFloat(clamped); + ms = @intCast(n_u64); + } + this.copy_timeout_ms = ms; + return .js_undefined; +} + +/// JS: PostgresSQLConnection.setMaxCopyBufferSize(bytes: number) +/// +/// This method is intentionally capped to `MAX_COPY_BUFFER_SIZE` (256MB) to avoid +/// unbounded memory growth when callers pass very large values (for example, via +/// JS clamping to 0xffffffff). +/// +/// To opt into larger limits (up to `MAX_COPY_BUFFER_SIZE_HARD_CAP`), use +/// `setMaxCopyBufferSizeUnsafe()`. +pub fn setMaxCopyBufferSize(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + const args = callframe.arguments(); + if (args.len < 1) { + return globalObject.throwNotEnoughArguments("setMaxCopyBufferSize", 1, args.len); + } + const n = try args[0].toNumber(globalObject); + + // Default to the safe cap (256MB). Non-finite and <= 0 values disable limits (0), + // matching the documented semantics used elsewhere for COPY limits. + var bytes: usize = MAX_COPY_BUFFER_SIZE; + if (!std.math.isFinite(n) or n <= 0) { + bytes = 0; + } else { + const cap_f64: f64 = @floatFromInt(@as(u64, MAX_COPY_BUFFER_SIZE)); + const clamped: f64 = @min(n, cap_f64); + const n_u64: u64 = @intFromFloat(clamped); + bytes = @intCast(@min(n_u64, @as(u64, MAX_COPY_BUFFER_SIZE))); + } + + this.max_copy_buffer_size = bytes; + return .js_undefined; +} + +/// JS: PostgresSQLConnection.setMaxCopyBufferSizeUnsafe(bytes: number) +/// +/// Explicit opt-in to larger COPY buffer sizes, capped to `MAX_COPY_BUFFER_SIZE_HARD_CAP`. +pub fn setMaxCopyBufferSizeUnsafe(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + const args = callframe.arguments(); + if (args.len < 1) { + return globalObject.throwNotEnoughArguments("setMaxCopyBufferSizeUnsafe", 1, args.len); + } + const n = try args[0].toNumber(globalObject); + + var bytes: usize = 0; + if (std.math.isFinite(n) and n > 0) { + const cap_f64: f64 = @floatFromInt(@as(u64, MAX_COPY_BUFFER_SIZE_HARD_CAP)); + const clamped: f64 = @min(n, cap_f64); + const n_u64: u64 = @intFromFloat(clamped); + bytes = @intCast(@min(n_u64, @as(u64, MAX_COPY_BUFFER_SIZE_HARD_CAP))); + } + + this.max_copy_buffer_size = bytes; + return .js_undefined; +} + +/// JS: PostgresSQLConnection.awaitWritable() +/// Returns a Promise that resolves once the socket becomes writable. +/// If there is no backpressure, it resolves immediately. +pub fn awaitWritable(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!jsc.JSValue { + // If the connection is not connected, fail immediately. + if (this.status != .connected) { + return globalObject.throw("Cannot await writable: connection is {s}. The connection must be open.", .{@tagName(this.status)}); + } + + // Fast path: if there is no backpressure, resolve immediately. + if (!this.flags.has_backpressure) { + return jsc.JSInternalPromise.resolvedPromise(globalObject, .js_undefined).asValue(); + } + + // Reuse an existing pending promise if present. + if (this.await_writable_promise.get()) |existing| { + return existing; + } + + const promise_value = jsc.JSValue.createInternalPromise(globalObject); + if (promise_value.asInternalPromise() == null) { + return globalObject.throw("Failed to create internal promise for awaitWritable", .{}); + } + + // Store strongly on the connection so it is kept alive until resolved/rejected. + this.await_writable_promise.set(globalObject, promise_value); + + return promise_value; +} + pub fn onAutoFlush(this: *@This()) bool { if (this.flags.has_backpressure) { debug("onAutoFlush: has backpressure", .{}); @@ -347,8 +535,15 @@ pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: AnyPostgresE } pub fn onClose(this: *PostgresSQLConnection) void { + // Reject any pending awaitWritable promise first so callers do not hang on remote close. + // Do this early to avoid races with cleanup and socket teardown. + this.rejectAwaitWritable(this.globalObject, "Connection closed"); + this.unregisterAutoFlusher(); + // Clean up COPY state if connection closes during COPY operation + this.cleanupCopyState(); + if (this.vm.isShuttingDown()) { defer this.updateHasPendingActivity(); this.stopTimers(); @@ -459,6 +654,18 @@ pub fn onTimeout(this: *PostgresSQLConnection) void { pub fn onDrain(this: *PostgresSQLConnection) void { debug("onDrain", .{}); this.flags.has_backpressure = false; + + // Resolve any pending awaitWritable promise first. + this.resolveAwaitWritable(this.globalObject); + + // Notify any pending awaitWritable callback (use connection as thisArg) + var vm = jsc.VirtualMachine.get(); + if (vm.rareData().postgresql_context.onWritableFn.get()) |callback_writable| { + const event_loop = vm.eventLoop(); + // Pass the PostgresSQLConnection JS wrapper as 'this' so JS can dispatch per-connection + event_loop.runCallback(callback_writable, this.globalObject, this.js_value, &.{}); + } + // Don't send any other messages while we're waiting for TLS. if (this.tls_status == .message_sent) { if (this.tls_status.message_sent < 8) { @@ -861,6 +1068,8 @@ pub fn doFlush(this: *PostgresSQLConnection, _: *jsc.JSGlobalObject, _: *jsc.Cal } fn close(this: *@This()) void { + // Reject any pending awaitWritable promise before tearing down the socket. + this.rejectAwaitWritable(this.globalObject, "Connection closed"); this.disconnect(); this.unregisterAutoFlusher(); this.write_buffer.clearAndFree(bun.default_allocator); @@ -872,6 +1081,403 @@ pub fn doClose(this: *@This(), globalObject: *jsc.JSGlobalObject, _: *jsc.CallFr return .js_undefined; } +/// Helper: send COPY data from a JSValue (string or ArrayBuffer/TypedArray) +pub fn copySendDataFromJSValue(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, data_value: jsc.JSValue) bun.JSError!void { + // Validate connection state + if (this.status != .connected) { + return globalObject.throw("Cannot send COPY data: connection is {s}. Ensure the connection is open before sending COPY data.", .{@tagName(this.status)}); + } + if (this.copy_state != .copy_in_progress) { + return globalObject.throw("Cannot send COPY data: not in COPY FROM STDIN mode (current state: {s}). You must execute a 'COPY ... FROM STDIN' command first.", .{@tagName(this.copy_state)}); + } + + // Enforce COPY timeout for COPY FROM as well (0 disables COPY timeout). + if (this.copy_timeout_ms > 0 and this.copy_start_timestamp_ms > 0) { + const now = std.time.milliTimestamp(); + const elapsed = @as(u64, @intCast(now)) -| this.copy_start_timestamp_ms; + const timeout_u64: u64 = @intCast(this.copy_timeout_ms); + if (elapsed > timeout_u64) { + this.abortCopyAndFailConnection(error.CopyTimeout, "COPY aborted: timeout"); + return globalObject.throw("COPY aborted: timeout", .{}); + } + } + + const max_copy_buffer_size = this.effectiveMaxCopyBufferSize(); + + // Extract payload as bytes (ArrayBuffer/TypedArray) or UTF-8 from string. + // IMPORTANT: When converting a string to UTF-8, the resulting slice is only valid + // while the UTF-8 buffer is alive. Write CopyData inside the same scope. + if (data_value.asArrayBuffer(globalObject)) |buf| { + const slice = buf.byteSlice(); + + // Guard against excessively large chunks (0 disables limit) + if (slice.len > max_copy_buffer_size) { + return globalObject.throw("COPY data chunk too large: {d} bytes exceeds maximum of {d} bytes. Consider sending smaller chunks.", .{ slice.len, max_copy_buffer_size }); + } + + // Write CopyData + var copy_data = protocol.CopyData{ + .data = .{ .temporary = slice }, + }; + copy_data.writeInternal(PostgresSQLConnection.Writer, this.writer()) catch |err| { + this.abortCopyAndFailConnection(error.CopyWriteFailed, "COPY aborted: write failed"); + return globalObject.throw("Failed to send COPY data ({d} bytes): {s}. The connection may have been closed or the socket buffer may be full.", .{ slice.len, @errorName(err) }); + }; + this.flushData(); + + // Progress tracking (saturating add) + this.copy_bytes_transferred = this.copy_bytes_transferred +| @as(u64, @intCast(slice.len)); + this.copy_chunks_processed = this.copy_chunks_processed +| 1; + } else { + const data_str = try data_value.toBunString(globalObject); + defer data_str.deref(); + + var data_utf8 = data_str.toUTF8(bun.default_allocator); + defer data_utf8.deinit(); + + const slice = data_utf8.slice(); + + // Guard against excessively large chunks (0 disables limit) + if (slice.len > max_copy_buffer_size) { + return globalObject.throw("COPY data chunk too large: {d} bytes exceeds maximum of {d} bytes. Consider sending smaller chunks.", .{ slice.len, max_copy_buffer_size }); + } + + // Write CopyData while `data_utf8` is still alive. + var copy_data = protocol.CopyData{ + .data = .{ .temporary = slice }, + }; + copy_data.writeInternal(PostgresSQLConnection.Writer, this.writer()) catch |err| { + this.abortCopyAndFailConnection(error.CopyWriteFailed, "COPY aborted: write failed"); + return globalObject.throw("Failed to send COPY data ({d} bytes): {s}. The connection may have been closed or the socket buffer may be full.", .{ slice.len, @errorName(err) }); + }; + this.flushData(); + + // Progress tracking (saturating add) + this.copy_bytes_transferred = this.copy_bytes_transferred +| @as(u64, @intCast(slice.len)); + this.copy_chunks_processed = this.copy_chunks_processed +| 1; + } +} + +/// Helper: send COPY done (validates state) +fn copySendDone(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject) bun.JSError!void { + // Validate connection state + if (this.status != .connected) { + return globalObject.throw("Cannot send COPY done: connection is {s}. The connection must be open to complete the COPY operation.", .{@tagName(this.status)}); + } + if (this.copy_state != .copy_in_progress) { + return globalObject.throw("Cannot send COPY done: not in COPY FROM STDIN mode (current state: {s}). You must be in an active COPY FROM STDIN operation.", .{@tagName(this.copy_state)}); + } + + this.writer().write(&protocol.CopyDone) catch |err| { + this.abortCopyAndFailConnection(error.CopyWriteFailed, "COPY aborted: write failed"); + return globalObject.throw("Failed to send COPY done signal: {s}. This may indicate a network error or closed connection.", .{@errorName(err)}); + }; + this.flushData(); +} + +/// Helper: send COPY fail with a message from a JSValue +pub fn copySendFailFromJSValue(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, message_value: jsc.JSValue) bun.JSError!void { + // Validate connection state + if (this.status != .connected) { + return globalObject.throw("Cannot send COPY fail: connection is {s}. The connection must be open to abort the COPY operation.", .{@tagName(this.status)}); + } + if (this.copy_state != .copy_in_progress) { + return globalObject.throw("Cannot send COPY fail: not in COPY FROM STDIN mode (current state: {s}). You must be in an active COPY FROM STDIN operation to abort it.", .{@tagName(this.copy_state)}); + } + + if (!message_value.isEmptyOrUndefinedOrNull()) { + const message_string = try message_value.toBunString(globalObject); + defer message_string.deref(); + + var message = message_string.toUTF8(bun.default_allocator); + defer message.deinit(); + + var fail_message = protocol.CopyFail{ + .message = .{ .temporary = message.slice() }, + }; + fail_message.writeInternal(PostgresSQLConnection.Writer, this.writer()) catch |err| { + this.abortCopyAndFailConnection(error.CopyWriteFailed, "COPY aborted: write failed"); + return globalObject.throw("Failed to send COPY fail message to server: {s}. The COPY operation may have already ended or the connection may be closed.", .{@errorName(err)}); + }; + this.flushData(); + } else { + var fail_message = protocol.CopyFail{ + .message = .{ .temporary = "" }, + }; + fail_message.writeInternal(PostgresSQLConnection.Writer, this.writer()) catch |err| { + this.abortCopyAndFailConnection(error.CopyWriteFailed, "COPY aborted: write failed"); + return globalObject.throw("Failed to send COPY fail message to server: {s}. The COPY operation may have already ended or the connection may be closed.", .{@errorName(err)}); + }; + this.flushData(); + } + + // Clean up all COPY state + this.cleanupCopyState(); +} + +/// Public: PostgresSQLConnection.sendCopyData(data) +pub fn sendCopyData(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + const args = callframe.arguments(); + if (args.len < 1) { + return globalObject.throwNotEnoughArguments("sendCopyData", 1, args.len); + } + try this.copySendDataFromJSValue(globalObject, args[0]); + return .js_undefined; +} + +/// Public: PostgresSQLConnection.sendCopyDone() +pub fn sendCopyDone(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!jsc.JSValue { + try this.copySendDone(globalObject); + return .js_undefined; +} + +/// Public: PostgresSQLConnection.sendCopyFail(message?) +pub fn sendCopyFail(this: *PostgresSQLConnection, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + const args = callframe.arguments(); + const message_value: jsc.JSValue = if (args.len > 0) args[0] else .js_undefined; + try this.copySendFailFromJSValue(globalObject, message_value); + return .js_undefined; +} + +/// Clean up all COPY protocol state +/// +/// This function is called in the following scenarios: +/// - Normal completion: After CommandComplete is received and data is returned to JS +/// - Error during COPY: When ErrorResponse is received during an active COPY operation +/// - Connection failure: When the connection is closed or fails during COPY +/// - Write failure: When sending CopyData, CopyDone, or CopyFail fails +/// - State validation failure: When concurrent COPY operations are detected +/// +/// This function is idempotent and safe to call multiple times. +fn abortCopyAndFailConnection(this: *PostgresSQLConnection, err: AnyPostgresError, comptime message: [:0]const u8) void { + // Reject the query that owns this COPY operation (preferred) or fall back to the current request. + if (this.copy_owner) |request| { + this.finishRequest(request); + request.onError(.{ .postgres_error = err }, this.globalObject); + } else if (this.current()) |request| { + this.finishRequest(request); + request.onError(.{ .postgres_error = err }, this.globalObject); + } + + this.cleanupCopyState(); + this.fail(message, err); +} + +fn cleanupCopyState(this: *PostgresSQLConnection) void { + // Early exit if already cleaned up + if (this.copy_state == .none and + this.copy_column_formats.len == 0 and + this.copy_data_buffer.items.len == 0 and + this.copy_owner == null) + { + return; + } + + debug("cleanupCopyState: state={s} bytes={} chunks={}", .{ + @tagName(this.copy_state), + this.copy_bytes_transferred, + this.copy_chunks_processed, + }); + + // Reset state flags + this.copy_state = .none; + this.copy_format = 0; + + // Clear COPY owner + this.copy_owner = null; + + // Free column formats array if allocated + if (this.copy_column_formats.len > 0) { + bun.default_allocator.free(this.copy_column_formats); + this.copy_column_formats = &.{}; + } + + // Clear data buffer and shrink if it grew too large + const buffer_capacity = this.copy_data_buffer.capacity; + this.copy_data_buffer.clearRetainingCapacity(); + + // If buffer capacity exceeds threshold, shrink it to save memory + if (buffer_capacity > COPY_BUFFER_SHRINK_THRESHOLD) { + debug("cleanupCopyState: shrinking buffer from {} to 0", .{buffer_capacity}); + this.copy_data_buffer.clearAndFree(); + } + + // Reset progress counters + this.copy_bytes_transferred = 0; + this.copy_chunks_processed = 0; + // Reset streaming mode and callback flags + this.copy_streaming_mode = false; + this.copy_callback_in_progress = false; + + // Reset timeout tracking + this.copy_start_timestamp_ms = 0; + + // Reset binary validation flag + this.copy_binary_header_validated = false; +} + +/// Helper to initialize COPY state for COPY FROM (is_out=false) or COPY TO (is_out=true) +fn startCopy(this: *PostgresSQLConnection, overall_format: u8, column_format_codes: []const u16, is_out: bool) AnyPostgresError!void { + // Prevent concurrent COPY operations + if (this.copy_state != .none) { + this.cleanupCopyState(); + return error.UnexpectedMessage; + } + + // Record which query owns this COPY operation (deterministic rejection on failure) + this.copy_owner = this.current() orelse return error.ExpectedRequest; + + // Duplicate column formats up-front + const new_column_formats = bun.default_allocator.dupe(u16, column_format_codes) catch |err| { + return err; + }; + + // Update state + this.copy_state = if (is_out) .copy_out_progress else .copy_in_progress; + this.copy_format = overall_format; + this.copy_start_timestamp_ms = @intCast(std.time.milliTimestamp()); + + // Replace column formats + if (this.copy_column_formats.len > 0) { + bun.default_allocator.free(this.copy_column_formats); + } + this.copy_column_formats = new_column_formats; + + // Reset binary header validation; clear buffer for COPY TO + this.copy_binary_header_validated = false; + if (is_out) { + this.copy_data_buffer.clearRetainingCapacity(); + } + + // Fire onCopyStart callback if registered + var vm = jsc.VirtualMachine.get(); + if (vm.rareData().postgresql_context.onCopyStartFn.get()) |callback| { + const event_loop = vm.eventLoop(); + event_loop.runCallback(callback, this.globalObject, this.js_value, &.{}); + + if (this.globalObject.hasException()) { + this.cleanupCopyState(); + this.fail("onCopyStart callback threw an exception", error.JSError); + this.globalObject.reportActiveExceptionAsUnhandled(error.JSError); + return error.JSError; + } + } + + // Tightened semantics: + // For COPY TO, if streaming mode was requested but JavaScript did not register a per-connection + // chunk handler, fall back to accumulation to avoid silently discarding data. + if (is_out and this.copy_streaming_mode and !this.copy_chunk_handler_registered) { + this.copy_streaming_mode = false; + } +} + +fn emitChunkToJS(this: *PostgresSQLConnection, data: []const u8) AnyPostgresError!void { + var vm = jsc.VirtualMachine.get(); + if (vm.rareData().postgresql_context.onCopyChunkFn.get()) |callback| { + this.copy_callback_in_progress = true; + defer this.copy_callback_in_progress = false; + + const loop = vm.eventLoop(); + var js_chunk: jsc.JSValue = .zero; + + if (this.copy_format == 0) { + js_chunk = bun.String.createUTF8ForJS(this.globalObject, data) catch |e| { + this.cleanupCopyState(); + this.globalObject.reportActiveExceptionAsUnhandled(e); + this.fail("Failed to create chunk data for COPY callback", error.OutOfMemory); + return error.OutOfMemory; + }; + } else { + js_chunk = jsc.ArrayBuffer.create(this.globalObject, data, .ArrayBuffer) catch |e| { + this.cleanupCopyState(); + this.globalObject.reportActiveExceptionAsUnhandled(e); + this.fail("Failed to create chunk data for COPY callback", error.OutOfMemory); + return error.OutOfMemory; + }; + } + + loop.runCallback(callback, this.globalObject, this.js_value, &.{js_chunk}); + + if (this.globalObject.hasException()) { + this.cleanupCopyState(); + this.fail("COPY chunk callback threw an exception", error.JSError); + this.globalObject.reportActiveExceptionAsUnhandled(error.JSError); + return error.JSError; + } + } +} + +fn flushBufferedChunkToJS(this: *PostgresSQLConnection) AnyPostgresError!void { + if (this.copy_data_buffer.items.len == 0) return; + try this.emitChunkToJS(this.copy_data_buffer.items); + this.copy_data_buffer.clearRetainingCapacity(); +} + +fn finishCopy(this: *PostgresSQLConnection, request: *PostgresSQLQuery, command_tag_str: []const u8) AnyPostgresError!void { + debug("finishCopy: state={s} bytes={}", .{ @tagName(this.copy_state), this.copy_data_buffer.items.len }); + + // For COPY TO (copy_out_progress), emit any pending buffered data (streaming mode) and onCopyEnd callback. + if (this.copy_state == .copy_out_progress) { + // Streaming-mode binary guard: ensure header was validated before any flush/callback/early-return. + if (this.copy_streaming_mode and this.copy_format == 1 and !this.copy_binary_header_validated) { + debug("finishCopy: streaming binary COPY completed without validated header", .{}); + this.cleanupCopyState(); + this.fail("Binary COPY operation completed without valid header signature", error.InvalidBinaryData); + return error.InvalidBinaryData; + } + + // Late flush of any pending buffered data (streaming mode) + if (this.copy_streaming_mode and this.copy_data_buffer.items.len > 0) { + try this.flushBufferedChunkToJS(); + } + + // Emit streaming end callback if registered + var vm = jsc.VirtualMachine.get(); + if (vm.rareData().postgresql_context.onCopyEndFn.get()) |callback_end| { + const loop = vm.eventLoop(); + loop.runCallback(callback_end, this.globalObject, this.js_value, &.{}); + } + + if (this.copy_streaming_mode) { + // In streaming mode, do not return accumulated buffer (we did not accumulate). + this.cleanupCopyState(); + request.onResult(command_tag_str, this.globalObject, this.js_value, false); + return; + } + + // Binary COPY header validation guard: ensure header was validated before returning data + if (this.copy_format == 1 and !this.copy_binary_header_validated) { + debug("finishCopy: binary COPY completed without validated header", .{}); + this.cleanupCopyState(); + this.fail("Binary COPY operation completed without valid header signature", error.InvalidBinaryData); + return error.InvalidBinaryData; + } + + const max_copy_buffer_size = this.effectiveMaxCopyBufferSize(); + + // Non-streaming: pass COPY TO accumulated data to JavaScript (even if empty), with safety guard + if (this.copy_data_buffer.items.len > max_copy_buffer_size) { + const err_msg = std.fmt.allocPrint( + bun.default_allocator, + "COPY buffer exceeded limit at completion: {d} bytes (limit: {d} bytes)", + .{ this.copy_data_buffer.items.len, max_copy_buffer_size }, + ) catch "COPY buffer too large"; + defer if (err_msg.ptr != "COPY buffer too large".ptr) bun.default_allocator.free(err_msg); + this.cleanupCopyState(); + this.fail(err_msg, error.CopyBufferTooLarge); + return error.CopyBufferTooLarge; + } + + // onCopyResult will convert buffer -> JS value, cleanup copy state, and call onResult + this.onCopyResult(request, command_tag_str); + return; + } + + // For COPY FROM (copy_in_progress) or unknown/none, cleanup and complete + this.cleanupCopyState(); + request.onResult(command_tag_str, this.globalObject, this.js_value, false); +} + pub fn stopTimers(this: *PostgresSQLConnection) void { if (this.timer.state == .ACTIVE) { this.vm.timer.remove(&this.timer); @@ -894,6 +1500,12 @@ pub fn deinit(this: *@This()) void { this.read_buffer.deinit(bun.default_allocator); this.backend_parameters.deinit(); + // Clean up COPY state + if (this.copy_column_formats.len > 0) { + bun.default_allocator.free(this.copy_column_formats); + } + this.copy_data_buffer.deinit(); + bun.freeSensitive(bun.default_allocator, this.options_buf); this.tls_config.deinit(); @@ -901,6 +1513,9 @@ pub fn deinit(this: *@This()) void { } fn cleanUpRequests(this: *@This(), js_reason: ?jsc.JSValue) void { + // Ensure COPY state is cleaned up when clearing all requests + this.cleanupCopyState(); + while (this.current()) |request| { switch (request.status) { // pending we will fail the request and the stmt will be marked as error ConnectionClosed too @@ -952,6 +1567,7 @@ fn refAndClose(this: *@This(), js_reason: ?jsc.JSValue) void { } pub fn disconnect(this: *@This()) void { + this.rejectAwaitWritable(this.globalObject, "Connection disconnected"); this.stopTimers(); this.unregisterAutoFlusher(); if (this.status == .connected) { @@ -960,6 +1576,77 @@ pub fn disconnect(this: *@This()) void { } } +fn onCopyResult(this: *PostgresSQLConnection, request: *PostgresSQLQuery, command_tag_str: []const u8) void { + // Validate we're in a valid COPY state before proceeding + if (this.copy_state == .none) { + debug("onCopyResult called but copy_state is none - this shouldn't happen", .{}); + request.onError(.{ .postgres_error = AnyPostgresError.UnexpectedMessage }, this.globalObject); + return; + } + + // Only process for copy_out_progress (COPY TO STDOUT) + if (this.copy_state != .copy_out_progress) { + debug("onCopyResult called but not in copy_out_progress state: {s}", .{@tagName(this.copy_state)}); + this.cleanupCopyState(); + request.onError(.{ .postgres_error = AnyPostgresError.UnexpectedMessage }, this.globalObject); + return; + } + + // Create a JSValue from the copy data buffer + const copy_data = this.copy_data_buffer.items; + + // For text format COPY, return as a string + // For binary format, return as ArrayBuffer + const result_value = if (this.copy_format == 0) blk: { + // Text format - return as string + break :blk bun.String.createUTF8ForJS(this.globalObject, copy_data) catch |err| { + this.cleanupCopyState(); + request.onJSError(this.globalObject.takeException(err), this.globalObject); + return; + }; + } else blk: { + // Binary format - return as ArrayBuffer + const array_buffer = jsc.ArrayBuffer.create(this.globalObject, copy_data, .ArrayBuffer) catch |err| { + this.cleanupCopyState(); + request.onJSError(this.globalObject.takeException(err), this.globalObject); + return; + }; + break :blk array_buffer; + }; + + // Get the existing pending value (SQLResultArray) and push the COPY data into it + const thisValue = request.thisValue.tryGet() orelse return; + const pending_value = PostgresSQLQuery.js.pendingValueGetCached(thisValue) orelse .zero; + + if (pending_value != .zero) { + // Push the COPY data as the first (and only) element in the result array + pending_value.push(this.globalObject, result_value) catch |err| { + this.cleanupCopyState(); + request.onJSError(this.globalObject.takeException(err), this.globalObject); + return; + }; + } else { + // No pending array yet: create a new SQLResultArray, push the result, and cache it + const new_array = jsc.JSValue.createEmptyArray(this.globalObject, 0) catch |err| { + this.cleanupCopyState(); + request.onJSError(this.globalObject.takeException(err), this.globalObject); + return; + }; + new_array.push(this.globalObject, result_value) catch |err| { + this.cleanupCopyState(); + request.onJSError(this.globalObject.takeException(err), this.globalObject); + return; + }; + PostgresSQLQuery.js.pendingValueSetCached(thisValue, this.globalObject, new_array); + } + + // Clear COPY state before completing the request + this.cleanupCopyState(); + + // Call onResult to complete the query + request.onResult(command_tag_str, this.globalObject, this.js_value, false); +} + fn current(this: *PostgresSQLConnection) ?*PostgresSQLQuery { if (this.requests.readableLength() == 0) { return null; @@ -1432,7 +2119,182 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera .CopyData => { var copy_data: protocol.CopyData = undefined; try copy_data.decodeInternal(Context, reader); - copy_data.data.deinit(); + defer copy_data.data.deinit(); + + if (this.copy_state == .copy_out_progress) { + // COPY TO STDOUT + const data_slice = copy_data.data.slice(); + debug("CopyData: received {} bytes", .{data_slice.len}); + + // Check COPY operation timeout + if (this.copy_timeout_ms > 0 and this.copy_start_timestamp_ms > 0) { + const now = std.time.milliTimestamp(); + const elapsed = @as(u64, @intCast(now)) -| this.copy_start_timestamp_ms; + const timeout_u64: u64 = @intCast(this.copy_timeout_ms); + if (elapsed > timeout_u64) { + debug("CopyData: timeout after {}ms (limit: {}ms)", .{ elapsed, timeout_u64 }); + this.abortCopyAndFailConnection(error.CopyTimeout, "COPY aborted: timeout"); + return error.CopyTimeout; + } + } + + // Validate/accumulate binary COPY header (supports fragmented first chunks) + if (this.copy_format == 1 and !this.copy_binary_header_validated) { + if (this.copy_streaming_mode) { + const max_copy_buffer_size = this.effectiveMaxCopyBufferSize(); + + // In streaming mode, buffer until we have at least the signature, then validate and emit buffered bytes + if (data_slice.len > max_copy_buffer_size) { + this.abortCopyAndFailConnection(error.CopyBufferTooLarge, "COPY aborted: buffer limit exceeded"); + return error.CopyBufferTooLarge; + } + const new_total_stream = this.copy_data_buffer.items.len + data_slice.len; + if (new_total_stream > max_copy_buffer_size) { + this.abortCopyAndFailConnection(error.CopyBufferTooLarge, "COPY aborted: buffer limit exceeded"); + return error.CopyBufferTooLarge; + } + this.copy_data_buffer.appendSlice(data_slice) catch |err| { + this.cleanupCopyState(); + return err; + }; + + if (this.copy_data_buffer.items.len < COPY_BINARY_SIGNATURE.len) { + // Not enough bytes yet; just track progress and wait for more + this.copy_bytes_transferred = this.copy_bytes_transferred +| @as(u64, @intCast(data_slice.len)); + this.copy_chunks_processed = this.copy_chunks_processed +| 1; + return; + } + + const has_valid_signature = std.mem.eql(u8, this.copy_data_buffer.items[0..COPY_BINARY_SIGNATURE.len], ©_BINARY_SIGNATURE); + if (!has_valid_signature) { + debug("CopyData: invalid binary COPY signature", .{}); + this.abortCopyAndFailConnection(error.InvalidBinaryData, "COPY aborted: invalid binary format"); + return error.InvalidBinaryData; + } + this.copy_binary_header_validated = true; + + // If a chunk callback is running, keep buffering and return + if (this.copy_callback_in_progress) { + this.copy_bytes_transferred = this.copy_bytes_transferred +| @as(u64, @intCast(data_slice.len)); + this.copy_chunks_processed = this.copy_chunks_processed +| 1; + return; + } + + // Emit the buffered header+data as a single chunk + try this.emitChunkToJS(this.copy_data_buffer.items); + + // Clear buffered header/data after emission and update progress + this.copy_data_buffer.clearRetainingCapacity(); + this.copy_bytes_transferred = this.copy_bytes_transferred +| @as(u64, @intCast(data_slice.len)); + this.copy_chunks_processed = this.copy_chunks_processed +| 1; + return; + } else { + // Non-streaming: allow fragmented first chunk; validate once we have enough bytes across buffer + incoming chunk + const sig_len: usize = COPY_BINARY_SIGNATURE.len; + const buffered_len: usize = this.copy_data_buffer.items.len; + if (buffered_len == 0) { + // Fast-path: entire signature is in this chunk + if (data_slice.len >= sig_len) { + const has_valid_signature = std.mem.eql(u8, data_slice[0..sig_len], ©_BINARY_SIGNATURE); + if (!has_valid_signature) { + debug("CopyData: invalid binary COPY signature", .{}); + this.abortCopyAndFailConnection(error.InvalidBinaryData, "COPY aborted: invalid binary format"); + return error.InvalidBinaryData; + } + this.copy_binary_header_validated = true; + } + } else if (buffered_len < sig_len and buffered_len + data_slice.len >= sig_len) { + // Signature split across previous buffer and this chunk; stitch minimal prefix into scratch and validate + var scratch: [COPY_BINARY_SIGNATURE.len]u8 = undefined; + // Copy already-buffered prefix + @memcpy(scratch[0..buffered_len], this.copy_data_buffer.items[0..buffered_len]); + // Copy needed bytes from the head of the new chunk + const need: usize = sig_len - buffered_len; + @memcpy(scratch[buffered_len .. buffered_len + need], data_slice[0..need]); + const has_valid_signature = std.mem.eql(u8, scratch[0..sig_len], ©_BINARY_SIGNATURE); + if (!has_valid_signature) { + debug("CopyData: invalid binary COPY signature (split across frames)", .{}); + this.abortCopyAndFailConnection(error.InvalidBinaryData, "COPY aborted: invalid binary format"); + return error.InvalidBinaryData; + } + this.copy_binary_header_validated = true; + } + // Otherwise, wait for next chunk to accumulate enough bytes (handled by normal buffering) + } + } + + // If a previous callback is still in progress, buffer and return safely (streaming mode) + if (this.copy_streaming_mode and this.copy_callback_in_progress) { + const max_copy_buffer_size = this.effectiveMaxCopyBufferSize(); + const new_total_pending = this.copy_data_buffer.items.len + data_slice.len; + if (new_total_pending > max_copy_buffer_size) { + this.abortCopyAndFailConnection(error.CopyBufferTooLarge, "COPY aborted: buffer limit exceeded"); + return error.CopyBufferTooLarge; + } + this.copy_data_buffer.appendSlice(data_slice) catch |err| { + this.cleanupCopyState(); + return err; + }; + this.copy_bytes_transferred = this.copy_bytes_transferred +| @as(u64, @intCast(data_slice.len)); + this.copy_chunks_processed = this.copy_chunks_processed +| 1; + return; + } + + // In streaming mode, enforce a per-chunk limit to avoid allocating huge ArrayBuffers + if (this.copy_streaming_mode) { + const effective_limit: usize = if (this.max_copy_buffer_size == 0) std.math.maxInt(usize) else this.max_copy_buffer_size; + const per_chunk_limit: usize = @min(effective_limit, 64 * 1024 * 1024); + + if (data_slice.len > per_chunk_limit) { + this.abortCopyAndFailConnection(error.CopyChunkTooLarge, "COPY aborted: chunk too large"); + return error.CopyChunkTooLarge; + } + } + + if (!this.copy_streaming_mode) { + const max_copy_buffer_size = this.effectiveMaxCopyBufferSize(); + + // Validate individual chunk size (0 disables limit) + if (data_slice.len > max_copy_buffer_size) { + this.abortCopyAndFailConnection(error.CopyBufferTooLarge, "COPY aborted: buffer limit exceeded"); + return error.CopyBufferTooLarge; + } + + // Check buffer size limit to prevent excessive memory usage + const new_total = this.copy_data_buffer.items.len + data_slice.len; + if (new_total > max_copy_buffer_size) { + this.abortCopyAndFailConnection(error.CopyBufferTooLarge, "COPY aborted: buffer limit exceeded"); + return error.CopyBufferTooLarge; + } + + this.copy_data_buffer.appendSlice(data_slice) catch |err| { + // Allocation failed - abort COPY, reject the owner, and fail the connection + this.abortCopyAndFailConnection(error.OutOfMemory, "COPY aborted: out of memory"); + return err; + }; + } + + // Track progress (with overflow protection) + this.copy_bytes_transferred = this.copy_bytes_transferred +| @as(u64, @intCast(data_slice.len)); + this.copy_chunks_processed = this.copy_chunks_processed +| @as(u64, 1); + + // Streaming mode: flush any pending buffered bytes then emit this chunk. + // - For text COPY (copy_format == 0), we can flush immediately. + // - For binary COPY (copy_format == 1), only flush once the header has been validated. + if (this.copy_streaming_mode) { + if (this.copy_data_buffer.items.len > 0 and !this.copy_callback_in_progress and (this.copy_format == 0 or this.copy_binary_header_validated)) { + try this.flushBufferedChunkToJS(); + } + try this.emitChunkToJS(data_slice); + } + } else if (this.copy_state == .copy_in_progress) { + // For COPY FROM STDIN, we shouldn't receive CopyData from server + debug("CopyData: unexpected in copy_in_progress state", .{}); + this.abortCopyAndFailConnection(error.UnexpectedCopyData, "COPY aborted: unexpected server data"); + return error.UnexpectedCopyData; + } else { + debug("CopyData: received outside COPY operation", .{}); + } }, .ParameterStatus => { var parameter_status: protocol.ParameterStatus = undefined; @@ -1474,7 +2336,13 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera debug("-> {s}", .{cmd.command_tag.slice()}); defer this.updateRef(); - request.onResult(cmd.command_tag.slice(), this.globalObject, this.js_value, false); + // Check if this is completing a COPY operation + if (this.copy_state != .none) { + debug("CommandComplete: COPY operation completed with {} bytes", .{this.copy_data_buffer.items.len}); + try this.finishCopy(request, cmd.command_tag.slice()); + } else { + request.onResult(cmd.command_tag.slice(), this.globalObject, this.js_value, false); + } }, .BindComplete => { try reader.eatMessage(protocol.BindComplete); @@ -1724,16 +2592,20 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera return; } - var request = this.current() orelse { + // During an active COPY operation, prefer rejecting the COPY-owning query. + // This ensures deterministic attribution of server errors during COPY. + var request = (if (this.copy_owner) |owner| owner else this.current()) orelse { debug("ErrorResponse: {f}", .{err}); return error.ExpectedRequest; }; + var is_error_owned = true; defer { if (is_error_owned) { err.deinit(); } } + if (request.statement) |stmt| { if (stmt.status == PostgresSQLStatement.Status.parsing) { stmt.status = PostgresSQLStatement.Status.failed; @@ -1748,6 +2620,13 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera this.finishRequest(request); this.updateRef(); request.onError(.{ .protocol = err }, this.globalObject); + + // Clean up COPY state if we were in the middle of a COPY operation. + // This is done after routing the error so `copy_owner` is still available. + if (this.copy_state != .none) { + debug("ErrorResponse during COPY operation - cleaning up state", .{}); + this.cleanupCopyState(); + } }, .PortalSuspended => { // try reader.eatMessage(&protocol.PortalSuspended); @@ -1762,7 +2641,14 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera request.onResult("CLOSECOMPLETE", this.globalObject, this.js_value, false); }, .CopyInResponse => { - debug("TODO CopyInResponse", .{}); + var resp: protocol.CopyInResponse = undefined; + try resp.decodeInternal(Context, reader); + defer resp.deinit(); + + debug("CopyInResponse: format={} columns={}", .{ resp.overall_format(), resp.column_format_codes().len }); + // Initialize COPY FROM state + try this.startCopy(resp.overall_format(), resp.column_format_codes(), false); + debug("CopyInResponse: ready to accept COPY data", .{}); }, .NoticeResponse => { debug("UNSUPPORTED NoticeResponse", .{}); @@ -1778,13 +2664,61 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera request.onResult("", this.globalObject, this.js_value, false); }, .CopyOutResponse => { - debug("TODO CopyOutResponse", .{}); + var resp: protocol.CopyOutResponse = undefined; + try resp.decodeInternal(Context, reader); + defer resp.deinit(); + + debug("CopyOutResponse: format={} columns={}", .{ resp.overall_format(), resp.column_format_codes().len }); + // Initialize COPY TO state + try this.startCopy(resp.overall_format(), resp.column_format_codes(), true); + debug("CopyOutResponse: ready to stream COPY data", .{}); }, .CopyDone => { - debug("TODO CopyDone", .{}); + try reader.eatMessage(protocol.CopyDone); + + debug("CopyDone: received {} bytes total", .{this.copy_data_buffer.items.len}); + + const max_copy_buffer_size = this.effectiveMaxCopyBufferSize(); + + // Safety guard: if not streaming and accumulated buffer somehow exceeds limit, abort + if (!this.copy_streaming_mode and this.copy_data_buffer.items.len > max_copy_buffer_size) { + const err_msg = std.fmt.allocPrint( + bun.default_allocator, + "COPY buffer exceeded limit at end: {d} bytes (limit: {d} bytes)", + .{ this.copy_data_buffer.items.len, max_copy_buffer_size }, + ) catch "COPY buffer too large"; + defer if (err_msg.ptr != "COPY buffer too large".ptr) bun.default_allocator.free(err_msg); + this.cleanupCopyState(); + this.fail(err_msg, error.CopyBufferTooLarge); + return error.CopyBufferTooLarge; + } + + // Validate we're in the correct state + if (this.copy_state != .copy_out_progress) { + debug("CopyDone: unexpected - not in copy_out_progress state (current: {s})", .{@tagName(this.copy_state)}); + this.cleanupCopyState(); + this.fail("Received CopyDone from server but not in COPY TO STDOUT operation", error.UnexpectedMessage); + return error.UnexpectedMessage; + } + + _ = this.current() orelse return error.ExpectedRequest; + + // Keep copy_state active - it will be cleared in CommandComplete + // The accumulated data will be returned when CommandComplete arrives + debug("CopyDone: waiting for CommandComplete", .{}); }, .CopyBothResponse => { - debug("TODO CopyBothResponse", .{}); + var resp: protocol.CopyBothResponse = undefined; + try resp.decodeInternal(Context, reader); + defer resp.deinit(); + + debug("CopyBothResponse: format={} columns={} (streaming replication)", .{ resp.overall_format(), resp.column_format_codes().len }); + + // CopyBothResponse is used for streaming replication + // Not implemented yet + this.cleanupCopyState(); + this.fail("CopyBoth (streaming replication) is not implemented", error.CopyBothNotImplemented); + return error.CopyBothNotImplemented; }, else => @compileError("Unknown message type: " ++ @tagName(MessageType)), } diff --git a/src/sql/postgres/PostgresSQLContext.zig b/src/sql/postgres/PostgresSQLContext.zig index 8982f17d6ee..e93a09770be 100644 --- a/src/sql/postgres/PostgresSQLContext.zig +++ b/src/sql/postgres/PostgresSQLContext.zig @@ -2,11 +2,19 @@ tcp: ?*uws.SocketContext = null, onQueryResolveFn: jsc.Strong.Optional = .empty, onQueryRejectFn: jsc.Strong.Optional = .empty, +onCopyStartFn: jsc.Strong.Optional = .empty, +onCopyChunkFn: jsc.Strong.Optional = .empty, +onCopyEndFn: jsc.Strong.Optional = .empty, +onWritableFn: jsc.Strong.Optional = .empty, pub fn init(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { var ctx = &globalObject.bunVM().rareData().postgresql_context; ctx.onQueryResolveFn.set(globalObject, callframe.argument(0)); ctx.onQueryRejectFn.set(globalObject, callframe.argument(1)); + ctx.onCopyStartFn.set(globalObject, callframe.argument(2)); + ctx.onCopyChunkFn.set(globalObject, callframe.argument(3)); + ctx.onCopyEndFn.set(globalObject, callframe.argument(4)); + ctx.onWritableFn.set(globalObject, callframe.argument(5)); return .js_undefined; } diff --git a/src/sql/postgres/protocol/CopyBothResponse.zig b/src/sql/postgres/protocol/CopyBothResponse.zig new file mode 100644 index 00000000000..c553d063fba --- /dev/null +++ b/src/sql/postgres/protocol/CopyBothResponse.zig @@ -0,0 +1,3 @@ +/// PostgreSQL COPY BOTH response message (used for replication). +/// Uses shared CopyResponse implementation. +pub const CopyBothResponse = @import("./CopyResponse.zig"); diff --git a/src/sql/postgres/protocol/CopyData.zig b/src/sql/postgres/protocol/CopyData.zig index ca26782a8d8..3259467e825 100644 --- a/src/sql/postgres/protocol/CopyData.zig +++ b/src/sql/postgres/protocol/CopyData.zig @@ -5,7 +5,7 @@ data: Data = .{ .empty = {} }, pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { const length = try reader.length(); - const data = try reader.read(@intCast(length -| 5)); + const data = try reader.read(@intCast(length -| 4)); this.* = .{ .data = data, }; @@ -19,12 +19,12 @@ pub fn writeInternal( writer: NewWriter(Context), ) !void { const data = this.data.slice(); - const count: u32 = @sizeOf((u32)) + data.len + 1; + const count: u32 = @sizeOf(u32) + @as(u32, @intCast(data.len)); const header = [_]u8{ 'd', } ++ toBytes(Int32(count)); try writer.write(&header); - try writer.string(data); + try writer.write(data); } pub const write = WriteWrap(@This(), writeInternal).write; diff --git a/src/sql/postgres/protocol/CopyFail.zig b/src/sql/postgres/protocol/CopyFail.zig index 4904346662a..071f301a946 100644 --- a/src/sql/postgres/protocol/CopyFail.zig +++ b/src/sql/postgres/protocol/CopyFail.zig @@ -19,7 +19,7 @@ pub fn writeInternal( writer: NewWriter(Context), ) !void { const message = this.message.slice(); - const count: u32 = @sizeOf((u32)) + message.len + 1; + const count: u32 = @sizeOf(u32) + @as(u32, @intCast(message.len)) + 1; const header = [_]u8{ 'f', } ++ toBytes(Int32(count)); diff --git a/src/sql/postgres/protocol/CopyInResponse.zig b/src/sql/postgres/protocol/CopyInResponse.zig index 9654855d8e7..924d0d4c670 100644 --- a/src/sql/postgres/protocol/CopyInResponse.zig +++ b/src/sql/postgres/protocol/CopyInResponse.zig @@ -1,13 +1,3 @@ -const CopyInResponse = @This(); - -pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - _ = reader; - _ = this; - bun.Output.panic("TODO: not implemented {s}", .{bun.meta.typeBaseName(@typeName(@This()))}); -} - -pub const decode = DecoderWrap(CopyInResponse, decodeInternal).decode; - -const bun = @import("bun"); -const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; -const NewReader = @import("./NewReader.zig").NewReader; +/// PostgreSQL COPY IN response message (COPY FROM STDIN). +/// Uses shared CopyResponse implementation. +pub const CopyInResponse = @import("./CopyResponse.zig"); diff --git a/src/sql/postgres/protocol/CopyOutResponse.zig b/src/sql/postgres/protocol/CopyOutResponse.zig index dac843fff70..5f4a214b8a4 100644 --- a/src/sql/postgres/protocol/CopyOutResponse.zig +++ b/src/sql/postgres/protocol/CopyOutResponse.zig @@ -1,13 +1,3 @@ -const CopyOutResponse = @This(); - -pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - _ = reader; - _ = this; - bun.Output.panic("TODO: not implemented {s}", .{bun.meta.typeBaseName(@typeName(@This()))}); -} - -pub const decode = DecoderWrap(CopyOutResponse, decodeInternal).decode; - -const bun = @import("bun"); -const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; -const NewReader = @import("./NewReader.zig").NewReader; +/// PostgreSQL COPY OUT response message (COPY TO STDOUT). +/// Uses shared CopyResponse implementation. +pub const CopyOutResponse = @import("./CopyResponse.zig"); diff --git a/src/sql/postgres/protocol/CopyResponse.zig b/src/sql/postgres/protocol/CopyResponse.zig new file mode 100644 index 00000000000..0cbada25987 --- /dev/null +++ b/src/sql/postgres/protocol/CopyResponse.zig @@ -0,0 +1,62 @@ +/// Shared implementation for PostgreSQL COPY response messages. +/// Used by CopyInResponse, CopyOutResponse, and CopyBothResponse which +/// share identical structure and decoding logic. +const CopyResponse = @This(); + +#overall_format: u8 = 0, +#column_format_codes: []u16 = &[_]u16{}, + +/// Returns the overall format code (0 = text, 1 = binary) +pub fn overall_format(this: *const CopyResponse) u8 { + return this.#overall_format; +} + +/// Returns the per-column format codes +pub fn column_format_codes(this: *const CopyResponse) []const u16 { + return this.#column_format_codes; +} + +pub fn deinit(this: *CopyResponse) void { + if (this.#column_format_codes.len > 0) { + bun.default_allocator.free(this.#column_format_codes); + this.#column_format_codes = &[_]u16{}; + } +} + +pub fn decodeInternal(this: *CopyResponse, comptime Container: type, reader: NewReader(Container)) !void { + this.* = .{ + .#overall_format = 0, + .#column_format_codes = &[_]u16{}, + }; + + const length_value = try reader.length(); + const payload_len: usize = if (length_value > 4) @intCast(length_value - 4) else 0; + const min_header: usize = 1 + 2; // overall_format (u8) + column_count (i16) + if (payload_len < min_header) return error.InvalidMessage; + + const format_value = try reader.int(u8); + const raw_column_count = try reader.short(); + const column_count: usize = @intCast(@max(raw_column_count, 0)); + + const max_columns = (payload_len - min_header) / 2; // each format code is int16 + if (column_count > max_columns) return error.InvalidMessage; + + const format_codes = try bun.default_allocator.alloc(u16, column_count); + errdefer bun.default_allocator.free(format_codes); + + for (format_codes) |*format_code| { + const raw = try reader.short(); + format_code.* = if (raw < 0) 0 else @intCast(raw); + } + + this.* = .{ + .#overall_format = format_value, + .#column_format_codes = format_codes, + }; +} + +pub const decode = DecoderWrap(CopyResponse, decodeInternal).decode; + +const bun = @import("bun"); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/test/js/sql/sql-postgres-copy.test.ts b/test/js/sql/sql-postgres-copy.test.ts new file mode 100644 index 00000000000..046bb486a3c --- /dev/null +++ b/test/js/sql/sql-postgres-copy.test.ts @@ -0,0 +1,1437 @@ +import { SQL, type CopyBinaryType } from "bun"; +import { describe, test, expect, afterAll, beforeAll } from "bun:test"; +import { isDockerEnabled } from "harness"; +import * as dockerCompose from "../../docker/index.ts"; + +if (isDockerEnabled()) { + describe("PostgreSQL COPY protocol", async () => { + const info = await dockerCompose.ensure("postgres_plain"); + + const connect = () => + new SQL({ + hostname: info.host, + port: info.ports[5432], + database: "bun_sql_test", + username: "bun_sql_test", + tls: false, + max: 1, + }); + + afterAll(async () => { + if (!process.env.BUN_KEEP_DOCKER) { + await dockerCompose.down(); + } + }); + + // Phase 0: Regression tests + + test("Regression: COPY maxBytes=0 disables the limit", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_unlimited", []); + await sql.unsafe("CREATE TABLE copy_unlimited (id INT, name TEXT)", []); + + const rowCount = 2500; + const name = "x".repeat(80); + + async function* genRows() { + for (let i = 0; i < rowCount; i++) { + yield `${i}\t${name}\n`; + } + } + + const copyRes = await sql.copyFrom("copy_unlimited", ["id", "name"], genRows(), { + format: "text", + maxBytes: 0, + }); + + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(rowCount); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_unlimited`; + expect(verify[0]?.count).toBe(rowCount); + }); + + test("Regression: ErrorResponse during COPY rejects the correct COPY query even with another query queued", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_error_owner", []); + await sql.unsafe("CREATE TABLE copy_error_owner (id INT NOT NULL, name TEXT)", []); + + async function* invalidRows() { + yield "1\tok\n"; + // Invalid for NOT NULL id, will trigger an ErrorResponse during COPY + yield "\\N\tbad\n"; + } + + const copyPromise = sql.copyFrom("copy_error_owner", ["id", "name"], invalidRows(), { format: "text" }); + + // Queue another query while COPY is in progress so ErrorResponse routing must prefer copy_owner. + const otherQueryPromise = sql`SELECT 123::int AS v`; + + let copyFailed = false; + try { + await copyPromise; + } catch { + copyFailed = true; + } + expect(copyFailed).toBe(true); + + // The non-COPY query should still succeed, proving the error was attributed to the COPY request. + const otherQueryResult = await otherQueryPromise; + expect(otherQueryResult[0]?.v).toBe(123); + + // Ensure COPY did not partially insert. + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_error_owner`; + expect(verify[0]?.count).toBe(0); + }); + + test("Regression: copyTo falls back to accumulation when streaming is requested but no chunk handler is registered", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_stream_fallback", []); + await sql.unsafe("CREATE TABLE copy_stream_fallback (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_stream_fallback (id, name) VALUES (1, 'Alpha'), (2, 'Beta')", []); + + // Force the "no chunk handler" path by temporarily disabling the handler registration methods. + const originalOnCopyChunk = (sql as any).onCopyChunk; + const originalOnCopyEnd = (sql as any).onCopyEnd; + (sql as any).onCopyChunk = undefined; + (sql as any).onCopyEnd = undefined; + + try { + // Request streaming, but the implementation should fall back to accumulation. + // Accumulation may still yield more than one chunk depending on internal buffering, + // so we validate that concatenation produces a single correct payload. + const iterable = await sql.copyTo({ + table: "copy_stream_fallback", + columns: ["id", "name"], + format: "csv", + stream: true, + }); + + const chunks: string[] = []; + for await (const chunk of iterable) { + chunks.push(typeof chunk === "string" ? chunk : new TextDecoder().decode(chunk)); + } + + expect(chunks.length).toBeGreaterThan(0); + const payload = chunks.join(""); + expect(payload.includes("Alpha")).toBe(true); + expect(payload.includes("Beta")).toBe(true); + expect(payload.length).toBeGreaterThan(0); + } finally { + (sql as any).onCopyChunk = originalOnCopyChunk; + (sql as any).onCopyEnd = originalOnCopyEnd; + } + }); + + // Phase 1: COPY TO STDOUT (Data Export) + + test("COPY TO STDOUT (text) returns a single string payload", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_users", []); + await sql.unsafe("CREATE TABLE copy_users (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_users (id, name) VALUES (1, 'Alex'), (2, 'Bea')", []); + + const result = await sql`COPY copy_users TO STDOUT`; + expect(Array.isArray(result)).toBe(true); + expect(typeof result[0]).toBe("string"); + const payload = String(result[0]); + expect(payload.includes("Alex")).toBe(true); + expect(payload.includes("Bea")).toBe(true); + expect(result.command).toBe("COPY"); + expect(result.count).toBe(2); + }); + + test("COPY TO STDOUT with subquery", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_sub", []); + await sql.unsafe("CREATE TABLE copy_sub (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_sub (id, name) VALUES (1, 'A'), (2, 'B')", []); + + const result = await sql`COPY (SELECT name FROM copy_sub ORDER BY id LIMIT 1) TO STDOUT`; + expect(Array.isArray(result)).toBe(true); + expect(typeof result[0]).toBe("string"); + expect(String(result[0]).trim()).toBe("A"); + }); + + test("COPY TO STDOUT (csv) returns a single string payload", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_csv", []); + await sql.unsafe("CREATE TABLE copy_csv (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_csv (id, name) VALUES (10, 'Hello'), (11, 'World')", []); + + const result = await sql`COPY copy_csv TO STDOUT (FORMAT CSV)`; + expect(Array.isArray(result)).toBe(true); + expect(typeof result[0]).toBe("string"); + const payload = String(result[0]); + expect(payload.includes("10,Hello")).toBe(true); + expect(payload.includes("11,World")).toBe(true); + expect(result.command).toBe("COPY"); + expect(result.count).toBe(2); + }); + + test("COPY TO STDOUT with empty result", async () => { + await using sql = connect(); + + const result = await sql`COPY (SELECT * FROM (VALUES (1)) t(i) WHERE i = -1) TO STDOUT`; + expect(Array.isArray(result)).toBe(true); + expect(String(result[0] ?? "")).toBe(""); + }); + + // Phase 2: COPY FROM STDIN (High-level API) + + test("COPY FROM STDIN (text) with array rows", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_from_text", []); + await sql.unsafe("CREATE TABLE copy_from_text (id INT, name TEXT)", []); + + const rows: Array<[number, string]> = [ + [1, "One"], + [2, "Two"], + [3, "Three"], + ]; + const copyRes = await sql.copyFrom("copy_from_text", ["id", "name"], rows, { format: "text" }); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(rows.length); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_from_text`; + expect(verify[0]?.count).toBe(rows.length); + }); + + test("COPY FROM STDIN (text) with raw TSV string payload", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_from_text_string", []); + await sql.unsafe("CREATE TABLE copy_from_text_string (id INT, name TEXT)", []); + const tsv = "3\tTSV User\n4\tTSV Two\n"; + const copyRes = await sql.copyFrom("copy_from_text_string", ["id", "name"], tsv, { format: "text" }); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(2); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_from_text_string`; + expect(verify[0]?.count).toBe(2); + }); + + test("COPY FROM STDIN (text) with generator of rows", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_from_text_gen", []); + await sql.unsafe("CREATE TABLE copy_from_text_gen (id INT, name TEXT)", []); + + function* genRows() { + for (let i = 5; i <= 7; i++) { + yield [i, `Gen ${i}`] as [number, string]; + } + } + const copyRes = await sql.copyFrom("copy_from_text_gen", ["id", "name"], genRows(), { format: "text" }); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(3); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_from_text_gen`; + expect(verify[0]?.count).toBe(3); + }); + + test("COPY FROM STDIN (text) with async iterable of rows", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_from_text_async", []); + await sql.unsafe("CREATE TABLE copy_from_text_async (id INT, name TEXT)", []); + + async function* genAsyncRows() { + for (let i = 8; i <= 10; i++) { + await Promise.resolve(); + yield [i, `Async ${i}`] as [number, string]; + } + } + const copyRes = await sql.copyFrom("copy_from_text_async", ["id", "name"], genAsyncRows(), { format: "text" }); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(3); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_from_text_async`; + expect(verify[0]?.count).toBe(3); + }); + + test("COPY FROM STDIN (text) with async iterable of raw string chunks", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_from_chunks", []); + await sql.unsafe("CREATE TABLE copy_from_chunks (id INT, name TEXT)", []); + + async function* genRawStrings() { + yield "21\tRawOne\n"; + yield "22\tRawTwo\n"; + } + const copyRes = await sql.copyFrom("copy_from_chunks", ["id", "name"], genRawStrings(), { format: "text" }); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(2); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_from_chunks`; + expect(verify[0]?.count).toBe(2); + }); + + test("COPY FROM STDIN (csv) with async iterable of raw Uint8Array chunks", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_from_chunks_bin", []); + await sql.unsafe("CREATE TABLE copy_from_chunks_bin (id INT, name TEXT)", []); + const enc = new TextEncoder(); + async function* genRawUint8() { + yield enc.encode("31,RawCSVOne\n"); + yield enc.encode("32,RawCSVTwo\n"); + } + const copyRes = await sql.copyFrom("copy_from_chunks_bin", ["id", "name"], genRawUint8(), { format: "csv" }); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(2); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_from_chunks_bin`; + expect(verify[0]?.count).toBe(2); + }); + + // Phase 3: COPY TO STDOUT (Streaming API) + + test("copyTo (query form) streams chunks", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_stream_q", []); + await sql.unsafe("CREATE TABLE copy_stream_q (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_stream_q (id, name) VALUES (1, 'Hello'), (2, 'World')", []); + let count = 0; + let totalLen = 0; + for await (const chunk of sql.copyTo(`COPY (SELECT id, name FROM copy_stream_q ORDER BY id) TO STDOUT`)) { + const s = typeof chunk === "string" ? chunk : new TextDecoder().decode(chunk as ArrayBuffer); + totalLen += s.length; + count++; + } + expect(count).toBeGreaterThan(0); + expect(totalLen).toBeGreaterThan(0); + }); + + test("copyTo (options, csv) streams string chunks", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_stream_opts", []); + await sql.unsafe("CREATE TABLE copy_stream_opts (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_stream_opts (id, name) VALUES (1, 'Hello')", []); + let count = 0; + for await (const chunk of sql.copyTo({ + table: "copy_stream_opts", + columns: ["id", "name"], + format: "csv", + })) { + expect(typeof chunk).toBe("string"); + count++; + } + expect(count).toBeGreaterThan(0); + }); + + // Phase 3.5: Abort and Progress demos + + test("copyTo supports progress + abort", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_to_abort", []); + await sql.unsafe("CREATE TABLE copy_to_abort (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_to_abort (id, name) VALUES (1, 'A'), (2, 'B'), (3, 'C')", []); + + const ac = new AbortController(); + let progressCalled = 0; + const stream = sql.copyTo({ + table: "copy_to_abort", + columns: ["id", "name"], + format: "csv", + signal: ac.signal, + onProgress: ({ bytesReceived, chunksReceived }: { bytesReceived: number; chunksReceived: number }) => { + progressCalled++; + if (chunksReceived >= 1) ac.abort(); + expect(bytesReceived).toBeGreaterThan(0); + }, + }); + + let threw = false; + try { + for await (const _ of stream) { + // consume first chunk only + break; + } + } catch { + threw = true; + } + expect(progressCalled).toBeGreaterThan(0); + expect(threw).toBe(true); + }); + + test("copyFrom backpressure waits for awaitWritable Promise", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_backpressure", []); + await sql.unsafe("CREATE TABLE copy_backpressure (id INT, val TEXT)", []); + + // Simulate backpressure by setting a very small maxBytes and sending one large chunk + const enc = new TextEncoder(); + const largeChunk = enc.encode("1\tHello World\n2\tMore Data\n".repeat(1000)); // ~18KB + + let progressCalled = 0; + let bytesSent = 0; + const res = await sql.copyFrom("copy_backpressure", ["id", "val"], [largeChunk], { + format: "text", + maxBytes: largeChunk.byteLength + 10, // allow only slightly more than one chunk + onProgress: info => { + progressCalled++; + bytesSent = info.bytesSent; + }, + }); + + expect(res?.command).toBe("COPY"); + expect(res?.count).toBeGreaterThan(0); + expect(progressCalled).toBeGreaterThan(0); + }); + + test("copyFrom supports progress + abort", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_from_abort", []); + await sql.unsafe("CREATE TABLE copy_from_abort (id INT, name TEXT)", []); + + const ac = new AbortController(); + const enc = new TextEncoder(); + async function* genManyRows() { + for (let i = 0; i < 200; i++) { + yield enc.encode(`${i},Name ${i}\n`); + } + } + let progressCalled = 0; + let threw = false; + try { + await sql.copyFrom("copy_from_abort", ["id", "name"], genManyRows(), { + format: "csv", + signal: ac.signal, + onProgress: ({ bytesSent, chunksSent }: { bytesSent: number; chunksSent: number }) => { + progressCalled++; + if (chunksSent >= 2) ac.abort(); + expect(bytesSent).toBeGreaterThan(0); + }, + }); + } catch { + threw = true; + } + expect(progressCalled).toBeGreaterThan(0); + expect(threw).toBe(true); + }); + + // Phase 4: Binary COPY + + test("binary COPY TO (non-streaming) returns single ArrayBuffer-like result", async () => { + await using sql = connect(); + + const result = await sql`COPY (SELECT 1::int) TO STDOUT (FORMAT BINARY)`; + const binChunk = result?.[0]; + expect(binChunk).toBeDefined(); + // It should be ArrayBuffer in Bun + expect(binChunk.byteLength ?? 0).toBeGreaterThan(0); + expect(result.command).toBe("COPY"); + }); + + test("binary COPY TO (streaming) yields ArrayBuffer chunks", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_bin2", []); + await sql.unsafe("CREATE TABLE copy_bin2 (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_bin2 (id, name) VALUES (1, 'One'), (2, 'Two')", []); + let sawArrayBuffer = false; + let total = 0; + for await (const chunk of sql.copyTo({ + table: "copy_bin2", + columns: ["id", "name"], + format: "binary", + })) { + if (chunk instanceof ArrayBuffer) { + sawArrayBuffer = true; + total += chunk.byteLength; + } + } + expect(sawArrayBuffer).toBe(true); + expect(total).toBeGreaterThan(0); + }); + + test("binary COPY FROM (zero-byte attempt) should fail on server", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_binary_zero", []); + await sql.unsafe("CREATE TABLE copy_binary_zero (id INT, name TEXT)", []); + let failed = false; + async function* emptyBinary() {} + try { + await sql.copyFrom("copy_binary_zero", ["id", "name"], emptyBinary(), { format: "binary" }); + } catch { + failed = true; + } + expect(failed).toBe(true); + }); + + test("COPY FROM STDIN (binary) with valid header and two rows", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_binary_data", []); + await sql.unsafe("CREATE TABLE copy_binary_data (id INT, name TEXT)", []); + + function be16(n: number) { + const b = new Uint8Array(2); + new DataView(b.buffer).setInt16(0, n, false); + return b; + } + function be32(n: number) { + const b = new Uint8Array(4); + new DataView(b.buffer).setInt32(0, n, false); + return b; + } + function beInt32(n: number) { + const b = new Uint8Array(4); + new DataView(b.buffer).setInt32(0, n, false); + return b; + } + function concat(...parts: Uint8Array[]) { + let len = 0; + for (const p of parts) len += p.length; + const out = new Uint8Array(len); + let o = 0; + for (const p of parts) { + out.set(p, o); + o += p.length; + } + return out; + } + function buildBinaryRow(id: number, name: string) { + const idBytes = beInt32(id); + const nameBytes = new TextEncoder().encode(name); + const fieldCount = be16(2); + const idLen = be32(4); + const nameLen = be32(nameBytes.length); + return concat(fieldCount, idLen, idBytes, nameLen, nameBytes); + } + + async function* genProperBinary() { + const sig = new Uint8Array([0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00]); + const flags = be32(0); + const extlen = be32(0); + yield concat(sig, flags, extlen); + yield buildBinaryRow(200, "Bin A"); + yield buildBinaryRow(201, "Bin B"); + yield be16(-1); + } + + const copyRes = await sql.copyFrom("copy_binary_data", ["id", "name"], genProperBinary(), { format: "binary" }); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(2); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_binary_data`; + expect(verify[0]?.count).toBe(2); + + let sawArrayBuffer = false; + for await (const chunk of sql.copyTo({ + table: "copy_binary_data", + columns: ["id", "name"], + format: "binary", + })) { + if (chunk instanceof ArrayBuffer) { + sawArrayBuffer = true; + break; + } + } + expect(sawArrayBuffer).toBe(true); + }); + + // Phase 5: CSV options (default delimiter and null token) + + test("copyFrom with CSV default delimiter and null token", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_csv_opts", []); + await sql.unsafe("CREATE TABLE copy_csv_opts (id INT, name TEXT, note TEXT)", []); + async function* genCsvDefaultCsv() { + yield "41,CSVOne,note A\n"; + yield "42,,note B\n"; + } + const copyCsvRes = await sql.copyFrom("copy_csv_opts", ["id", "name", "note"], genCsvDefaultCsv(), { + format: "csv", + }); + expect(copyCsvRes?.command).toBe("COPY"); + expect(copyCsvRes?.count).toBe(2); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_csv_opts`; + expect(verify[0]?.count).toBe(2); + }); + + // Phase 6: Binary COPY FROM with automatic encoder (extended types + batch) + + test("Binary copyFrom automatic encoder with extended types", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_binary_ext", []); + await sql.unsafe( + ` + CREATE TABLE copy_binary_ext ( + did int2, + i4 int4, + i8 int8, + f4 float4, + f8 float8, + ok boolean, + b bytea, + d date, + t time, + ts timestamp, + tz timestamptz, + u uuid, + j json, + jb jsonb, + txt text, + num numeric, + iv interval, + i4s int4[], + texts text[], + uuids uuid[] + ) + `, + [], + ); + + const now = new Date(Date.UTC(2024, 0, 2, 3, 4, 5, 6)); + const binRows: any[] = [ + [ + 1, + 123, + 1234567890123n, + 3.5, + 6.25, + true, + new Uint8Array([1, 2, 3, 4]), + "2024-01-01", + "12:34:56.789", + now, + now, + "550e8400-e29b-41d4-a716-446655440000", + { k: 1 }, + { jb: "x" }, + "hello\\world\tline\nend", + "12345.6789", + { days: 1, ms: 3600000 }, + [10, 20, 30], + ["x", "y"], + ["550e8400-e29b-41d4-a716-446655440000", "550e8400-e29b-41d4-a716-446655440001"], + ], + [ + 2, + -456, + -1234567890123n, + -1.5, + -2.25, + false, + new Uint8Array([9, 8, 7]), + "2024-01-02", + "23:59:59.123456", + new Date(Date.UTC(2024, 0, 3, 10, 20, 30)), + new Date(Date.UTC(2024, 0, 4, 11, 22, 33)), + "550e8400-e29b-41d4-a716-446655440001", + { k: 2 }, + { jb: "y" }, + "goodbye", + "-9876.54321", + { months: 2, days: 3, ms: 0 }, + [100, 200], + ["alpha", "beta"], + ["550e8400-e29b-41d4-a716-446655440001", "550e8400-e29b-41d4-a716-446655440000"], + ], + ]; + const binaryTypes: CopyBinaryType[] = [ + "int2", + "int4", + "int8", + "float4", + "float8", + "bool", + "bytea", + "date", + "time", + "timestamp", + "timestamptz", + "uuid", + "json", + "jsonb", + "text", + "numeric", + "interval", + "int4[]", + "text[]", + "uuid[]", + ]; + + const copyRes = await sql.copyFrom( + "copy_binary_ext", + [ + "did", + "i4", + "i8", + "f4", + "f8", + "ok", + "b", + "d", + "t", + "ts", + "tz", + "u", + "j", + "jb", + "txt", + "num", + "iv", + "i4s", + "texts", + "uuids", + ], + binRows, + { format: "binary", binaryTypes, batchSize: 64 * 1024 }, + ); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(2); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_binary_ext`; + expect(verify[0]?.count).toBe(2); + }); + + // Phase 7: copyToPipeTo already covered earlier + + // Phase 8: COPY FROM (text) with custom batchSize + + test("COPY FROM (text) with custom batchSize using async rows", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_batch_test", []); + await sql.unsafe("CREATE TABLE copy_batch_test (id INT, name TEXT)", []); + async function* manyTextRows(count: number) { + for (let i = 0; i < count; i++) { + yield [i, `Name ${i} with \\ and \t and \n`] as [number, string]; + } + } + const count = 300; + const copyRes = await sql.copyFrom("copy_batch_test", ["id", "name"], manyTextRows(count), { + format: "text", + batchSize: 32 * 1024, + }); + expect(copyRes?.command).toBe("COPY"); + expect(copyRes?.count).toBe(count); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_batch_test`; + expect(verify[0]?.count).toBe(count); + }); + + // Progress verification for batched text COPY FROM + test("copyFrom (text) progress bytes/chunks match server output", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_progress", []); + await sql.unsafe("CREATE TABLE copy_progress (id INT, name TEXT)", []); + + const total = 200; + let expected = ""; + for (let i = 0; i < total; i++) { + expected += `${i}\tName ${i}\n`; + } + + let bytesSent = 0; + let chunksSent = 0; + + async function* genRows() { + for (let i = 0; i < total; i++) { + // Ensure we exercise the row-batching path (flushBatch will send aggregated chunks) + yield [i, `Name ${i}`] as [number, string]; + } + } + + const res = await sql.copyFrom("copy_progress", ["id", "name"], genRows(), { + format: "text", + onProgress: ({ bytesSent: b, chunksSent: c }: { bytesSent: number; chunksSent: number }) => { + bytesSent = b; + chunksSent = c; + }, + }); + expect(res?.command).toBe("COPY"); + expect(res?.count).toBe(total); + + // At least one batch should have been sent + expect(chunksSent).toBeGreaterThan(0); + + // Progress bytes should equal the serialized payload length we generated + expect(bytesSent).toBe(expected.length); + + // Dump back from server in a deterministic order and compare to expected payload + const out = await sql`COPY (SELECT id, name FROM copy_progress ORDER BY id) TO STDOUT`; + const outStr = String(out[0] ?? ""); + expect(outStr.length).toBe(bytesSent); + expect(outStr).toBe(expected); + }); + + // Phase 9: COPY guardrails (timeout) + + test("copyTo timeout triggers when too small", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_timeout", []); + await sql.unsafe("CREATE TABLE copy_timeout (id INT, data TEXT)", []); + // Insert enough data to make copying take longer than the timeout + await sql.unsafe("INSERT INTO copy_timeout SELECT i, repeat('x', 1000) FROM generate_series(1, 10000) i", []); + + let didTimeout = false; + let errorMessage = ""; + try { + for await (const _ of sql.copyTo({ + table: "copy_timeout", + columns: ["id", "data"], + format: "text", + timeout: 50, // Very small timeout (50ms) to force timeout during large data copy + })) { + // Should timeout before getting all chunks + } + } catch (e) { + didTimeout = true; + errorMessage = String((e as any)?.message ?? e).toLowerCase(); + } + + // The timeout should actually fire + expect(didTimeout).toBe(true); + expect(errorMessage).toMatch(/timeout/); + }); + + test("Regression: copyTo has no timeout by default and timeout=0 disables timeout", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_timeout_disabled", []); + await sql.unsafe("CREATE TABLE copy_timeout_disabled (id INT, data TEXT)", []); + await sql.unsafe( + "INSERT INTO copy_timeout_disabled SELECT i, repeat('x', 1000) FROM generate_series(1, 10000) i", + [], + ); + + const readAll = async (iterable: AsyncIterable) => { + let count = 0; + for await (const _ of iterable) { + count++; + } + return count; + }; + + let succeededDefault = false; + try { + const n = await readAll( + sql.copyTo({ + table: "copy_timeout_disabled", + columns: ["id", "data"], + format: "text", + }), + ); + expect(n).toBeGreaterThan(0); + succeededDefault = true; + } catch (e) { + const message = String((e as any)?.message ?? e).toLowerCase(); + expect(message).not.toMatch(/timeout/); + } + expect(succeededDefault).toBe(true); + + let succeededZero = false; + try { + const n = await readAll( + sql.copyTo({ + table: "copy_timeout_disabled", + columns: ["id", "data"], + format: "text", + timeout: 0, + }), + ); + expect(n).toBeGreaterThan(0); + succeededZero = true; + } catch (e) { + const message = String((e as any)?.message ?? e).toLowerCase(); + expect(message).not.toMatch(/timeout/); + } + expect(succeededZero).toBe(true); + }); + + // pgx-inspired tests + + test("pgx: small typed rows with nulls", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS pgx_small", []); + await sql.unsafe( + `CREATE TABLE pgx_small( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`, + [], + ); + + const tzed = new Date(); + const rows: any[][] = [ + [0, 1, 2n, "abc", "efg", "2000-01-01", tzed], + [null, null, null, null, null, null, null], + ]; + + const res = await sql.copyFrom("pgx_small", ["a", "b", "c", "d", "e", "f", "g"], rows, { format: "text" }); + expect(res?.command).toBe("COPY"); + expect(res?.count).toBe(rows.length); + + const out = await sql`SELECT COUNT(*)::int AS count FROM pgx_small`; + expect(out[0]?.count).toBe(rows.length); + }); + + test("pgx: large rows with bytea", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS pgx_large", []); + await sql.unsafe( + `CREATE TABLE pgx_large( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz, + h bytea + )`, + [], + ); + + const tzed = new Date(); + const bytes = new Uint8Array([111, 111, 111, 111]); + const rows: any[][] = []; + for (let i = 0; i < 1000; i++) { + rows.push([0, 1, 2n, "abc", "efg", "2000-01-01", tzed, bytes]); + } + const res = await sql.copyFrom("pgx_large", ["a", "b", "c", "d", "e", "f", "g", "h"], rows, { + format: "binary", + binaryTypes: ["int2", "int4", "int8", "varchar", "text", "date", "timestamptz", "bytea"], + }); + expect(res?.command).toBe("COPY"); + expect(res?.count).toBe(rows.length); + + const out = await sql`SELECT COUNT(*)::int AS count FROM pgx_large`; + expect(out[0]?.count).toBe(rows.length); + }); + + test("pgx: enum types with copyFrom", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS pgx_enum_tbl", []); + await sql.unsafe( + "DO $$ BEGIN IF EXISTS (SELECT 1 FROM pg_type WHERE typname = 'color') THEN DROP TYPE color; END IF; END $$;", + [], + ); + await sql.unsafe( + "DO $$ BEGIN IF EXISTS (SELECT 1 FROM pg_type WHERE typname = 'fruit') THEN DROP TYPE fruit; END IF; END $$;", + [], + ); + await sql.unsafe(`CREATE TYPE color AS ENUM ('blue', 'green', 'orange')`, []); + await sql.unsafe(`CREATE TYPE fruit AS ENUM ('apple', 'orange', 'grape')`, []); + await sql.unsafe( + `CREATE TABLE pgx_enum_tbl( + a text, + b color, + c fruit, + d color, + e fruit, + f text + )`, + [], + ); + + const rows: any[][] = [ + ["abc", "blue", "grape", "orange", "orange", "def"], + [null, null, null, null, null, null], + ]; + const res = await sql.copyFrom("pgx_enum_tbl", ["a", "b", "c", "d", "e", "f"], rows, { format: "text" }); + expect(res?.command).toBe("COPY"); + expect(res?.count).toBe(rows.length); + + const out = await sql`SELECT COUNT(*)::int AS count FROM pgx_enum_tbl`; + expect(out[0]?.count).toBe(rows.length); + }); + + test("pgx: server failure mid-copy (NOT NULL violation) yields 0 inserted", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS pgx_fail_mid", []); + await sql.unsafe(`CREATE TABLE pgx_fail_mid(a int4, b varchar NOT NULL)`, []); + const rows: any[][] = [ + [1, "abc"], + [2, null], // should trigger server-side failure + [3, "def"], + ]; + let failed = false; + try { + await sql.copyFrom("pgx_fail_mid", ["a", "b"], rows, { format: "text" }); + } catch { + failed = true; + } + expect(failed).toBe(true); + + const out = await sql`SELECT COUNT(*)::int AS count FROM pgx_fail_mid`; + expect(out[0]?.count).toBe(0); + }); + + test("pgx: client generator error midway", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS pgx_client_err", []); + await sql.unsafe(`CREATE TABLE pgx_client_err(a bytea NOT NULL)`, []); + async function* errGen() { + let count = 0; + while (true) { + count++; + if (count === 3) throw new Error("client error"); + yield new Uint8Array(1000); + if (count >= 100) break; + } + } + let failed = false; + try { + await sql.copyFrom("pgx_client_err", ["a"], errGen(), { format: "binary" }); + } catch { + failed = true; + } + expect(failed).toBe(true); + + const out = await sql`SELECT COUNT(*)::int AS count FROM pgx_client_err`; + expect(out[0]?.count).toBe(0); + }); + + test("pgx: automatic string conversion for int8 and numeric[]", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS pgx_auto_str", []); + await sql.unsafe("CREATE TABLE pgx_auto_str(a int8)", []); + const rows1: any[][] = [["42"], ["7"], [8]]; + const res1 = await sql.copyFrom("pgx_auto_str", ["a"], rows1, { format: "text" }); + expect(res1?.count).toBe(rows1.length); + + const nums = await sql`SELECT a::bigint AS a FROM pgx_auto_str ORDER BY a`; + expect(nums.map(n => Number(n.a))).toEqual([7, 8, 42]); + + await sql.unsafe("DROP TABLE IF EXISTS pgx_auto_arr", []); + await sql.unsafe("CREATE TABLE pgx_auto_arr(a numeric[])", []); + const rows2: any[][] = [[[42]], [[7]], [[8, 9]]]; + const res2 = await sql.copyFrom("pgx_auto_arr", ["a"], rows2, { format: "binary", binaryTypes: ["numeric[]"] }); + expect(res2?.count).toBe(rows2.length); + + const arr = await sql`SELECT a FROM pgx_auto_arr`; + // Flatten to verify values are present + expect(arr.length).toBe(rows2.length); + }); + + test("pgx: function-style generator copy", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS pgx_func", []); + await sql.unsafe("CREATE TABLE pgx_func(a int)", []); + const channelItems = 10; + + async function* gen() { + for (let i = 0; i < channelItems; i++) { + yield [i]; + } + } + + const ok = await sql.copyFrom("pgx_func", ["a"], gen(), { format: "text" }); + expect(ok?.count).toBe(channelItems); + + const rows = await sql`SELECT a::int AS a FROM pgx_func ORDER BY a`; + expect(rows.map((r: any) => r.a)).toEqual([...Array(channelItems)].map((_, i) => i)); + + // Simulate a failure on the producer side + async function* genFail() { + let x = 9; + while (true) { + x++; + if (x > 100) throw new Error("simulated error"); + yield [x]; + } + } + + let failed = false; + try { + await sql.copyFrom("pgx_func", ["a"], genFail(), { format: "text" }); + } catch { + failed = true; + } + expect(failed).toBe(true); + }); + + test("unique constraint violation during COPY FROM yields zero inserted", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_unique", []); + await sql.unsafe("CREATE TABLE copy_unique (id INT PRIMARY KEY, name TEXT)", []); + const rows = [ + [1, "A"], + [1, "B"], + ]; + let failed = false; + try { + await sql.copyFrom("copy_unique", ["id", "name"], rows, { format: "text" }); + } catch { + failed = true; + } + expect(failed).toBe(true); + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_unique`; + expect(verify[0]?.count).toBe(0); + }); + + test("type cast error during COPY FROM yields zero inserted", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_cast_err", []); + await sql.unsafe("CREATE TABLE copy_cast_err (id INT NOT NULL)", []); + const badRows = [["abc"]]; // invalid int + let failed = false; + try { + await sql.copyFrom("copy_cast_err", ["id"], badRows, { format: "text" }); + } catch { + failed = true; + } + expect(failed).toBe(true); + const verify = await sql`SELECT COUNT(*)::int AS count FROM copy_cast_err`; + expect(verify[0]?.count).toBe(0); + }); + + test("CSV quoted fields and embedded quotes", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_csv_quotes", []); + await sql.unsafe('CREATE TABLE copy_csv_quotes (id INT, "full" TEXT, "quote" TEXT)', []); + async function* gen() { + yield '1,"Last, First","He said ""Hi"""\n'; + yield '2,"Simple","Plain"\n'; + } + const res = await sql.copyFrom("copy_csv_quotes", ["id", "full", "quote"], gen(), { format: "csv" }); + expect(res?.command).toBe("COPY"); + expect(res?.count).toBe(2); + + const rows = await sql`SELECT id::int AS id, "full", "quote" FROM copy_csv_quotes ORDER BY id`; + expect(rows[0].full).toBe("Last, First"); + expect(rows[0].quote).toBe('He said "Hi"'); + expect(rows[1].full).toBe("Simple"); + expect(rows[1].quote).toBe("Plain"); + }); + + test("copyToPipeTo streams CSV to sink", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS copy_pipe_csv", []); + await sql.unsafe("CREATE TABLE copy_pipe_csv (id INT, name TEXT)", []); + await sql.unsafe("INSERT INTO copy_pipe_csv (id, name) VALUES (1,'A'),(2,'B')", []); + + const sinkChunks: Array = []; + const sink = { + async write(chunk: string | ArrayBuffer | Uint8Array) { + sinkChunks.push(chunk); + }, + async end() {}, + }; + + await sql.copyToPipeTo( + { + table: "copy_pipe_csv", + columns: ["id", "name"], + format: "csv", + }, + sink, + ); + + expect(sinkChunks.length).toBeGreaterThan(0); + const stringChunks = sinkChunks.filter(c => typeof c === "string"); + expect(stringChunks.length).toBeGreaterThan(0); + }); + + test("Audit fix: Binary COPY header validation - incomplete header should fail", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS audit_binary_test", []); + await sql.unsafe("CREATE TABLE audit_binary_test (id INT, name TEXT)", []); + + // Try to send incomplete/invalid binary data (missing proper header) + let failed = false; + async function* invalidBinaryData() { + // Send incomplete header (less than 11 bytes required for signature) + yield new Uint8Array([0x50, 0x47, 0x43]); // Only "PGC" - incomplete signature + // Send trailer immediately to trigger completion + const trailer = new Uint8Array(2); + new DataView(trailer.buffer).setInt16(0, -1, false); + yield trailer; + } + + try { + await sql.copyFrom("audit_binary_test", ["id", "name"], invalidBinaryData(), { + format: "binary", + }); + } catch (e) { + failed = true; + expect(e).toBeDefined(); + } + expect(failed).toBe(true); + }); + + test("Audit fix: Empty columns list - COPY should work without columns specified", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS audit_empty_cols", []); + await sql.unsafe("CREATE TABLE audit_empty_cols (id INT, name TEXT)", []); + + // Insert with empty columns array - should copy all columns + const data = "1\tAlice\n2\tBob\n"; + const result = await sql.copyFrom("audit_empty_cols", [], data, { format: "text" }); + + expect(result.command).toBe("COPY"); + expect(result.count).toBe(2); + + const verify = await sql`SELECT COUNT(*)::int AS count FROM audit_empty_cols`; + expect(verify[0]?.count).toBe(2); + }); + + test("Audit fix: Large maxBytes values should not overflow to negative", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS audit_large_bytes", []); + await sql.unsafe("CREATE TABLE audit_large_bytes (id INT, data TEXT)", []); + await sql.unsafe("INSERT INTO audit_large_bytes VALUES (1, 'test')", []); + + let bytesReceived = 0; + const largeLimit = 5_000_000_000; // 5GB - larger than 32-bit signed int max + + // This should not fail due to negative comparison + let chunks = 0; + for await (const chunk of sql.copyTo({ + table: "audit_large_bytes", + columns: ["id", "data"], + format: "text", + maxBytes: largeLimit, // Large value that would overflow with bitwise ops + onProgress: info => { + bytesReceived = info.bytesReceived; + // Should be positive + expect(bytesReceived).toBeGreaterThanOrEqual(0); + }, + })) { + chunks++; + expect(chunk).toBeDefined(); + } + + expect(chunks).toBeGreaterThan(0); + expect(bytesReceived).toBeGreaterThan(0); + expect(bytesReceived).toBeLessThan(largeLimit); // Should not exceed limit + }); + + test("Audit fix: UTF-8 byte length calculation - progress should count UTF-8 bytes", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS audit_utf8_test", []); + await sql.unsafe("CREATE TABLE audit_utf8_test (id INT, emoji TEXT)", []); + await sql.unsafe("INSERT INTO audit_utf8_test VALUES (1, '👍'), (2, '🎉'), (3, '😀')", []); + + let bytesReceived = 0; + let lastBytes = 0; + + for await (const chunk of sql.copyTo({ + table: "audit_utf8_test", + columns: ["id", "emoji"], + format: "text", + onProgress: info => { + bytesReceived = info.bytesReceived; + }, + })) { + if (typeof chunk === "string") { + // Manual UTF-8 byte calculation for verification + const utf8Bytes = new TextEncoder().encode(chunk).byteLength; + const utf16Length = chunk.length; + + // UTF-8 emoji bytes should be more than UTF-16 code units for emojis + // Each emoji is typically 4 UTF-8 bytes but 2 UTF-16 code units + if (chunk.includes("👍") || chunk.includes("🎉") || chunk.includes("😀")) { + expect(utf8Bytes).toBeGreaterThan(utf16Length); + } + + // Progress should accumulate UTF-8 bytes + const bytesDelta = bytesReceived - lastBytes; + lastBytes = bytesReceived; + + // The delta should be close to UTF-8 byte length (allow for some variance due to buffering) + if (bytesDelta > 0) { + expect(bytesDelta).toBeGreaterThanOrEqual(chunk.length); + } + } + } + + expect(bytesReceived).toBeGreaterThan(0); + }); + + test("Audit fix: Binary COPY with valid header should succeed", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS audit_valid_binary", []); + await sql.unsafe("CREATE TABLE audit_valid_binary (id INT, name TEXT)", []); + + function be16(n: number) { + const b = new Uint8Array(2); + new DataView(b.buffer).setInt16(0, n, false); + return b; + } + function be32(n: number) { + const b = new Uint8Array(4); + new DataView(b.buffer).setInt32(0, n, false); + return b; + } + function concat(...parts: Uint8Array[]) { + let len = 0; + for (const p of parts) len += p.length; + const out = new Uint8Array(len); + let o = 0; + for (const p of parts) { + out.set(p, o); + o += p.length; + } + return out; + } + + async function* validBinaryData() { + // Valid signature + const sig = new Uint8Array([0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00]); + const flags = be32(0); + const extlen = be32(0); + yield concat(sig, flags, extlen); + + // One row: field count (2), id length (4), id value (100), name length (5), name value + const fieldCount = be16(2); + const idLen = be32(4); + const idVal = be32(100); + const nameBytes = new TextEncoder().encode("Test"); + const nameLen = be32(nameBytes.length); + yield concat(fieldCount, idLen, idVal, nameLen, nameBytes); + + // Trailer + yield be16(-1); + } + + const result = await sql.copyFrom("audit_valid_binary", ["id", "name"], validBinaryData(), { + format: "binary", + }); + + expect(result.command).toBe("COPY"); + expect(result.count).toBe(1); + + const verify = await sql`SELECT * FROM audit_valid_binary`; + expect(verify[0]?.id).toBe(100); + expect(verify[0]?.name).toBe("Test"); + }); + + test("Audit fix: CSV empty string vs NULL - empty strings should be quoted", async () => { + await using sql = connect(); + + await sql.unsafe("DROP TABLE IF EXISTS audit_csv_null_test", []); + await sql.unsafe("CREATE TABLE audit_csv_null_test (id INT, val TEXT)", []); + + // Test data: [1, null], [2, ""], [3, "text"] + const rows = [ + [1, null], // Should emit: 1, + [2, ""], // Should emit: 2,"" + [3, "text"], // Should emit: 3,text + ]; + + const result = await sql.copyFrom("audit_csv_null_test", ["id", "val"], rows, { + format: "csv", + }); + + expect(result.command).toBe("COPY"); + expect(result.count).toBe(3); + + const verify = await sql`SELECT id::int AS id, val FROM audit_csv_null_test ORDER BY id`; + expect(verify[0]?.id).toBe(1); + expect(verify[0]?.val).toBe(null); // NULL value + expect(verify[1]?.id).toBe(2); + expect(verify[1]?.val).toBe(""); // Empty string + expect(verify[2]?.id).toBe(3); + expect(verify[2]?.val).toBe("text"); + }); + + test("Audit fix: uint32 clamping - large timeout/buffer values should not wrap", async () => { + // This test is intentionally pure and does not reserve a real connection. + // It validates the clamping logic used by reserved connection wrappers. + + const clampUint32 = (value: number) => { + const n = Number(value); + if (!Number.isFinite(n) || n <= 0) return 0; + return Math.min(0xffffffff, Math.trunc(n)); + }; + + // Values larger than 32-bit signed int max (2^31 - 1 = 2147483647) + const largeTimeout = 3_000_000_000; // 3 billion ms + const largeBufferSize = 5_000_000_000; // 5 billion bytes + + // Should clamp to max uint32 without wrapping to 0 or negative + expect(clampUint32(largeTimeout)).toBe(3_000_000_000); + expect(clampUint32(largeBufferSize)).toBe(0xffffffff); + + // Negative and non-finite values should clamp to 0 + expect(clampUint32(-1000)).toBe(0); + expect(clampUint32(-5000)).toBe(0); + expect(clampUint32(Number.NaN)).toBe(0); + expect(clampUint32(Number.POSITIVE_INFINITY)).toBe(0); + }); + + test("Audit fix: escapeIdentifier for schema-qualified names in copyTo", async () => { + await using sql = connect(); + + // Create a schema and table with schema-qualified name + await sql.unsafe("DROP SCHEMA IF EXISTS audit_schema CASCADE", []); + await sql.unsafe("CREATE SCHEMA audit_schema", []); + await sql.unsafe("CREATE TABLE audit_schema.qualified_table (id INT, data TEXT)", []); + await sql.unsafe("INSERT INTO audit_schema.qualified_table VALUES (1, 'test')", []); + + let chunks = 0; + let succeeded = false; + try { + for await (const chunk of sql.copyTo({ + table: "audit_schema.qualified_table", + columns: ["id", "data"], + format: "text", + })) { + chunks++; + expect(chunk).toBeDefined(); + } + succeeded = true; + } catch (e) { + // Should not throw + } + + expect(succeeded).toBe(true); + expect(chunks).toBeGreaterThan(0); + + // Cleanup + await sql.unsafe("DROP SCHEMA audit_schema CASCADE", []); + }); + }); +} else { + describe("PostgreSQL COPY protocol", () => { + test("skipped - docker not enabled", () => { + expect(true).toBe(true); + }); + }); +}