Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(NODE-5939): Implement 6.x: cache the AWS credentials provider in the MONGODB-AWS auth logic #3991

Merged
merged 21 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e50a7ea
feat(NODE-5616): cache the AWS credentials provider in the MONGODB-AW…
alenakhineika Feb 13, 2024
0f0143a
test: test aws instance per client
alenakhineika Feb 13, 2024
108967f
test: try with cashed provider
alenakhineika Feb 14, 2024
197e109
test: try to reset counter per test
alenakhineika Feb 14, 2024
c036433
test: check existing assertions
alenakhineika Feb 14, 2024
a9df63c
test: try without static
alenakhineika Feb 14, 2024
7d1dd7e
test: try to move client creation
alenakhineika Feb 15, 2024
3307cb1
test: move n check to a separate test
alenakhineika Feb 15, 2024
02e5b4b
Merge remote-tracking branch 'origin/main' into NODE-5616-cache-aws-c…
alenakhineika Feb 15, 2024
f67e257
refactor: store getAuthProvider in connection options
alenakhineika Feb 15, 2024
77d2066
Merge branch 'main' into NODE-5616-cache-aws-credentials
alenakhineika Feb 15, 2024
53c70a3
Merge branch 'main' into NODE-5616-cache-aws-credentials
alenakhineika Feb 16, 2024
480a49d
fix: import from mongodb
alenakhineika Feb 16, 2024
fbb5537
refactor: try with factory
alenakhineika Feb 16, 2024
40ea9a5
fix: save provider per client
alenakhineika Feb 16, 2024
74e9c80
refactor: clean up
alenakhineika Feb 16, 2024
3e5999c
refactor: move client auth providers to a separate module
alenakhineika Feb 19, 2024
d5e3140
Merge branch 'main' into NODE-5616-cache-aws-credentials
alenakhineika Feb 19, 2024
2cd3d60
docs: add comments
alenakhineika Feb 19, 2024
5c00cb6
Merge branch 'NODE-5616-cache-aws-credentials' of github.com:mongodb/…
alenakhineika Feb 19, 2024
deeacf5
Merge branch 'main' into NODE-5616-cache-aws-credentials
durran Feb 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/cmap/auth/auth_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ export class AuthContext {
}
}

/**
* Provider used during authentication.
* @internal
*/
export abstract class AuthProvider {
/**
* Prepare the handshake document before the initial handshake.
Expand Down
115 changes: 60 additions & 55 deletions src/cmap/auth/mongodb_aws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { promisify } from 'util';

import type { Binary, BSONSerializeOptions } from '../../bson';
import * as BSON from '../../bson';
import { aws4, getAwsCredentialProvider } from '../../deps';
import { aws4, type AWSCredentials, getAwsCredentialProvider } from '../../deps';
import {
MongoAWSError,
MongoCompatibilityError,
Expand Down Expand Up @@ -57,12 +57,42 @@ interface AWSSaslContinuePayload {
}

export class MongoDBAWS extends AuthProvider {
static credentialProvider: ReturnType<typeof getAwsCredentialProvider> | null = null;
static credentialProvider: ReturnType<typeof getAwsCredentialProvider>;
provider?: () => Promise<AWSCredentials>;
randomBytesAsync: (size: number) => Promise<Buffer>;

constructor() {
super();
this.randomBytesAsync = promisify(crypto.randomBytes);
MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

if ('fromNodeProviderChain' in MongoDBAWS.credentialProvider) {
this.provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();
}
}

override async auth(authContext: AuthContext): Promise<void> {
Expand All @@ -83,7 +113,7 @@ export class MongoDBAWS extends AuthProvider {
}

if (!authContext.credentials.username) {
authContext.credentials = await makeTempCredentials(authContext.credentials);
authContext.credentials = await makeTempCredentials(authContext.credentials, this.provider);
}

const { credentials } = authContext;
Expand Down Expand Up @@ -181,7 +211,10 @@ interface AWSTempCredentials {
Expiration?: Date;
}

async function makeTempCredentials(credentials: MongoCredentials): Promise<MongoCredentials> {
async function makeTempCredentials(
credentials: MongoCredentials,
provider?: () => Promise<AWSCredentials>
): Promise<MongoCredentials> {
function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) {
if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) {
throw new MongoMissingCredentialsError('Could not obtain temporary MONGODB-AWS credentials');
Expand All @@ -198,11 +231,31 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});
}

MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

// Check if the AWS credential provider from the SDK is present. If not,
// use the old method.
if ('kModuleError' in MongoDBAWS.credentialProvider) {
if (provider && !('kModuleError' in MongoDBAWS.credentialProvider)) {
/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
*
* - Environment variables exposed via process.env
* - SSO credentials from token cache
* - Web identity token credentials
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
AccessKeyId: creds.accessKeyId,
SecretAccessKey: creds.secretAccessKey,
Token: creds.sessionToken,
Expiration: creds.expiration
});
} catch (error) {
throw new MongoAWSError(error.message);
}
} else {
// If the environment variable AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
// is set then drivers MUST assume that it was set by an AWS ECS agent
if (process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI) {
Expand Down Expand Up @@ -232,54 +285,6 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});

return makeMongoCredentialsFromAWSTemp(creds);
} else {
let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

const provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();

/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
*
* - Environment variables exposed via process.env
* - SSO credentials from token cache
* - Web identity token credentials
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
AccessKeyId: creds.accessKeyId,
SecretAccessKey: creds.secretAccessKey,
Token: creds.sessionToken,
Expiration: creds.expiration
});
} catch (error) {
throw new MongoAWSError(error.message);
}
}
}

Expand Down
39 changes: 13 additions & 26 deletions src/cmap/connect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,10 @@ import {
MongoRuntimeError,
needsRetryableWriteLabel
} from '../error';
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
import { HostAddress, ns, promiseWithResolvers } from '../utils';
import { AuthContext, type AuthProvider } from './auth/auth_provider';
import { GSSAPI } from './auth/gssapi';
import { MongoCR } from './auth/mongocr';
import { MongoDBAWS } from './auth/mongodb_aws';
import { MongoDBOIDC } from './auth/mongodb_oidc';
import { Plain } from './auth/plain';
import { AuthContext } from './auth/auth_provider';
import { AuthMechanism } from './auth/providers';
import { ScramSHA1, ScramSHA256 } from './auth/scram';
import { X509 } from './auth/x509';
import {
type CommandOptions,
Connection,
Expand All @@ -40,18 +34,6 @@ import {
MIN_SUPPORTED_WIRE_VERSION
} from './wire_protocol/constants';

/** @internal */
export const AUTH_PROVIDERS = new Map<AuthMechanism | string, AuthProvider>([
[AuthMechanism.MONGODB_AWS, new MongoDBAWS()],
[AuthMechanism.MONGODB_CR, new MongoCR()],
[AuthMechanism.MONGODB_GSSAPI, new GSSAPI()],
[AuthMechanism.MONGODB_OIDC, new MongoDBOIDC()],
[AuthMechanism.MONGODB_PLAIN, new Plain()],
[AuthMechanism.MONGODB_SCRAM_SHA1, new ScramSHA1()],
[AuthMechanism.MONGODB_SCRAM_SHA256, new ScramSHA256()],
[AuthMechanism.MONGODB_X509, new X509()]
]);

/** @public */
export type Stream = Socket | TLSSocket;

Expand Down Expand Up @@ -111,7 +93,7 @@ export async function performInitialHandshake(
if (credentials) {
if (
!(credentials.mechanism === AuthMechanism.MONGODB_DEFAULT) &&
!AUTH_PROVIDERS.get(credentials.mechanism)
!options.authProviders.getOrCreateProvider(credentials.mechanism)
) {
throw new MongoInvalidArgumentError(`AuthMechanism '${credentials.mechanism}' not supported`);
}
Expand All @@ -120,7 +102,7 @@ export async function performInitialHandshake(
const authContext = new AuthContext(conn, credentials, options);
conn.authContext = authContext;

const handshakeDoc = await prepareHandshakeDocument(authContext);
const handshakeDoc = await prepareHandshakeDocument(authContext, options.authProviders);

// @ts-expect-error: TODO(NODE-5141): The options need to be filtered properly, Connection options differ from Command options
const handshakeOptions: CommandOptions = { ...options };
Expand Down Expand Up @@ -166,7 +148,7 @@ export async function performInitialHandshake(
authContext.response = response;

const resolvedCredentials = credentials.resolveAuthMechanism(response);
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
const provider = options.authProviders.getOrCreateProvider(resolvedCredentials.mechanism);
if (!provider) {
throw new MongoInvalidArgumentError(
`No AuthProvider for ${resolvedCredentials.mechanism} defined.`
Expand All @@ -191,6 +173,10 @@ export async function performInitialHandshake(
conn.established = true;
}

/**
* HandshakeDocument used during authentication.
* @internal
*/
export interface HandshakeDocument extends Document {
/**
* @deprecated Use hello instead
Expand All @@ -210,7 +196,8 @@ export interface HandshakeDocument extends Document {
* This function is only exposed for testing purposes.
*/
export async function prepareHandshakeDocument(
authContext: AuthContext
authContext: AuthContext,
authProviders: MongoClientAuthProviders
): Promise<HandshakeDocument> {
const options = authContext.options;
const compressors = options.compressors ? options.compressors : [];
Expand All @@ -232,7 +219,7 @@ export async function prepareHandshakeDocument(
if (credentials.mechanism === AuthMechanism.MONGODB_DEFAULT && credentials.username) {
handshakeDoc.saslSupportedMechs = `${credentials.source}.${credentials.username}`;

const provider = AUTH_PROVIDERS.get(AuthMechanism.MONGODB_SCRAM_SHA256);
const provider = authProviders.getOrCreateProvider(AuthMechanism.MONGODB_SCRAM_SHA256);
if (!provider) {
// This auth mechanism is always present.
throw new MongoInvalidArgumentError(
Expand All @@ -241,7 +228,7 @@ export async function prepareHandshakeDocument(
}
return provider.prepare(handshakeDoc, authContext);
}
const provider = AUTH_PROVIDERS.get(credentials.mechanism);
const provider = authProviders.getOrCreateProvider(credentials.mechanism);
if (!provider) {
throw new MongoInvalidArgumentError(`No AuthProvider for ${credentials.mechanism} defined.`);
}
Expand Down
3 changes: 3 additions & 0 deletions src/cmap/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
MongoWriteConcernError
} from '../error';
import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client';
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger';
import { type CancellationToken, TypedEventEmitter } from '../mongo_types';
import type { ReadPreferenceLike } from '../read_preference';
Expand Down Expand Up @@ -109,6 +110,8 @@ export interface ConnectionOptions
/** @internal */
connectionType?: any;
credentials?: MongoCredentials;
/** @internal */
authProviders: MongoClientAuthProviders;
connectTimeoutMS?: number;
tls: boolean;
noDelay?: boolean;
Expand Down
9 changes: 6 additions & 3 deletions src/cmap/connection_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {
import { CancellationToken, TypedEventEmitter } from '../mongo_types';
import type { Server } from '../sdam/server';
import { type Callback, eachAsync, List, makeCounter, TimeoutController } from '../utils';
import { AUTH_PROVIDERS, connect } from './connect';
import { connect } from './connect';
import { Connection, type ConnectionEvents, type ConnectionOptions } from './connection';
import {
ConnectionCheckedInEvent,
Expand Down Expand Up @@ -622,7 +622,9 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
);
}
const resolvedCredentials = credentials.resolveAuthMechanism(connection.hello);
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
const provider = this[kServer].topology.client.s.authProviders.getOrCreateProvider(
resolvedCredentials.mechanism
);
if (!provider) {
return callback(
new MongoMissingCredentialsError(
Expand Down Expand Up @@ -700,7 +702,8 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
id: this[kConnectionCounter].next().value,
generation: this[kGeneration],
cancellationToken: this[kCancellationToken],
mongoLogger: this.mongoLogger
mongoLogger: this.mongoLogger,
authProviders: this[kServer].topology.client.s.authProviders
};

this[kPending]++;
Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ export type {
CSFLEKMSTlsOptions,
StateMachineExecutable
} from './client-side-encryption/state_machine';
export type { AuthContext } from './cmap/auth/auth_provider';
export type { AuthContext, AuthProvider } from './cmap/auth/auth_provider';
export type {
AuthMechanismProperties,
MongoCredentials,
Expand All @@ -268,6 +268,7 @@ export type {
OpResponseOptions,
WriteProtocolMessageType
} from './cmap/commands';
export type { HandshakeDocument } from './cmap/connect';
export type { LEGAL_TCP_SOCKET_OPTIONS, LEGAL_TLS_SOCKET_OPTIONS, Stream } from './cmap/connect';
export type {
CommandOptions,
Expand Down Expand Up @@ -365,6 +366,7 @@ export type {
SupportedTLSSocketOptions,
WithSessionCallback
} from './mongo_client';
export { MongoClientAuthProviders } from './mongo_client_auth_providers';
export type {
Log,
LogComponentSeveritiesClientOptions,
Expand Down
8 changes: 6 additions & 2 deletions src/mongo_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { MONGO_CLIENT_EVENTS } from './constants';
import { Db, type DbOptions } from './db';
import type { Encrypter } from './encrypter';
import { MongoInvalidArgumentError } from './error';
import { MongoClientAuthProviders } from './mongo_client_auth_providers';
import {
type LogComponentSeveritiesClientOptions,
type MongoDBLogWritable,
Expand Down Expand Up @@ -297,6 +298,7 @@ export interface MongoClientPrivate {
bsonOptions: BSONSerializeOptions;
namespace: MongoDBNamespace;
hasBeenClosed: boolean;
authProviders: MongoClientAuthProviders;
/**
* We keep a reference to the sessions that are acquired from the pool.
* - used to track and close all sessions in client.close() (which is non-standard behavior)
Expand All @@ -319,6 +321,7 @@ export type MongoClientEvents = Pick<TopologyEvents, (typeof MONGO_CLIENT_EVENTS
};

/** @internal */

const kOptions = Symbol('options');

/**
Expand Down Expand Up @@ -379,6 +382,7 @@ export class MongoClient extends TypedEventEmitter<MongoClientEvents> {
hasBeenClosed: false,
sessionPool: new ServerSessionPool(this),
activeSessions: new Set(),
authProviders: new MongoClientAuthProviders(),

get options() {
return client[kOptions];
Expand Down Expand Up @@ -829,10 +833,10 @@ export interface MongoOptions
proxyUsername?: string;
proxyPassword?: string;
serverMonitoringMode: ServerMonitoringMode;

/** @internal */
connectionType?: typeof Connection;

/** @internal */
authProviders: MongoClientAuthProviders;
/** @internal */
encrypter: Encrypter;
/** @internal */
Expand Down
Loading
Loading