diff --git a/src/client-side-encryption/state_machine.ts b/src/client-side-encryption/state_machine.ts index af3ea4c215..f47ee191b5 100644 --- a/src/client-side-encryption/state_machine.ts +++ b/src/client-side-encryption/state_machine.ts @@ -12,7 +12,9 @@ import { } from '../bson'; import { type ProxyOptions } from '../cmap/connection'; import { getSocks, type SocksLib } from '../deps'; +import { MongoOperationTimeoutError } from '../error'; import { type MongoClient, type MongoClientOptions } from '../mongo_client'; +import { Timeout, type TimeoutContext, TimeoutError } from '../timeout'; import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils'; import { autoSelectSocketOptions, type DataKey } from './client_encryption'; import { MongoCryptError } from './errors'; @@ -173,6 +175,7 @@ export type StateMachineOptions = { * An internal class that executes across a MongoCryptContext until either * a finishing state or an error is reached. Do not instantiate directly. */ +// TODO(DRIVERS-2671): clarify CSOT behavior for FLE APIs export class StateMachine { constructor( private options: StateMachineOptions, @@ -182,7 +185,11 @@ export class StateMachine { /** * Executes the state machine according to the specification */ - async execute(executor: StateMachineExecutable, context: MongoCryptContext): Promise { + async execute( + executor: StateMachineExecutable, + context: MongoCryptContext, + timeoutContext?: TimeoutContext + ): Promise { const keyVaultNamespace = executor._keyVaultNamespace; const keyVaultClient = executor._keyVaultClient; const metaDataClient = executor._metaDataClient; @@ -201,8 +208,13 @@ export class StateMachine { 'unreachable state machine state: entered MONGOCRYPT_CTX_NEED_MONGO_COLLINFO but metadata client is undefined' ); } - const collInfo = await this.fetchCollectionInfo(metaDataClient, context.ns, filter); + const collInfo = await this.fetchCollectionInfo( + metaDataClient, + context.ns, + filter, + timeoutContext + ); if (collInfo) { context.addMongoOperationResponse(collInfo); } @@ -222,9 +234,9 @@ export class StateMachine { // When we are using the shared library, we don't have a mongocryptd manager. const markedCommand: Uint8Array = mongocryptdManager ? await mongocryptdManager.withRespawn( - this.markCommand.bind(this, mongocryptdClient, context.ns, command) + this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext) ) - : await this.markCommand(mongocryptdClient, context.ns, command); + : await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext); context.addMongoOperationResponse(markedCommand); context.finishMongoOperation(); @@ -233,7 +245,12 @@ export class StateMachine { case MONGOCRYPT_CTX_NEED_MONGO_KEYS: { const filter = context.nextMongoOperation(); - const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter); + const keys = await this.fetchKeys( + keyVaultClient, + keyVaultNamespace, + filter, + timeoutContext + ); if (keys.length === 0) { // See docs on EMPTY_V @@ -255,9 +272,7 @@ export class StateMachine { } case MONGOCRYPT_CTX_NEED_KMS: { - const requests = Array.from(this.requests(context)); - await Promise.all(requests); - + await Promise.all(this.requests(context, timeoutContext)); context.finishKMSRequests(); break; } @@ -299,7 +314,7 @@ export class StateMachine { * @param kmsContext - A C++ KMS context returned from the bindings * @returns A promise that resolves when the KMS reply has be fully parsed */ - async kmsRequest(request: MongoCryptKMSRequest): Promise { + async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise { const parsedUrl = request.endpoint.split(':'); const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT; const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {}); @@ -329,10 +344,6 @@ export class StateMachine { } } - function ontimeout() { - return new MongoCryptError('KMS request timed out'); - } - function onerror(cause: Error) { return new MongoCryptError('KMS request failed', { cause }); } @@ -364,7 +375,6 @@ export class StateMachine { resolve: resolveOnNetSocketConnect } = promiseWithResolvers(); netSocket - .once('timeout', () => rejectOnNetSocketError(ontimeout())) .once('error', err => rejectOnNetSocketError(onerror(err))) .once('close', () => rejectOnNetSocketError(onclose())) .once('connect', () => resolveOnNetSocketConnect()); @@ -410,8 +420,8 @@ export class StateMachine { reject: rejectOnTlsSocketError, resolve } = promiseWithResolvers(); + socket - .once('timeout', () => rejectOnTlsSocketError(ontimeout())) .once('error', err => rejectOnTlsSocketError(onerror(err))) .once('close', () => rejectOnTlsSocketError(onclose())) .on('data', data => { @@ -425,20 +435,26 @@ export class StateMachine { resolve(); } }); - await willResolveKmsRequest; + await (timeoutContext?.csotEnabled() + ? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)]) + : willResolveKmsRequest); + } catch (error) { + if (error instanceof TimeoutError) + throw new MongoOperationTimeoutError('KMS request timed out'); + throw error; } finally { // There's no need for any more activity on this socket at this point. destroySockets(); } } - *requests(context: MongoCryptContext) { + *requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) { for ( let request = context.nextKMSRequest(); request != null; request = context.nextKMSRequest() ) { - yield this.kmsRequest(request); + yield this.kmsRequest(request, timeoutContext); } } @@ -498,7 +514,8 @@ export class StateMachine { async fetchCollectionInfo( client: MongoClient, ns: string, - filter: Document + filter: Document, + timeoutContext?: TimeoutContext ): Promise { const { db } = MongoDBCollectionNamespace.fromString(ns); @@ -506,7 +523,10 @@ export class StateMachine { .db(db) .listCollections(filter, { promoteLongs: false, - promoteValues: false + promoteValues: false, + ...(timeoutContext?.csotEnabled() + ? { timeoutMS: timeoutContext?.remainingTimeMS, timeoutMode: 'cursorLifetime' } + : {}) }) .toArray(); @@ -522,12 +542,22 @@ export class StateMachine { * @param command - The command to execute. * @param callback - Invoked with the serialized and marked bson command, or with an error */ - async markCommand(client: MongoClient, ns: string, command: Uint8Array): Promise { - const options = { promoteLongs: false, promoteValues: false }; + async markCommand( + client: MongoClient, + ns: string, + command: Uint8Array, + timeoutContext?: TimeoutContext + ): Promise { const { db } = MongoDBCollectionNamespace.fromString(ns); - const rawCommand = deserialize(command, options); + const bsonOptions = { promoteLongs: false, promoteValues: false }; + const rawCommand = deserialize(command, bsonOptions); - const response = await client.db(db).command(rawCommand, options); + const response = await client.db(db).command(rawCommand, { + ...bsonOptions, + ...(timeoutContext?.csotEnabled() + ? { timeoutMS: timeoutContext?.remainingTimeMS } + : undefined) + }); return serialize(response, this.bsonOptions); } @@ -543,7 +573,8 @@ export class StateMachine { fetchKeys( client: MongoClient, keyVaultNamespace: string, - filter: Uint8Array + filter: Uint8Array, + timeoutContext?: TimeoutContext ): Promise> { const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(keyVaultNamespace); @@ -551,7 +582,12 @@ export class StateMachine { return client .db(dbName) .collection(collectionName, { readConcern: { level: 'majority' } }) - .find(deserialize(filter)) + .find( + deserialize(filter), + timeoutContext?.csotEnabled() + ? { timeoutMS: timeoutContext?.remainingTimeMS, timeoutMode: 'cursorLifetime' } + : {} + ) .toArray(); } } diff --git a/src/sdam/server.ts b/src/sdam/server.ts index 08325086d5..7ab2d9a043 100644 --- a/src/sdam/server.ts +++ b/src/sdam/server.ts @@ -311,6 +311,10 @@ export class Server extends TypedEventEmitter { delete finalOptions.readPreference; } + if (this.description.iscryptd) { + finalOptions.omitMaxTimeMS = true; + } + const session = finalOptions.session; let conn = session?.pinnedConnection; diff --git a/test/integration/client-side-operations-timeout/client_side_operations_timeout.prose.test.ts b/test/integration/client-side-operations-timeout/client_side_operations_timeout.prose.test.ts index 09b95d6dff..80da92e10a 100644 --- a/test/integration/client-side-operations-timeout/client_side_operations_timeout.prose.test.ts +++ b/test/integration/client-side-operations-timeout/client_side_operations_timeout.prose.test.ts @@ -1,5 +1,7 @@ /* Specification prose tests */ +import { type ChildProcess, spawn } from 'node:child_process'; + import { expect } from 'chai'; import * as semver from 'semver'; import * as sinon from 'sinon'; @@ -16,7 +18,8 @@ import { MongoServerSelectionError, now, ObjectId, - promiseWithResolvers + promiseWithResolvers, + squashError } from '../../mongodb'; import { type FailPoint } from '../../tools/utils'; @@ -103,17 +106,55 @@ describe('CSOT spec prose tests', function () { }); }); - context.skip('2. maxTimeMS is not set for commands sent to mongocryptd', () => { - /** - * This test MUST only be run against enterprise server versions 4.2 and higher. - * - * 1. Launch a mongocryptd process on 23000. - * 1. Create a MongoClient (referred to as `client`) using the URI `mongodb://localhost:23000/?timeoutMS=1000`. - * 1. Using `client`, execute the `{ ping: 1 }` command against the `admin` database. - * 1. Verify via command monitoring that the `ping` command sent did not contain a `maxTimeMS` field. - */ - }); + context( + '2. maxTimeMS is not set for commands sent to mongocryptd', + { requires: { mongodb: '>=4.2' } }, + () => { + /** + * This test MUST only be run against enterprise server versions 4.2 and higher. + * + * 1. Launch a mongocryptd process on 23000. + * 1. Create a MongoClient (referred to as `client`) using the URI `mongodb://localhost:23000/?timeoutMS=1000`. + * 1. Using `client`, execute the `{ ping: 1 }` command against the `admin` database. + * 1. Verify via command monitoring that the `ping` command sent did not contain a `maxTimeMS` field. + */ + + let client: MongoClient; + const mongocryptdTestPort = '23000'; + let childProcess: ChildProcess; + + beforeEach(async function () { + childProcess = spawn('mongocryptd', ['--port', mongocryptdTestPort, '--ipv6'], { + stdio: 'ignore', + detached: true + }); + + childProcess.on('error', error => console.warn(this.currentTest?.fullTitle(), error)); + client = new MongoClient(`mongodb://localhost:${mongocryptdTestPort}/?timeoutMS=1000`, { + monitorCommands: true + }); + }); + + afterEach(async function () { + await client.close(); + childProcess.kill('SIGKILL'); + sinon.restore(); + }); + + it('maxTimeMS is not set', async function () { + const commandStarted = []; + client.on('commandStarted', ev => commandStarted.push(ev)); + await client + .db('admin') + .command({ ping: 1 }) + .catch(e => squashError(e)); + expect(commandStarted).to.have.lengthOf(1); + expect(commandStarted[0].command).to.not.have.property('maxTimeMS'); + }); + } + ); + // TODO(NODE-6391): Add timeoutMS support to Explicit Encryption context.skip('3. ClientEncryption', () => { /** * Each test under this category MUST only be run against server versions 4.4 and higher. In these tests, @@ -720,6 +761,30 @@ describe('CSOT spec prose tests', function () { 'TODO(NODE-6223): Auto connect performs extra server selection. Explicit connect throws on invalid host name'; }); + it.skip("timeoutMS honored for server selection if it's lower than serverSelectionTimeoutMS", async function () { + /** + * 1. Create a MongoClient (referred to as `client`) with URI `mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20`. + * 1. Using `client`, run the command `{ ping: 1 }` against the `admin` database. + * - Expect this to fail with a server selection timeout error after no more than 15ms. + */ + client = new MongoClient('mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20'); + const start = now(); + + const maybeError = await client + .db('test') + .admin() + .ping() + .then( + () => null, + e => e + ); + const end = now(); + + expect(maybeError).to.be.instanceof(MongoOperationTimeoutError); + expect(end - start).to.be.lte(15); + }).skipReason = + 'TODO(NODE-6223): Auto connect performs extra server selection. Explicit connect throws on invalid host name'; + it.skip("timeoutMS honored for server selection if it's lower than serverSelectionTimeoutMS", async function () { /** * 1. Create a MongoClient (referred to as `client`) with URI `mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20`. diff --git a/test/integration/client-side-operations-timeout/client_side_operations_timeout.unit.test.ts b/test/integration/client-side-operations-timeout/client_side_operations_timeout.unit.test.ts index 944d9b9604..7387099a7f 100644 --- a/test/integration/client-side-operations-timeout/client_side_operations_timeout.unit.test.ts +++ b/test/integration/client-side-operations-timeout/client_side_operations_timeout.unit.test.ts @@ -6,8 +6,22 @@ import { expect } from 'chai'; import * as sinon from 'sinon'; - -import { ConnectionPool, type MongoClient, Timeout, TimeoutContext, Topology } from '../../mongodb'; +import { setTimeout } from 'timers'; +import { TLSSocket } from 'tls'; +import { promisify } from 'util'; + +// eslint-disable-next-line @typescript-eslint/no-restricted-imports +import { StateMachine } from '../../../src/client-side-encryption/state_machine'; +import { + ConnectionPool, + CSOTTimeoutContext, + type MongoClient, + MongoOperationTimeoutError, + Timeout, + TimeoutContext, + Topology +} from '../../mongodb'; +import { createTimerSandbox } from '../../unit/timer_sandbox'; // TODO(NODE-5824): Implement CSOT prose tests describe('CSOT spec unit tests', function () { @@ -93,17 +107,83 @@ describe('CSOT spec unit tests', function () { }).skipReason = 'TODO(NODE-5682): Add CSOT support for socket read/write at the connection layer for CRUD APIs'; - context.skip('Client side encryption', function () { - context( - 'The remaining timeoutMS value should apply to HTTP requests against KMS servers for CSFLE.', - () => {} - ); + describe('Client side encryption', function () { + describe('KMS requests', function () { + const stateMachine = new StateMachine({} as any); + const request = { + addResponse: _response => {}, + status: { + type: 1, + code: 1, + message: 'notARealStatus' + }, + bytesNeeded: 500, + kmsProvider: 'notRealAgain', + endpoint: 'fake', + message: Buffer.from('foobar') + }; + + context('when StateMachine.kmsRequest() is passed a `CSOTimeoutContext`', function () { + beforeEach(async function () { + sinon.stub(TLSSocket.prototype, 'connect').callsFake(function (..._args) {}); + }); + + afterEach(async function () { + sinon.restore(); + }); + + it('the kms request times out through remainingTimeMS', async function () { + const timeoutContext = new CSOTTimeoutContext({ + timeoutMS: 500, + serverSelectionTimeoutMS: 30000 + }); + const err = await stateMachine.kmsRequest(request, timeoutContext).catch(e => e); + expect(err).to.be.instanceOf(MongoOperationTimeoutError); + expect(err.errmsg).to.equal('KMS request timed out'); + }); + }); + + context('when StateMachine.kmsRequest() is not passed a `CSOTimeoutContext`', function () { + let clock: sinon.SinonFakeTimers; + let timerSandbox: sinon.SinonSandbox; + + let sleep; + + beforeEach(async function () { + sinon.stub(TLSSocket.prototype, 'connect').callsFake(function (..._args) { + clock.tick(30000); + }); + timerSandbox = createTimerSandbox(); + clock = sinon.useFakeTimers(); + sleep = promisify(setTimeout); + }); + + afterEach(async function () { + if (clock) { + timerSandbox.restore(); + clock.restore(); + clock = undefined; + } + sinon.restore(); + }); + + it('the kms request does not timeout within 30 seconds', async function () { + const sleepingFn = async () => { + await sleep(30000); + throw Error('Slept for 30s'); + }; + + const err$ = Promise.all([stateMachine.kmsRequest(request), sleepingFn()]).catch(e => e); + clock.tick(30000); + const err = await err$; + expect(err.message).to.equal('Slept for 30s'); + }); + }); + }); - context( - 'The remaining timeoutMS value should apply to commands sent to mongocryptd as part of automatic encryption.', - () => {} - ); - }).skipReason = 'TODO(NODE-5686): Add CSOT support to client side encryption'; + // TODO(NODE-6390): Add timeoutMS support to Auto Encryption + it.skip('The remaining timeoutMS value should apply to commands sent to mongocryptd as part of automatic encryption.', () => {}); + }); context.skip('Background Connection Pooling', function () { context( diff --git a/test/unit/client-side-encryption/state_machine.test.ts b/test/unit/client-side-encryption/state_machine.test.ts index 77f3cf3a82..95bb605635 100644 --- a/test/unit/client-side-encryption/state_machine.test.ts +++ b/test/unit/client-side-encryption/state_machine.test.ts @@ -12,9 +12,17 @@ import * as tls from 'tls'; import { StateMachine } from '../../../src/client-side-encryption/state_machine'; // eslint-disable-next-line @typescript-eslint/no-restricted-imports import { Db } from '../../../src/db'; -// eslint-disable-next-line @typescript-eslint/no-restricted-imports -import { MongoClient } from '../../../src/mongo_client'; -import { Int32, Long, serialize } from '../../mongodb'; +import { + BSON, + Collection, + CSOTTimeoutContext, + Int32, + Long, + MongoClient, + serialize, + squashError +} from '../../mongodb'; +import { sleep } from '../../tools/utils'; describe('StateMachine', function () { class MockRequest implements MongoCryptKMSRequest { @@ -74,12 +82,10 @@ describe('StateMachine', function () { const options = { promoteLongs: false, promoteValues: false }; const serializedCommand = serialize(command); const stateMachine = new StateMachine({} as any); - // eslint-disable-next-line @typescript-eslint/no-empty-function - const callback = () => {}; context('when executing the command', function () { it('does not promote values', function () { - stateMachine.markCommand(clientStub, 'test.coll', serializedCommand, callback); + stateMachine.markCommand(clientStub, 'test.coll', serializedCommand); expect(runCommandStub.calledWith(command, options)).to.be.true; }); }); @@ -461,4 +467,129 @@ describe('StateMachine', function () { expect.fail('missed exception'); }); }); + + describe('CSOT', function () { + describe('#fetchKeys', function () { + const stateMachine = new StateMachine({} as any); + const client = new MongoClient('mongodb://localhost:27017'); + let findSpy; + + beforeEach(async function () { + findSpy = sinon.spy(Collection.prototype, 'find'); + }); + + afterEach(async function () { + sinon.restore(); + await client.close(); + }); + + context('when StateMachine.fetchKeys() is passed a `CSOTimeoutContext`', function () { + it('collection.find runs with its timeoutMS property set to remainingTimeMS', async function () { + const timeoutContext = new CSOTTimeoutContext({ + timeoutMS: 500, + serverSelectionTimeoutMS: 30000 + }); + await sleep(300); + await stateMachine + .fetchKeys(client, 'keyVault', BSON.serialize({ a: 1 }), timeoutContext) + .catch(e => squashError(e)); + expect(findSpy.getCalls()[0].args[1].timeoutMS).to.not.be.undefined; + expect(findSpy.getCalls()[0].args[1].timeoutMS).to.be.lessThanOrEqual(205); + }); + }); + + context('when StateMachine.fetchKeys() is not passed a `CSOTimeoutContext`', function () { + it('collection.find runs with an undefined timeoutMS property', async function () { + await stateMachine + .fetchKeys(client, 'keyVault', BSON.serialize({ a: 1 })) + .catch(e => squashError(e)); + expect(findSpy.getCalls()[0].args[1].timeoutMS).to.be.undefined; + }); + }); + }); + + describe('#markCommand', function () { + const stateMachine = new StateMachine({} as any); + const client = new MongoClient('mongodb://localhost:27017'); + let dbCommandSpy; + + beforeEach(async function () { + dbCommandSpy = sinon.spy(Db.prototype, 'command'); + }); + + afterEach(async function () { + sinon.restore(); + await client.close(); + }); + + context('when StateMachine.markCommand() is passed a `CSOTimeoutContext`', function () { + it('db.command runs with its timeoutMS property set to remainingTimeMS', async function () { + const timeoutContext = new CSOTTimeoutContext({ + timeoutMS: 500, + serverSelectionTimeoutMS: 30000 + }); + await sleep(300); + await stateMachine + .markCommand(client, 'keyVault', BSON.serialize({ a: 1 }), timeoutContext) + .catch(e => squashError(e)); + expect(dbCommandSpy.getCalls()[0].args[1].timeoutMS).to.not.be.undefined; + expect(dbCommandSpy.getCalls()[0].args[1].timeoutMS).to.be.lessThanOrEqual(205); + }); + }); + + context('when StateMachine.markCommand() is not passed a `CSOTimeoutContext`', function () { + it('db.command runs with an undefined timeoutMS property', async function () { + await stateMachine + .markCommand(client, 'keyVault', BSON.serialize({ a: 1 })) + .catch(e => squashError(e)); + expect(dbCommandSpy.getCalls()[0].args[1].timeoutMS).to.be.undefined; + }); + }); + }); + + describe('#fetchCollectionInfo', function () { + const stateMachine = new StateMachine({} as any); + const client = new MongoClient('mongodb://localhost:27017'); + let listCollectionsSpy; + + beforeEach(async function () { + listCollectionsSpy = sinon.spy(Db.prototype, 'listCollections'); + }); + + afterEach(async function () { + sinon.restore(); + await client.close(); + }); + + context( + 'when StateMachine.fetchCollectionInfo() is passed a `CSOTimeoutContext`', + function () { + it('listCollections runs with its timeoutMS property set to remainingTimeMS', async function () { + const timeoutContext = new CSOTTimeoutContext({ + timeoutMS: 500, + serverSelectionTimeoutMS: 30000 + }); + await sleep(300); + await stateMachine + .fetchCollectionInfo(client, 'keyVault', BSON.serialize({ a: 1 }), timeoutContext) + .catch(e => squashError(e)); + expect(listCollectionsSpy.getCalls()[0].args[1].timeoutMS).to.not.be.undefined; + expect(listCollectionsSpy.getCalls()[0].args[1].timeoutMS).to.be.lessThanOrEqual(205); + }); + } + ); + + context( + 'when StateMachine.fetchCollectionInfo() is not passed a `CSOTimeoutContext`', + function () { + it('listCollections runs with an undefined timeoutMS property', async function () { + await stateMachine + .fetchCollectionInfo(client, 'keyVault', BSON.serialize({ a: 1 })) + .catch(e => squashError(e)); + expect(listCollectionsSpy.getCalls()[0].args[1].timeoutMS).to.be.undefined; + }); + } + ); + }); + }); });