Skip to content

Commit

Permalink
Support multiple databases (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
romeerez committed Sep 15, 2024
1 parent 280f2f6 commit 5c9d063
Show file tree
Hide file tree
Showing 11 changed files with 4,556 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
DATABASE_URL=postgres://postgres:@localhost:5432/pg-transactional-tests
DATABASE_URLS=postgres://postgres:@localhost:5432/pg-transactional-tests,postgres://postgres:@localhost:5432/pg-transactional-tests-2
DATABASE_CAMEL_CASE=true
MIGRATIONS_PATH=tests/migrations
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ I have a repo [ORMs overview](https://github.com/romeerez/orms-overview) where I

This **does not** work only with Prisma because its implementation is very different.

If a test doesn't perform any query, it won't start a transaction in vain.

Supports testing multiple databases in parallel. Transaction state is tracked by connection parameters.
If there are different connection parameters, it will run different transactions.

## Get started

Install:
Expand Down Expand Up @@ -55,13 +60,13 @@ import db from './path-to-your-db'
// so every instance of `pg` in your app becomes patched
patchPgForTransactions();

// start transaction before all tests:
// start transaction before all tests (only when there are queries):
beforeAll(startTransaction)

// start transaction before each test:
// start transaction before each test (only when there are queries):
beforeEach(startTransaction);

// rollback transaction after each test:
// rollback transaction after each test (if transaction started):
afterEach(rollbackTransaction);

// rollback transaction after all and stop the db connection:
Expand Down
4,380 changes: 4,380 additions & 0 deletions package-lock.json

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "pg-transactional-tests",
"version": "1.0.9",
"version": "1.1.0",
"description": "Wraps each test in transaction for `pg` package",
"repository": "https://github.com/romeerez/pg-transactional-tests",
"main": "dist/index.js",
Expand All @@ -9,13 +9,15 @@
"test": "jest --setupFiles dotenv/config --watch",
"check": "jest --setupFiles dotenv/config",
"build": "tsc",
"db": "rake-db",
"db": "tsx tests/dbScript.ts",
"prepublish": "tsc"
},
"jest": {
"verbose": false,
"transform": {
"^.+\\.tsx?$": ["@swc/jest"]
"^.+\\.tsx?$": [
"@swc/jest"
]
}
},
"keywords": [
Expand All @@ -39,7 +41,8 @@
"jest": "^28.1.1",
"pg": "^8.11.3",
"prettier": "^2.7.1",
"rake-db": "^1.3.1",
"rake-db": "2.22.41",
"tsx": "^4.19.1",
"typescript": "^4.7.4"
},
"peerDependencies": {
Expand Down
108 changes: 73 additions & 35 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,37 @@ import {
QueryConfig,
} from 'pg';

let transactionId = 0;
let client: Client | undefined;
let connectPromise: Promise<void> | undefined;
let prependStartTransaction = false;

const { connect, query } = Client.prototype;
const { connect: poolConnect, query: poolQuery } = Pool.prototype;

interface ConnectionParameters {
user: string;
database: string;
port: number;
host: string;
}

interface ClientWithNeededTypes extends Client {
connectionParameters: ConnectionParameters;
}

const getClientId = (client: Client) => {
const { connectionParameters: p } = client as ClientWithNeededTypes;
return `${p.host} ${p.port} ${p.user} ${p.database}`;
};

let prependStartTransaction = false;

let clientStates: Record<
string,
{
transactionId: number;
client: Client;
connectPromise?: Promise<void>;
prependStartTransaction?: boolean;
}
> = {};

export const patchPgForTransactions = () => {
Client.prototype.connect = async function (
this: Client,
Expand All @@ -25,27 +48,35 @@ export const patchPgForTransactions = () => {
err: Error | undefined,
connection?: Client,
) => void;
if (!client) client = this;

if (connectPromise) {
await connectPromise;
cb?.(undefined, client);
const thisId = getClientId(this);

let state = clientStates[thisId];
if (!state) {
clientStates[thisId] = state = {
client: this,
transactionId: 0,
prependStartTransaction,
};
}

if (state.connectPromise) {
await state.connectPromise;
cb?.(undefined, state.client);
return;
}

connectPromise = new Promise((resolve, reject) => {
connect.call(client, (err) => {
return (state.connectPromise = new Promise((resolve, reject) => {
connect.call(state.client, (err) => {
if (err) {
cb?.(err);
reject(err);
} else {
cb?.(undefined, client);
cb?.(undefined, state.client);
resolve();
}
});
});

return connectPromise;
}));
};

Pool.prototype.connect = function (
Expand All @@ -69,6 +100,8 @@ export const patchPgForTransactions = () => {
inputArg: string | QueryConfig | QueryArrayConfig,
...args: any[]
) {
const state = clientStates[getClientId(this)];

let input = inputArg;
const sql = (typeof input === 'string' ? input : input.text)
.trim()
Expand All @@ -78,36 +111,36 @@ export const patchPgForTransactions = () => {
if (!sql.startsWith('SELECT')) {
let replacingSql: string | undefined;

if (prependStartTransaction) {
prependStartTransaction = false;
if (state.prependStartTransaction) {
state.prependStartTransaction = false;
await this.query('BEGIN');
}

if (sql.startsWith('START TRANSACTION') || sql.startsWith('BEGIN')) {
if (transactionId > 0) {
replacingSql = `SAVEPOINT "${transactionId++}"`;
if (state.transactionId > 0) {
replacingSql = `SAVEPOINT "${state.transactionId++}"`;
} else {
transactionId = 1;
state.transactionId = 1;
}
} else {
const isCommit = sql.startsWith('COMMIT');
const isRollback = !isCommit && sql.startsWith('ROLLBACK');
if (isCommit || isRollback) {
if (transactionId === 0) {
if (state.transactionId === 0) {
throw new Error(
`Trying to ${
isCommit ? 'COMMIT' : 'ROLLBACK'
} outside of transaction`,
);
}

if (transactionId > 1) {
const savePoint = --transactionId;
if (state.transactionId > 1) {
const savePoint = --state.transactionId;
replacingSql = `${
isCommit ? 'RELEASE' : 'ROLLBACK TO'
} SAVEPOINT "${savePoint}"`;
} else {
transactionId = 0;
state.transactionId = 0;
}
}
}
Expand All @@ -124,7 +157,7 @@ export const patchPgForTransactions = () => {
await (Client.prototype.connect as () => Promise<void>).call(this);

// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (query as any).call(client, input, ...args);
return (query as any).call(state.client, input, ...args);
};

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -140,24 +173,29 @@ export const patchPgForTransactions = () => {
};

export const unpatchPgForTransactions = () => {
transactionId = 0;
client = undefined;
connectPromise = undefined;

clientStates = {};
Client.prototype.connect = connect;
Client.prototype.query = query;
Pool.prototype.connect = poolConnect;
Pool.prototype.query = poolQuery;
};

export const startTransaction = async () => {
export const startTransaction = () => {
prependStartTransaction = true;
for (const state of Object.values(clientStates)) {
state.prependStartTransaction = true;
}
};

export const rollbackTransaction = async () => {
if (transactionId > 0) {
await client?.query('ROLLBACK');
}
export const rollbackTransaction = () => {
return Promise.all(
Object.values(clientStates).map(async (state) => {
if (state.transactionId > 0) {
await state.client?.query('ROLLBACK');
}
}),
);
};

export const close = () => client?.end();
export const close = () =>
Promise.all(Object.values(clientStates).map((state) => state.client.end()));
6 changes: 6 additions & 0 deletions tests/config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import 'dotenv/config';

const dbUrlsString = process.env.DATABASE_URLS;
if (!dbUrlsString) throw new Error(`Missing DATABASE_URLS env var`);

export const dbUrls = dbUrlsString.split(',');
10 changes: 10 additions & 0 deletions tests/dbScript.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { rakeDb } from 'rake-db';
import { dbUrls } from './config';

export const change = rakeDb(
dbUrls.map((url) => ({ databaseURL: url })),
{
migrationsPath: './migrations',
import: (path) => import(path),
},
);
Loading

0 comments on commit 5c9d063

Please sign in to comment.