Skip to content

Commit

Permalink
introduce KMSCredentialProvider abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
baileympearson committed Mar 26, 2024
1 parent 46e0648 commit 00a4755
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 31 deletions.
16 changes: 9 additions & 7 deletions src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { MongoDBCollectionNamespace } from '../utils';
import * as cryptoCallbacks from './crypto_callbacks';
import { MongoCryptInvalidArgumentError } from './errors';
import { MongocryptdManager } from './mongocryptd_manager';
import { type KMSProviders, refreshKMSCredentials } from './providers';
import { KMSCredentialProvider, type KMSProviders } from './providers';
import { type CSFLEKMSTlsOptions, StateMachine } from './state_machine';

/** @public */
Expand Down Expand Up @@ -233,7 +233,6 @@ export class AutoEncrypter {
_metaDataClient: MongoClient;
_proxyOptions: ProxyOptions;
_tlsOptions: CSFLEKMSTlsOptions;
_kmsProviders: KMSProviders;
_bypassMongocryptdAndCryptShared: boolean;
_contextCounter: number;

Expand All @@ -252,6 +251,7 @@ export class AutoEncrypter {
* fields were decrypted.
*/
[kDecorateResult] = false;
_credentialProvider: KMSCredentialProvider;

/** @internal */
static getMongoCrypt(): MongoCryptConstructor {
Expand Down Expand Up @@ -319,7 +319,7 @@ export class AutoEncrypter {
this._metaDataClient = options.metadataClient || client;
this._proxyOptions = options.proxyOptions || {};
this._tlsOptions = options.tlsOptions || {};
this._kmsProviders = options.kmsProviders || {};
const kmsProviders = options.kmsProviders || {};

const mongoCryptOptions: MongoCryptOptions = {
cryptoCallbacks
Expand All @@ -336,9 +336,9 @@ export class AutoEncrypter {
: (serialize(options.encryptedFieldsMap) as Buffer);
}

mongoCryptOptions.kmsProviders = !Buffer.isBuffer(this._kmsProviders)
? (serialize(this._kmsProviders) as Buffer)
: this._kmsProviders;
mongoCryptOptions.kmsProviders = !Buffer.isBuffer(kmsProviders)
? (serialize(kmsProviders) as Buffer)
: kmsProviders;

if (options.options?.logger) {
mongoCryptOptions.logger = options.options.logger;
Expand Down Expand Up @@ -389,6 +389,8 @@ export class AutoEncrypter {

this._mongocryptdClient = new MongoClient(this._mongocryptdManager.uri, clientOptions);
}

this._credentialProvider = new KMSCredentialProvider(kmsProviders);
}

/**
Expand Down Expand Up @@ -502,7 +504,7 @@ export class AutoEncrypter {
* the original ones.
*/
async askForKMSCredentials(): Promise<KMSProviders> {
return refreshKMSCredentials(this._kmsProviders);
return this._credentialProvider.refreshCredentials();
}

/**
Expand Down
18 changes: 9 additions & 9 deletions src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import {
} from './errors';
import {
type ClientEncryptionDataKeyProvider,
type KMSProviders,
refreshKMSCredentials
KMSCredentialProvider,
type KMSProviders
} from './providers/index';
import { type CSFLEKMSTlsOptions, StateMachine } from './state_machine';

Expand Down Expand Up @@ -61,8 +61,7 @@ export class ClientEncryption {
/** @internal */
_tlsOptions: CSFLEKMSTlsOptions;
/** @internal */
_kmsProviders: KMSProviders;

_credentialProvider: KMSCredentialProvider;
/** @internal */
_mongoCrypt: MongoCrypt;

Expand Down Expand Up @@ -107,7 +106,7 @@ export class ClientEncryption {
this._client = client;
this._proxyOptions = options.proxyOptions ?? {};
this._tlsOptions = options.tlsOptions ?? {};
this._kmsProviders = options.kmsProviders || {};
const kmsProviders = options.kmsProviders || {};

if (options.keyVaultNamespace == null) {
throw new MongoCryptInvalidArgumentError('Missing required option `keyVaultNamespace`');
Expand All @@ -116,15 +115,16 @@ export class ClientEncryption {
const mongoCryptOptions: MongoCryptOptions = {
...options,
cryptoCallbacks,
kmsProviders: !Buffer.isBuffer(this._kmsProviders)
? (serialize(this._kmsProviders) as Buffer)
: this._kmsProviders
kmsProviders: !Buffer.isBuffer(kmsProviders)
? (serialize(kmsProviders) as Buffer)
: kmsProviders
};

this._keyVaultNamespace = options.keyVaultNamespace;
this._keyVaultClient = options.keyVaultClient || client;
const MongoCrypt = ClientEncryption.getMongoCrypt();
this._mongoCrypt = new MongoCrypt(mongoCryptOptions);
this._credentialProvider = new KMSCredentialProvider(kmsProviders);
}

/**
Expand Down Expand Up @@ -654,7 +654,7 @@ export class ClientEncryption {
* the original ones.
*/
async askForKMSCredentials(): Promise<KMSProviders> {
return refreshKMSCredentials(this._kmsProviders);
return this._credentialProvider.refreshCredentials();
}

static get libmongocryptVersion() {
Expand Down
35 changes: 20 additions & 15 deletions src/client-side-encryption/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,25 +144,30 @@ export function isEmptyCredentials(
}

/**
* Load cloud provider credentials for the user provided KMS providers.
* Credentials will only attempt to get loaded if they do not exist
* and no existing credentials will get overwritten.
*
* @internal
*/
export async function refreshKMSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
let finalKMSProviders = kmsProviders;
export class KMSCredentialProvider {
constructor(private readonly kmsProviders: KMSProviders) {}

if (isEmptyCredentials('aws', kmsProviders)) {
finalKMSProviders = await loadAWSCredentials(finalKMSProviders);
}
/**
* Load cloud provider credentials for the user provided KMS providers.
* Credentials will only attempt to get loaded if they do not exist
* and no existing credentials will get overwritten.
*/
async refreshCredentials() {
let finalKMSProviders = this.kmsProviders;

if (isEmptyCredentials('gcp', kmsProviders)) {
finalKMSProviders = await loadGCPCredentials(finalKMSProviders);
}
if (isEmptyCredentials('aws', this.kmsProviders)) {
finalKMSProviders = await loadAWSCredentials(finalKMSProviders);
}

if (isEmptyCredentials('gcp', this.kmsProviders)) {
finalKMSProviders = await loadGCPCredentials(finalKMSProviders);
}

if (isEmptyCredentials('azure', kmsProviders)) {
finalKMSProviders = await loadAzureCredentials(finalKMSProviders);
if (isEmptyCredentials('azure', this.kmsProviders)) {
finalKMSProviders = await loadAzureCredentials(finalKMSProviders);
}
return finalKMSProviders;
}
return finalKMSProviders;
}

0 comments on commit 00a4755

Please sign in to comment.