Skip to content

Commit

Permalink
Support loading models with weights above 2GB on Chrome (#7609)
Browse files Browse the repository at this point in the history
Chrome ArrayBuffers throw allocation errors above 2GB in size. This makes it impossible to load TFJS models above this size in Chrome (even with weight sharding) because model loading involves concatenating all the weights into a single ArrayBuffer.

This PR avoids this concatenation. Instead of slicing the weight tensors out of a single concatenated ArrayBuffer, it keeps the weight buffers in their original shards and slices them using the CompositeArrayBuffer class created in #7598.
  • Loading branch information
mattsoulanille authored May 4, 2023
1 parent dcd3b43 commit 086e9d8
Show file tree
Hide file tree
Showing 22 changed files with 205 additions and 136 deletions.
9 changes: 6 additions & 3 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ describe('Model', () => {
expect(handler.savedArtifacts.modelTopology).toEqual(CUSTOM_OP_MODEL);
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
tfc.test_util.expectArraysClose(
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
new Int32Array(io.CompositeArrayBuffer.join(
handler.savedArtifacts.weightData)), bias.dataSync());
});
});
});
Expand Down Expand Up @@ -616,7 +617,8 @@ describe('Model', () => {
});
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
tfc.test_util.expectArraysClose(
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
new Int32Array(io.CompositeArrayBuffer.join(
handler.savedArtifacts.weightData)), bias.dataSync());
});
});

Expand Down Expand Up @@ -904,7 +906,8 @@ describe('Model', () => {
});
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
tfc.test_util.expectArraysClose(
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
new Int32Array(io.CompositeArrayBuffer.join(handler.savedArtifacts
.weightData)), bias.dataSync());
});
});

Expand Down
14 changes: 13 additions & 1 deletion tfjs-converter/src/operations/executors/spy_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,23 @@
* =============================================================================
*/

// The opposite of Extract<T, U>
type Without<T, U> = T extends U ? never : T;

// Do not spy on CompositeArrayBuffer because it is a class constructor.
type NotSpiedOn = 'CompositeArrayBuffer';

export type RecursiveSpy<T> =
T extends Function ? jasmine.Spy : {[K in keyof T]: RecursiveSpy<T[K]>};
T extends Function ? jasmine.Spy :
{[K in Without<keyof T, NotSpiedOn>]: RecursiveSpy<T[K]>} &
{[K in Extract<keyof T, NotSpiedOn>]: T[K]};

export function spyOnAllFunctions<T>(obj: T): RecursiveSpy<T> {
return Object.fromEntries(Object.entries(obj).map(([key, val]) => {
// TODO(mattSoulanille): Do not hard code this
if (key === 'CompositeArrayBuffer') {
return val;
}
if (val instanceof Function) {
return [key, jasmine.createSpy(`${key} spy`, val).and.callThrough()];
} else if (val instanceof Array) {
Expand Down
16 changes: 11 additions & 5 deletions tfjs-core/src/io/browser_files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
import '../flags';
import {env} from '../environment';

import {basename, concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils';
import {basename, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {CompositeArrayBuffer} from './composite_array_buffer';

const DEFAULT_FILE_NAME_PREFIX = 'model';
const DEFAULT_JSON_EXTENSION_NAME = '.json';
Expand Down Expand Up @@ -70,8 +71,13 @@ export class BrowserDownloads implements IOHandler {
'Browser downloads are not supported in ' +
'this environment since `document` is not present');
}

// TODO(mattsoulanille): Support saving models over 2GB that exceed
// Chrome's ArrayBuffer size limit.
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);

const weightsURL = window.URL.createObjectURL(new Blob(
[modelArtifacts.weightData], {type: 'application/octet-stream'}));
[weightBuffer], {type: 'application/octet-stream'}));

if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error(
Expand Down Expand Up @@ -169,7 +175,7 @@ class BrowserFiles implements IOHandler {
}

private loadWeights(weightsManifest: WeightsManifestConfig): Promise<[
/* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer
/* weightSpecs */ WeightsManifestEntry[], WeightData,
]> {
const weightSpecs: WeightsManifestEntry[] = [];
const paths: string[] = [];
Expand All @@ -185,7 +191,7 @@ class BrowserFiles implements IOHandler {
paths.map(path => this.loadWeightsFile(path, pathToFile[path]));

return Promise.all(promises).then(
buffers => [weightSpecs, concatenateArrayBuffers(buffers)]);
buffers => [weightSpecs, buffers]);
}

private loadWeightsFile(path: string, file: File): Promise<ArrayBuffer> {
Expand Down
24 changes: 14 additions & 10 deletions tfjs-core/src/io/browser_files_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import * as tf from '../index';
import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util';
import {browserDownloads, BrowserDownloads, browserDownloadsRouter} from './browser_files';
import {WeightsManifestConfig, WeightsManifestEntry} from './types';
import {CompositeArrayBuffer} from './composite_array_buffer';

const modelTopology1: {} = {
'class_name': 'Sequential',
Expand Down Expand Up @@ -310,7 +311,7 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
expect(modelArtifacts.modelInitializer).toEqual({});
expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1);

expect(new Uint8Array(modelArtifacts.weightData))
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
.toEqual(new Uint8Array(weightData1));
});

Expand Down Expand Up @@ -351,9 +352,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
const modelArtifacts = await filesHandler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs);
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
1, 2, 3, 4, 10, 20, 30, 40
]));
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
.toEqual(new Uint8Array([
1, 2, 3, 4, 10, 20, 30, 40
]));
});

it(`Two groups, four paths, reverseOrder=false`, async () => {
Expand Down Expand Up @@ -418,9 +420,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs)
.toEqual(weightSpecs1.concat(weightSpecs2));
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
]));
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
.toEqual(new Uint8Array([
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
]));
});

it(`Two groups, four paths, reverseOrder=true`, async () => {
Expand Down Expand Up @@ -485,9 +488,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs)
.toEqual(weightSpecs1.concat(weightSpecs2));
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
]));
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
.toEqual(new Uint8Array([
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
]));
});

it('Upload model topology only', async () => {
Expand Down
26 changes: 23 additions & 3 deletions tfjs-core/src/io/composite_array_buffer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,22 @@ export class CompositeArrayBuffer {
private bufferUniformSize?: number;
public readonly byteLength: number;

constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray |
/**
* Concatenate a number of ArrayBuffers into one.
*
* @param buffers An array of ArrayBuffers to concatenate, or a single
* ArrayBuffer.
* @returns Result of concatenating `buffers` in order.
*/
static join(buffers?: ArrayBuffer[] | ArrayBuffer) {
return new CompositeArrayBuffer(buffers).slice();
}

constructor(buffers?: ArrayBuffer | ArrayBuffer[] | TypedArray |
TypedArray[]) {
if (buffers == null) {
return;
}
// Normalize the `buffers` input to be `ArrayBuffer[]`.
if (!(buffers instanceof Array)) {
buffers = [buffers];
Expand Down Expand Up @@ -85,6 +99,12 @@ export class CompositeArrayBuffer {
}

slice(start = 0, end = this.byteLength): ArrayBuffer {
// If there are no shards, then the CompositeArrayBuffer was initialized
// with no data.
if (this.shards.length === 0) {
return new ArrayBuffer(0);
}

// NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.
start = isNaN(Number(start)) ? 0 : start;
end = isNaN(Number(end)) ? 0 : end;
Expand Down Expand Up @@ -117,8 +137,8 @@ export class CompositeArrayBuffer {
const globalEnd = Math.min(end, shard.end);
const localEnd = globalEnd - shard.start;

const outputSlice = new Uint8Array(shard.buffer.slice(localStart,
localEnd));
const outputSlice = new Uint8Array(shard.buffer, localStart,
localEnd - localStart);
outputArray.set(outputSlice, outputStart);
sliced += outputSlice.length;

Expand Down
14 changes: 13 additions & 1 deletion tfjs-core/src/io/composite_array_buffer_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ describe('CompositeArrayBuffer', () => {
});
}

it('can be passed an empty arraybuffer', () => {
it('can be created from an empty arraybuffer', () => {
const array = new Uint8Array([]);
const singleComposite = new CompositeArrayBuffer(array.buffer);
expectArraysEqual(new Uint8Array(singleComposite.slice()), []);
Expand All @@ -92,6 +92,18 @@ describe('CompositeArrayBuffer', () => {
expectArraysEqual(new Uint8Array(singleComposite.slice()), array);
});

it('can be created from zero arrays', () => {
const singleComposite = new CompositeArrayBuffer([]);
expectArraysEqual(new Uint8Array(singleComposite.slice()),
new Uint8Array());
});

it('can be created from undefined input', () => {
const singleComposite = new CompositeArrayBuffer();
expectArraysEqual(new Uint8Array(singleComposite.slice()),
new Uint8Array());
});

it('treats NaN as zero when passed as the start of slice', () => {
const array = new Uint8Array([1,2,3]);
const composite = new CompositeArrayBuffer(array.buffer);
Expand Down
15 changes: 10 additions & 5 deletions tfjs-core/src/io/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
import {env} from '../environment';

import {assert} from '../util';
import {concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
import {CompositeArrayBuffer} from './composite_array_buffer';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeightsAsArrayBuffer} from './weights_loader';

const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
Expand Down Expand Up @@ -110,9 +111,13 @@ export class HTTPRequest implements IOHandler {
'model.json');

if (modelArtifacts.weightData != null) {
// TODO(mattsoulanille): Support saving models over 2GB that exceed
// Chrome's ArrayBuffer size limit.
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);

init.body.append(
'model.weights.bin',
new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}),
new Blob([weightBuffer], {type: OCTET_STREAM_MIME_TYPE}),
'model.weights.bin');
}

Expand Down Expand Up @@ -182,7 +187,7 @@ export class HTTPRequest implements IOHandler {
}

private async loadWeights(weightsManifest: WeightsManifestConfig):
Promise<[WeightsManifestEntry[], ArrayBuffer]> {
Promise<[WeightsManifestEntry[], WeightData]> {
const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
const [prefix, suffix] = parseUrl(weightPath);
const pathPrefix = this.weightPathPrefix || prefix;
Expand Down Expand Up @@ -210,7 +215,7 @@ export class HTTPRequest implements IOHandler {
fetchFunc: this.fetch,
onProgress: this.onProgress
});
return [weightSpecs, concatenateArrayBuffers(buffers)];
return [weightSpecs, buffers];
}
}

Expand Down
42 changes: 25 additions & 17 deletions tfjs-core/src/io/http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import * as tf from '../index';
import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util';
import {HTTPRequest, httpRouter, parseUrl} from './http';
import {CompositeArrayBuffer} from './composite_array_buffer';

// Test data.
const modelTopology1: {} = {
Expand Down Expand Up @@ -161,7 +162,8 @@ describeWithFlags('http-load fetch', NODE_ENVS, () => {
expect(modelArtifacts.generatedBy).toEqual('1.15');
expect(modelArtifacts.convertedBy).toEqual('1.3.1');
expect(modelArtifacts.userDefinedMetadata).toEqual({});
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer.join(
modelArtifacts.weightData))).toEqual(floatData);
});

it('throw exception if no fetch polyfill', () => {
Expand Down Expand Up @@ -507,7 +509,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.userDefinedMetadata).toEqual({});
expect(modelArtifacts.modelInitializer).toEqual({});

expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
.weightData))).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
// Assert that fetch is invoked with `window` as the context.
expect(fetchSpy.calls.mostRecent().object).toEqual(window);
Expand Down Expand Up @@ -550,7 +553,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
.weightData))).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
expect(Object.keys(requestInits).length).toEqual(2);
expect(requestInits['./model.json'].headers['header_key_1'])
Expand Down Expand Up @@ -599,8 +603,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData))
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
.weightData))).toEqual(new Float32Array([1, 3, 3, 7, 4]));
});

it('2 groups, 2 weight, 2 paths', async () => {
Expand Down Expand Up @@ -644,8 +648,9 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs)
.toEqual(
weightsManifest[0].weights.concat(weightsManifest[1].weights));
expect(new Float32Array(modelArtifacts.weightData))
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
expect(new Float32Array(CompositeArrayBuffer.join(
modelArtifacts.weightData)))
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
});

it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', async () => {
Expand Down Expand Up @@ -689,10 +694,10 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs)
.toEqual(
weightsManifest[0].weights.concat(weightsManifest[1].weights));
expect(new Int32Array(modelArtifacts.weightData.slice(0, 12)))
.toEqual(new Int32Array([1, 3, 3]));
expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14)))
.toEqual(new Uint8Array([7, 4]));
expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
.slice(0, 12))).toEqual(new Int32Array([1, 3, 3]));
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
.slice(12, 14))).toEqual(new Uint8Array([7, 4]));
});

it('topology only', async () => {
Expand Down Expand Up @@ -752,10 +757,11 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs)
.toEqual(
weightsManifest[0].weights.concat(weightsManifest[1].weights));
expect(new Int32Array(modelArtifacts.weightData.slice(0, 12)))
.toEqual(new Int32Array([1, 3, 3]));
expect(new Float32Array(modelArtifacts.weightData.slice(12, 20)))
.toEqual(new Float32Array([-7, -4]));
expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
.slice(0, 12))).toEqual(new Int32Array([1, 3, 3]));
expect(new Float32Array(CompositeArrayBuffer
.join(modelArtifacts.weightData)
.slice(12, 20))).toEqual(new Float32Array([-7, -4]));
});

it('Missing modelTopology and weightsManifest leads to error', async () => {
Expand Down Expand Up @@ -840,7 +846,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer.join(
modelArtifacts.weightData))).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
expect(Object.keys(requestInits).length).toEqual(2);
expect(requestInits['./model.json'].headers['header_key_1'])
Expand Down Expand Up @@ -902,7 +909,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer
.join(modelArtifacts.weightData))).toEqual(floatData);

expect(fetchInputs).toEqual(['./model.json', './weightfile0']);
expect(fetchInits.length).toEqual(2);
Expand Down
Loading

0 comments on commit 086e9d8

Please sign in to comment.