Skip to content

Commit

Permalink
fix(sdk): Disable concurrency on rewrap (#388)
Browse files Browse the repository at this point in the history
* fix(sdk): Disable concurrency on rewrap

- Adds new `concurrencyLimit` decrypt param, which sets a thread pool (kinda)
- Defaults value to 1

* remove debug logs

* simplify nested ternary
  • Loading branch information
dmihalcik-virtru authored Nov 15, 2024
1 parent 8b1de24 commit beb3c06
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 61 deletions.
89 changes: 89 additions & 0 deletions lib/src/concurrency.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
type LabelledSuccess<T> = { lid: string; value: Promise<T> };
type LabelledFailure = { lid: string; e: any };

async function labelPromise<T>(label: string, promise: Promise<T>): Promise<LabelledSuccess<T>> {
try {
const value = await promise;
return { lid: label, value: Promise.resolve(value) };
} catch (e) {
throw { lid: label, e };
}
}

// Pooled variant of Promise.all; implements most of the logic of the real all,
// but with a pool size of n. Rejects on first reject, or returns a list
// of all successful responses. Operates with at most n 'active' promises at a time.
// For tracking purposes, all promises must have a unique identifier.
export async function allPool<T>(n: number, p: Record<string, Promise<T>>): Promise<Awaited<T>[]> {
const pool: Record<string, Promise<LabelledSuccess<T>>> = {};
const resolved: Awaited<T>[] = [];
for (const [id, job] of Object.entries(p)) {
// while the size of jobs to do is greater than n,
// let n jobs run and take the first one to finish out of the pool
pool[id] = labelPromise(id, job);
if (Object.keys(pool).length > n - 1) {
const promises = Object.values(pool);
try {
const { lid, value } = await Promise.race(promises);
resolved.push(await value);
delete pool[lid];
} catch (err) {
const { e } = err as LabelledFailure;
throw e;
}
}
}
try {
for (const labelled of await Promise.all(Object.values(pool))) {
resolved.push(await labelled.value);
}
} catch (err) {
if ('lid' in err && 'e' in err) {
throw err.e;
} else {
throw err;
}
}
return resolved;
}

// Pooled variant of promise.any; implements most of the logic of the real any,
// but with a pool size of n, and returns the first successful promise,
// operating with at most n 'active' promises at a time.
export async function anyPool<T>(n: number, p: Record<string, Promise<T>>): Promise<Awaited<T>> {
const pool: Record<string, Promise<LabelledSuccess<T>>> = {};
const rejections = [];
for (const [id, job] of Object.entries(p)) {
// while the size of jobs to do is greater than n,
// let n jobs run and take the first one to finish out of the pool
pool[id] = labelPromise(id, job);
if (Object.keys(pool).length > n - 1) {
const promises = Object.values(pool);
try {
const { value } = await Promise.race(promises);
return await value;
} catch (error) {
const { lid, e } = error;
rejections.push(e);
delete pool[lid];
}
}
}
try {
const { value } = await Promise.any(Object.values(pool));
return await value;
} catch (errors) {
if (errors instanceof AggregateError) {
for (const error of errors.errors) {
if ('lid' in error && 'e' in error) {
rejections.push(error.e);
} else {
rejections.push(error);
}
}
} else {
rejections.push(errors);
}
}
throw new AggregateError(rejections);
}
6 changes: 6 additions & 0 deletions lib/tdf3/src/client/builders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ export type DecryptParams = {
keyMiddleware?: DecryptKeyMiddleware;
streamMiddleware?: DecryptStreamMiddleware;
assertionVerificationKeys?: AssertionVerificationKeys;
concurrencyLimit?: number;
noVerifyAssertions?: boolean;
};

Expand Down Expand Up @@ -685,6 +686,11 @@ class DecryptParamsBuilder {
return freeze({ ..._params });
}

withConcurrencyLimit(limit: number): DecryptParamsBuilder {
this._params.concurrencyLimit = limit;
return this;
}

/**
* Generate a parameters object in the form expected by <code>{@link Client#decrypt|decrypt}</code>.
* <br/><br/>
Expand Down
2 changes: 2 additions & 0 deletions lib/tdf3/src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ export class Client {
streamMiddleware = async (stream: DecoratedReadableStream) => stream,
assertionVerificationKeys,
noVerifyAssertions,
concurrencyLimit = 1,
}: DecryptParams): Promise<DecoratedReadableStream> {
const dpopKeys = await this.dpopKeys;
let entityObject;
Expand All @@ -587,6 +588,7 @@ export class Client {
allowList: this.allowedKases,
authProvider: this.authProvider,
chunker,
concurrencyLimit,
cryptoService: this.cryptoService,
dpopKeys,
entity: entityObject,
Expand Down
101 changes: 40 additions & 61 deletions lib/tdf3/src/tdf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import PolicyObject from '../../src/tdf/PolicyObject.js';
import { type CryptoService, type DecryptResult } from './crypto/declarations.js';
import { CentralDirectory } from './utils/zip-reader.js';
import { SymmetricCipher } from './ciphers/symmetric-cipher-base.js';
import { allPool, anyPool } from '../../src/concurrency.js';

// TODO: input validation on manifest JSON
const DEFAULT_SEGMENT_SIZE = 1024 * 1024;
Expand Down Expand Up @@ -163,6 +164,7 @@ export type DecryptConfiguration = {
fileStreamServiceWorker?: string;
assertionVerificationKeys?: AssertionVerificationKeys;
noVerifyAssertions?: boolean;
concurrencyLimit?: number;
};

export type UpsertConfiguration = {
Expand Down Expand Up @@ -904,17 +906,24 @@ export function splitLookupTableFactory(
return splitPotentials;
}

type RewrapResponseData = {
key: Uint8Array;
metadata: Record<string, unknown>;
};

async function unwrapKey({
manifest,
allowedKases,
authProvider,
dpopKeys,
concurrencyLimit,
entity,
cryptoService,
}: {
manifest: Manifest;
allowedKases: OriginAllowList;
authProvider: AuthProvider | AppIdAuthProvider;
concurrencyLimit?: number;
dpopKeys: CryptoKeyPair;
entity: EntityObject | undefined;
cryptoService: CryptoService;
Expand All @@ -928,7 +937,7 @@ async function unwrapKey({
const splitPotentials = splitLookupTableFactory(keyAccess, allowedKases);
const isAppIdProvider = authProvider && isAppIdProviderCheck(authProvider);

async function tryKasRewrap(keySplitInfo: KeyAccessObject) {
async function tryKasRewrap(keySplitInfo: KeyAccessObject): Promise<RewrapResponseData> {
const url = `${keySplitInfo.url}/${isAppIdProvider ? '' : 'v2/'}rewrap`;
const ephemeralEncryptionKeys = await cryptoService.cryptoToPemPair(
await cryptoService.generateKeyPair()
Expand Down Expand Up @@ -982,77 +991,47 @@ async function unwrapKey({
};
}

// Get unique split IDs to determine if we have an OR or AND condition
const splitIds = new Set(Object.keys(splitPotentials));

// If we have only one split ID, it's an OR condition
if (splitIds.size === 1) {
const [splitId] = splitIds;
let poolSize = 1;
if (concurrencyLimit !== undefined && concurrencyLimit > 1) {
poolSize = concurrencyLimit;
}
const splitPromises: Record<string, Promise<RewrapResponseData>> = {};
for (const splitId of Object.keys(splitPotentials)) {
const potentials = splitPotentials[splitId];

try {
// OR condition: Try all KAS servers for this split, take first success
const result = await Promise.any(
Object.values(potentials).map(async (keySplitInfo) => {
try {
return await tryKasRewrap(keySplitInfo);
} catch (e) {
// Rethrow with more context
throw handleRewrapError(e as Error | AxiosError);
}
})
if (!potentials || !Object.keys(potentials).length) {
throw new UnsafeUrlError(
`Unreconstructable key - no valid KAS found for split ${JSON.stringify(splitId)}`,
''
);

const reconstructedKey = keyMerge([result.key]);
return {
reconstructedKeyBinary: Binary.fromArrayBuffer(reconstructedKey),
metadata: result.metadata,
};
} catch (error) {
if (error instanceof AggregateError) {
// All KAS servers failed
throw error.errors[0]; // Throw the first error since we've already wrapped them
}
throw error;
}
} else {
// AND condition: We need successful results from all different splits
const splitResults = await Promise.all(
Object.entries(splitPotentials).map(async ([splitId, potentials]) => {
if (!potentials || !Object.keys(potentials).length) {
throw new UnsafeUrlError(
`Unreconstructable key - no valid KAS found for split ${JSON.stringify(splitId)}`,
''
);
}

const anyPromises: Record<string, Promise<RewrapResponseData>> = {};
for (const [kas, keySplitInfo] of Object.entries(potentials)) {
anyPromises[kas] = (async () => {
try {
// For each split, try all potential KAS servers until one succeeds
return await Promise.any(
Object.values(potentials).map(async (keySplitInfo) => {
try {
return await tryKasRewrap(keySplitInfo);
} catch (e) {
throw handleRewrapError(e as Error | AxiosError);
}
})
);
} catch (error) {
if (error instanceof AggregateError) {
// All KAS servers for this split failed
throw error.errors[0]; // Throw the first error since we've already wrapped them
}
throw error;
return await tryKasRewrap(keySplitInfo);
} catch (e) {
throw handleRewrapError(e as Error | AxiosError);
}
})
);

})();
}
splitPromises[splitId] = anyPool(poolSize, anyPromises);
}
try {
const splitResults = await allPool(poolSize, splitPromises);
// Merge all the split keys
const reconstructedKey = keyMerge(splitResults.map((r) => r.key));
return {
reconstructedKeyBinary: Binary.fromArrayBuffer(reconstructedKey),
metadata: splitResults[0].metadata, // Use metadata from first split
};
} catch (e) {
if (e instanceof AggregateError) {
const errors = e.errors;
if (errors.length === 1) {
throw errors[0];
}
}
throw e;
}
}

Expand Down
65 changes: 65 additions & 0 deletions lib/tests/mocha/unit/concurrency.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { allPool, anyPool } from '../../../src/concurrency.js';
import { expect } from 'chai';

describe('concurrency', () => {
for (const n of [1, 2, 3, 4]) {
describe(`allPool(${n})`, () => {
it(`should resolve all promises with a pool size of ${n}`, async () => {
const promises = {
a: Promise.resolve(1),
b: Promise.resolve(2),
c: Promise.resolve(3),
};
const result = await allPool(n, promises);
expect(result).to.have.members([1, 2, 3]);
});
it(`should reject if any promise rejects, n=${n}`, async () => {
const promises = {
a: Promise.resolve(1),
b: Promise.reject(new Error('failure')),
c: Promise.resolve(3),
};
try {
await allPool(n, promises);
} catch (e) {
expect(e).to.contain({ message: 'failure' });
}
});
});
describe(`anyPool(${n})`, () => {
it('should resolve with the first resolved promise', async () => {
const startTime = Date.now();
const promises = {
a: new Promise((resolve) => setTimeout(() => resolve(1), 500)),
b: new Promise((resolve) => setTimeout(() => resolve(2), 50)),
c: new Promise((resolve) => setTimeout(() => resolve(3), 1500)),
};
const result = await anyPool(n, promises);
const endTime = Date.now();
const elapsed = endTime - startTime;
if (n > 1) {
expect(elapsed).to.be.lessThan(500);
expect(result).to.equal(2);
} else {
expect(elapsed).to.be.greaterThan(50);
expect(elapsed).to.be.lessThan(1000);
expect(result).to.equal(1);
}
});

it('should reject if all promises reject', async () => {
const promises = {
a: Promise.reject(new Error('failure1')),
b: Promise.reject(new Error('failure2')),
c: Promise.reject(new Error('failure3')),
};
try {
await anyPool(n, promises);
} catch (e) {
expect(e).to.be.instanceOf(AggregateError);
expect(e.errors).to.have.lengthOf(3);
}
});
});
}
});

0 comments on commit beb3c06

Please sign in to comment.