Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(NODE-6051): only provide expected allowed keys to libmongocrypt after fetching aws kms credentials #4057

Merged
merged 9 commits into from
Apr 4, 2024
27 changes: 17 additions & 10 deletions src/client-side-encryption/providers/aws.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import { getAwsCredentialProvider } from '../../deps';
import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
import { type KMSProviders } from '.';

/**
* @internal
*/
export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
const credentialProvider = getAwsCredentialProvider();
const credentialProvider = new AWSSDKCredentialProvider();

if ('kModuleError' in credentialProvider) {
return kmsProviders;
}
// We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey`
// or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings
// and let libmongocrypt error if we're unable to fetch the required keys.
const {
SecretAccessKey = '',
AccessKeyId = '',
Token
} = await credentialProvider.getCredentials();
const aws: NonNullable<KMSProviders['aws']> = {
secretAccessKey: SecretAccessKey,
accessKeyId: AccessKeyId
};
// the AWS session token is only required for temporary credentials so only attach it to the
// result if it's present in the response from the aws sdk
Token != null && (aws.sessionToken = Token);

const { fromNodeProviderChain } = credentialProvider;
const provider = fromNodeProviderChain();
// The state machine is the only place calling this so it will
// catch if there is a rejection here.
const aws = await provider();
return { ...kmsProviders, aws };
}
73 changes: 60 additions & 13 deletions test/integration/auth/mongodb_aws.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,26 @@ import * as http from 'http';
import { performance } from 'perf_hooks';
import * as sinon from 'sinon';

// eslint-disable-next-line @typescript-eslint/no-restricted-imports
import { refreshKMSCredentials } from '../../../src/client-side-encryption/providers';
import {
AWSTemporaryCredentialProvider,
MongoAWSError,
type MongoClient,
MongoDBAWS,
MongoMissingCredentialsError,
MongoServerError
MongoServerError,
setDifference
} from '../../mongodb';

function awsSdk() {
try {
return require('@aws-sdk/credential-providers');
} catch {
return null;
}
}
const isMongoDBAWSAuthEnvironment = (process.env.MONGODB_URI ?? '').includes('MONGODB-AWS');

describe('MONGODB-AWS', function () {
let awsSdkPresent;
let client: MongoClient;

beforeEach(function () {
const MONGODB_URI = process.env.MONGODB_URI;
if (!MONGODB_URI || MONGODB_URI.indexOf('MONGODB-AWS') === -1) {
if (!isMongoDBAWSAuthEnvironment) {
this.currentTest.skipReason = 'requires MONGODB_URI to contain MONGODB-AWS auth mechanism';
return this.skip();
}
Expand All @@ -39,7 +35,7 @@ describe('MONGODB-AWS', function () {
`Always inform the AWS tests if they run with or without the SDK (MONGODB_AWS_SDK=${MONGODB_AWS_SDK})`
).to.include(MONGODB_AWS_SDK);

awsSdkPresent = !!awsSdk();
awsSdkPresent = AWSTemporaryCredentialProvider.isAWSSDKInstalled;
expect(
awsSdkPresent,
MONGODB_AWS_SDK === 'true'
Expand Down Expand Up @@ -244,8 +240,10 @@ describe('MONGODB-AWS', function () {

const envCheck = () => {
const { AWS_WEB_IDENTITY_TOKEN_FILE = '' } = process.env;
credentialProvider = awsSdk();
return AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 || credentialProvider == null;
return (
AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 ||
!AWSTemporaryCredentialProvider.isAWSSDKInstalled
);
};

beforeEach(function () {
Expand All @@ -255,6 +253,9 @@ describe('MONGODB-AWS', function () {
return this.skip();
}

// @ts-expect-error We intentionally access a protected variable.
credentialProvider = AWSTemporaryCredentialProvider.awsSDK;

storedEnv = process.env;
if (test.env.AWS_STS_REGIONAL_ENDPOINTS === undefined) {
delete process.env.AWS_STS_REGIONAL_ENDPOINTS;
Expand Down Expand Up @@ -324,3 +325,49 @@ describe('MONGODB-AWS', function () {
}
});
});

describe('AWS KMS Credential Fetching', function () {
context('when the AWS SDK is not installed', function () {
beforeEach(function () {
this.currentTest.skipReason = !isMongoDBAWSAuthEnvironment
? 'Test must run in an AWS auth testing environment'
: AWSTemporaryCredentialProvider.isAWSSDKInstalled
? 'This test must run in an environment where the AWS SDK is not installed.'
: undefined;
this.currentTest?.skipReason && this.skip();
});
it('fetching AWS KMS credentials throws an error', async function () {
const error = await refreshKMSCredentials({ aws: {} }).catch(e => e);
expect(error).to.be.instanceOf(MongoAWSError);
});
});

context('when the AWS SDK is installed', function () {
beforeEach(function () {
this.currentTest.skipReason = !isMongoDBAWSAuthEnvironment
? 'Test must run in an AWS auth testing environment'
: !AWSTemporaryCredentialProvider.isAWSSDKInstalled
? 'This test must run in an environment where the AWS SDK is installed.'
: undefined;
this.currentTest?.skipReason && this.skip();
});
it('KMS credentials are successfully fetched.', async function () {
const { aws } = await refreshKMSCredentials({ aws: {} });

expect(aws).to.have.property('accessKeyId');
expect(aws).to.have.property('secretAccessKey');
});

it('does not return any extra keys for the `aws` credential provider', async function () {
const { aws } = await refreshKMSCredentials({ aws: {} });

const keys = new Set(Object.keys(aws ?? {}));
const allowedKeys = ['accessKeyId', 'secretAccessKey', 'sessionToken'];

expect(
Array.from(setDifference(keys, allowedKeys)),
'received an unexpected key in the response refreshing KMS credentials'
).to.deep.equal([]);
});
});
});
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { expect } from 'chai';
import * as dns from 'dns';
import { once } from 'events';
import { coerce } from 'semver';
import { satisfies } from 'semver';
import * as sinon from 'sinon';

import {
Expand Down Expand Up @@ -51,11 +51,9 @@ describe('Polling Srv Records for Mongos Discovery', () => {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const test = this.currentTest!;

const { major } = coerce(process.version);
test.skipReason =
major === 18 || major === 20
? 'TODO(NODE-5666): fix failing unit tests on Node18'
: undefined;
test.skipReason = satisfies(process.version, '>=18.0.0')
? `TODO(NODE-5666): fix failing unit tests on Node18 (Running with Nodejs ${process.version})`
: undefined;

if (test.skipReason) this.skip();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import {
} from '../../../../src/client-side-encryption/providers/azure';
// eslint-disable-next-line @typescript-eslint/no-restricted-imports
import * as utils from '../../../../src/client-side-encryption/providers/utils';
// eslint-disable-next-line @typescript-eslint/no-restricted-imports
import { AWSSDKCredentialProvider } from '../../../../src/cmap/auth/aws_temporary_credentials';
import * as requirements from '../requirements.helper';

const originalAccessKeyId = process.env.AWS_ACCESS_KEY_ID;
Expand Down Expand Up @@ -154,25 +156,25 @@ describe('#refreshKMSCredentials', function () {
});
});

context('when the sdk is not installed', function () {
W-A-James marked this conversation as resolved.
Show resolved Hide resolved
const kmsProviders = {
local: {
key: Buffer.alloc(96)
},
aws: {}
};

before(function () {
if (requirements.credentialProvidersInstalled.aws && this.currentTest) {
this.currentTest.skipReason = 'Credentials will be loaded when sdk present';
this.currentTest.skip();
return;
}
context('when the AWS SDK returns unknown fields', function () {
beforeEach(() => {
sinon.stub(AWSSDKCredentialProvider.prototype, 'getCredentials').resolves({
Token: 'example',
SecretAccessKey: 'example',
AccessKeyId: 'example',
Expiration: new Date()
});
});

it('does not refresh credentials', async function () {
const providers = await refreshKMSCredentials(kmsProviders);
expect(providers).to.deep.equal(kmsProviders);
afterEach(() => sinon.restore());
it('only returns fields libmongocrypt expects', async function () {
const credentials = await refreshKMSCredentials({ aws: {} });
expect(credentials).to.deep.equal({
aws: {
accessKeyId: accessKey,
secretAccessKey: secretKey,
sessionToken: sessionToken
}
});
});
});
});
Expand Down
7 changes: 3 additions & 4 deletions test/unit/connection_string.spec.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { coerce } from 'semver';
import { satisfies } from 'semver';

import { loadSpecTests } from '../spec';
import { executeUriValidationTest } from '../tools/uri_spec_runner';
Expand All @@ -15,14 +15,13 @@ describe('Connection String spec tests', function () {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const test = this.currentTest!;

const { major } = coerce(process.version);
const skippedTests = [
'Invalid port (zero) with IP literal',
'Invalid port (zero) with hostname'
];
test.skipReason =
major === 20 && skippedTests.includes(test.title)
? 'TODO(NODE-5666): fix failing unit tests on Node18'
satisfies(process.version, '>=20.0.0') && skippedTests.includes(test.title)
? 'TODO(NODE-5666): fix failing unit tests on Node20+'
: undefined;

if (test.skipReason) this.skip();
Expand Down
5 changes: 2 additions & 3 deletions test/unit/sdam/monitor.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { once } from 'node:events';
import * as net from 'node:net';

import { expect } from 'chai';
import { coerce } from 'semver';
import { satisfies } from 'semver';
import * as sinon from 'sinon';
import { setTimeout } from 'timers';
import { setTimeout as setTimeoutPromise } from 'timers/promises';
Expand Down Expand Up @@ -57,7 +57,6 @@ describe('monitoring', function () {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const test = this.currentTest!;

const { major } = coerce(process.version);
const failingTests = [
'should connect and issue an initial server check',
'should ignore attempts to connect when not already closed',
Expand All @@ -67,7 +66,7 @@ describe('monitoring', function () {
'correctly returns the mean of the heartbeat durations'
];
test.skipReason =
(major === 18 || major === 20) && failingTests.includes(test.title)
satisfies(process.version, '>=18.0.0') && failingTests.includes(test.title)
? 'TODO(NODE-5666): fix failing unit tests on Node18'
: undefined;

Expand Down
10 changes: 4 additions & 6 deletions test/unit/sdam/topology.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { expect } from 'chai';
import { once } from 'events';
import * as net from 'net';
import { type AddressInfo } from 'net';
import { coerce, type SemVer } from 'semver';
import { satisfies } from 'semver';
import * as sinon from 'sinon';
import { clearTimeout } from 'timers';

Expand Down Expand Up @@ -284,11 +284,9 @@ describe('Topology (unit)', function () {
it('should encounter a server selection timeout on garbled server responses', function () {
const test = this.test;

const { major } = coerce(process.version) as SemVer;
test.skipReason =
major === 18 || major === 20
? 'TODO(NODE-5666): fix failing unit tests on Node18'
: undefined;
test.skipReason = satisfies(process.version, '>=18.0.0')
? 'TODO(NODE-5666): fix failing unit tests on Node18'
: undefined;

if (test.skipReason) this.skip();

Expand Down
Loading