Skip to content

Commit

Permalink
reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent dfbf6f3 commit b6e7fba
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 67 deletions.
80 changes: 53 additions & 27 deletions js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,29 @@ import {GpuData, GpuDataId, GpuDataType} from './types';
* manages GpuDataId -> GpuBuffer
*/
export interface GpuDataManager {
uploadData(tensor: Tensor, gpuDataType: GpuDataType): GpuData;
createData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData;
releaseData(tensorId: Tensor.Id): void;
downloadData(tensorId: Tensor.Id): Promise<ArrayBufferLike>;
/**
* upload data to GPU. if the ID already exists in cache, returns the cached value without uploading anything.
*/
upload(id: GpuDataId, data: Tensor.NumberType, gpuDataType: GpuDataType): Promise<GpuData>;
/**
* create new data on GPU.
*/
create(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData;
/**
* get GPU data by ID.
*/
get(id: GpuDataId): GpuData|undefined;
/**
* release the data on GPU by ID.
*/
release(id: GpuDataId): void;
/**
* download the data from GPU.
*/
download(id: GpuDataId): Promise<ArrayBufferLike>;
}

interface DefaultCacheValue {
interface StorageCacheValue {
gpuData: GpuData;
size: number;
}
Expand All @@ -27,27 +43,25 @@ interface DownloadCacheValue {
}

class GpuDataManagerImpl implements GpuDataManager {
defaultCache: Map<GpuDataId, DefaultCacheValue>;
// GPU Data ID => GPU Data ( storage buffer )
storageCache: Map<GpuDataId, StorageCacheValue>;

// GPU Data ID => GPU Data ( read buffer )
downloadCache: Map<GpuDataId, DownloadCacheValue>;

constructor(private device: GPUDevice) {
this.defaultCache = new Map();
this.storageCache = new Map();
this.downloadCache = new Map();
}

uploadData(tensor: Tensor, gpuDataType: GpuDataType): GpuData {
async upload(id: GpuDataId, data: Tensor.NumberType, gpuDataType: GpuDataType): Promise<GpuData> {
if (gpuDataType !== GpuDataType.default) {
throw new Error('we only support default GPU data type now');
}

const cachedData = this.defaultCache.get(tensor.dataId);
if (cachedData) {
return cachedData.gpuData;
}

const src = tensor.numberData;
const srcArrayBuffer = src.buffer;
const srcOffset = src.byteOffset;
const srcLength = src.byteLength;
const srcArrayBuffer = data.buffer;
const srcOffset = data.byteOffset;
const srcLength = data.byteLength;

// create gpu buffer
const gpuBuffer =
Expand All @@ -58,12 +72,12 @@ class GpuDataManagerImpl implements GpuDataManager {
new Uint8Array(arrayBuffer).set(new Uint8Array(srcArrayBuffer, srcOffset, srcLength));
gpuBuffer.unmap();

const gpuData = {id: tensor.dataId, type: GpuDataType.default, buffer: gpuBuffer};
this.defaultCache.set(gpuData.id, {gpuData, size: srcLength});
const gpuData = {id, type: GpuDataType.default, buffer: gpuBuffer};
this.storageCache.set(id, {gpuData, size: srcLength});
return gpuData;
}

createData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData {
create(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData {
if (gpuDataType !== GpuDataType.default) {
throw new Error('we only support default GPU data type now');
}
Expand All @@ -82,27 +96,39 @@ class GpuDataManagerImpl implements GpuDataManager {
this.device.createBuffer({size: bufferLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC});

const gpuData = {id: Guid.create(), type: GpuDataType.default, buffer: gpuBuffer};
this.defaultCache.set(gpuData.id, {gpuData, size: bufferLength});
this.storageCache.set(gpuData.id, {gpuData, size: bufferLength});
return gpuData;
}

releaseData(tensorId: Tensor.Id): void {
const cachedData = this.defaultCache.get(tensorId);
get(id: GpuDataId): GpuData|undefined {
return this.storageCache.get(id)?.gpuData;
}

release(id: GpuDataId): void {
const cachedData = this.storageCache.get(id);
if (!cachedData) {
throw new Error('releasing data does not exist');
}

this.defaultCache.delete(tensorId);
this.storageCache.delete(id);
cachedData.gpuData.buffer.destroy();

const downloadingData = this.downloadCache.get(id);
if (downloadingData) {
void downloadingData.data.then(() => {
downloadingData.gpuData.buffer.destroy();
});
this.downloadCache.delete(id);
}
}

async downloadData(tensorId: Tensor.Id): Promise<ArrayBufferLike> {
const downloadData = this.downloadCache.get(tensorId);
async download(id: GpuDataId): Promise<ArrayBufferLike> {
const downloadData = this.downloadCache.get(id);
if (downloadData) {
return downloadData.data;
}

const cachedData = this.defaultCache.get(tensorId);
const cachedData = this.storageCache.get(id);
if (!cachedData) {
throw new Error('data does not exist');
}
Expand Down
45 changes: 22 additions & 23 deletions js/web/lib/onnxjs/backends/webgpu/inference-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
// Licensed under the MIT License.

import {InferenceHandler} from '../../backend';
import {createView, Tensor} from '../../tensor';
import {Tensor} from '../../tensor';

import {createGpuDataManager, GpuDataManager} from './gpu-data-manager';
import {WebGpuSessionHandler} from './session-handler';
import {createTensorDataManager, TensorDataManager} from './tensor-data-manager';
import {GpuData, GpuDataType, ProgramInfo, ProgramInfoLoader} from './types';

const getProgramInfoUniqueKey = (programInfo: ProgramInfo|ProgramInfoLoader, inputGpuDatas: GpuData[]): string => {
Expand All @@ -19,24 +19,26 @@ const getProgramInfoUniqueKey = (programInfo: ProgramInfo|ProgramInfoLoader, inp
};

export class WebGpuInferenceHandler implements InferenceHandler {
dataManager: GpuDataManager;
// per inference context
dataManager: TensorDataManager;

constructor(public session: WebGpuSessionHandler) {
this.dataManager = createGpuDataManager(session.backend.device);
this.dataManager = createTensorDataManager(session.backend.device);
}

private uploadGpuData(tensor: Tensor, textureType: GpuDataType): GpuData {
if (this.session.isInitializer(tensor.dataId)) {
return this.session.dataManager.uploadData(tensor, textureType);
return this.session.dataManager.uploadTensorToGpu(tensor, textureType);
}

return this.dataManager.uploadData(tensor, textureType);
return this.dataManager.uploadTensorToGpu(tensor, textureType);
}

private createGpuData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData {
return this.dataManager.createData(type, dims, gpuDataType);
private createGpuData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): [Tensor, GpuData] {
return this.dataManager.createGpuTensor(type, dims, gpuDataType);
}

run(program: ProgramInfoLoader|ProgramInfo, inputs: readonly Tensor[]): Tensor[] {
async run(program: ProgramInfoLoader|ProgramInfo, inputs: readonly Tensor[]): Promise<Tensor[]> {
if (inputs.length !== program.inputTypes.length) {
throw new Error(`Input size must be equal to ${program.inputTypes.length}.`);
}
Expand All @@ -56,9 +58,12 @@ export class WebGpuInferenceHandler implements InferenceHandler {

// create texture info for outputs
const outputDatas: GpuData[] = [];
const outputTensors: Tensor[] = [];
for (let i = 0; i < programInfo.outputs.length; ++i) {
outputDatas.push(this.createGpuData(
programInfo.outputs[i].type, programInfo.outputs[i].dims, programInfo.outputs[i].gpuDataType));
const [tensor, gpuData] = this.createGpuData(
programInfo.outputs[i].type, programInfo.outputs[i].dims, programInfo.outputs[i].gpuDataType);
outputTensors.push(tensor);
outputDatas.push(gpuData);
}

if (!artifact) {
Expand All @@ -68,20 +73,14 @@ export class WebGpuInferenceHandler implements InferenceHandler {

this.session.programManager.run(artifact, inputDatas, outputDatas, artifact.programInfo.dispatchGroup(inputs));

const outputTensors: Tensor[] = [];
for (let i = 0; i < outputDatas.length; i++) {
const outputTensorInfo = artifact.programInfo.outputs[i];
const dims = outputTensorInfo.dims;
const type = outputTensorInfo.type;
const outputData = outputDatas[i];
const tensor = new Tensor(dims, type, undefined, async () => {
const data = await this.dataManager.downloadData(outputData.id);
return createView(data, type);
}, undefined, outputData.id);
outputTensors.push(tensor);
}
return outputTensors;
}

reshape(input: Tensor, reshapedDims: readonly number[]): Tensor {
return this.dataManager.hasGpuData(input.dataId) ?
this.dataManager.createGpuRef(input.dataId, input.type, reshapedDims)[0] :
new Tensor(reshapedDims, input.type, undefined, undefined, input.data);
}

dispose(): void {}
}
7 changes: 4 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import {OpSet} from '../../opset';

import {reshape} from './ops/reshape';
import * as unaryOps from './ops/unary-op';
import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze';

export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
['Abs', '', '6+', unaryOps.abs], ['Acos', '', '7+', unaryOps.acos],
Expand Down Expand Up @@ -59,7 +61,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['ReduceSum', '', '1-12', reduceSum, parseReduceAttributes],
// ['ReduceSumSquare', '', '1+', reduceLogSumSquare, parseReduceAttributes],
// ['Relu', '', '6+', unaryOps.relu],
// ['Reshape', '', '5+', reshape],
['Reshape', '', '5+', reshape],
// ['Resize', '', '10', resize, parseResizeAttributesV10],
// ['Resize', '', '11+', resize, parseResizeAttributesV11],
// ['Shape', '', '1+', shape],
Expand All @@ -86,7 +88,6 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['Transpose', '', '1+', transpose, parseTransposeAttributes],
// ['Upsample', '', '7-8', upsample, parseUpsampleAttributesV7],
// ['Upsample', '', '9', upsample, parseUpsampleAttributesV9],
// ['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes],
// ['Unsqueeze', '', '13+', unsqueezeV13],
['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes], ['Unsqueeze', '', '13+', unsqueezeV13],
// ['Xor', '', '7+', binaryOps.xor],
];
11 changes: 11 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';

export const reshape = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => {
const reshapedDims = ShapeUtil.calculateReshapedDims(inputs[0].dims, inputs[1].integerData);
return [handler.reshape(inputs[0], reshapedDims)];
};
43 changes: 43 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/unsqueeze.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Graph} from '../../../graph';
import {OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';

export const unsqueeze = (inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], axes: number[]): Tensor[] => {
validateInputs(inputs);
const outputShape = ShapeUtil.unsqueezeShape(inputs[0].dims, axes);
const output = inferenceHandler.reshape(inputs[0], outputShape);
return [output];
};

export const unsqueezeV13 = (inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => {
validateInputsV13(inputs);
return unsqueeze(inferenceHandler, [inputs[0]], Array.from(inputs[1].integerData));
};

export const parseUnsqueezeAttributes: OperatorInitialization<number[]> = (node: Graph.Node): number[] =>
node.attributes.getInts('axes');

const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Unsqueeze requires 1 input.');
}

if (inputs[0].type === 'string') {
throw new Error('invalid input tensor types.');
}
};

const validateInputsV13 = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Unsqueeze requires 2 inputs.');
}

if (inputs[1].type !== 'int32') {
throw new Error('Invalid input type.');
}
};
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ import {Session} from '../../session';
import {Tensor} from '../../tensor';
import {WebGpuBackend} from '../backend-webgpu';

import {createGpuDataManager, GpuDataManager} from './gpu-data-manager';
import {WebGpuInferenceHandler} from './inference-handler';
import {WEBGPU_OP_RESOLVE_RULES} from './op-resolve-rules';
import {ProgramManager} from './program-manager';
import {createTensorDataManager, TensorDataManager} from './tensor-data-manager';

export class WebGpuSessionHandler implements SessionHandler {
private initializers: Set<Tensor.Id>;
readonly dataManager: GpuDataManager;
readonly dataManager: TensorDataManager;
programManager: ProgramManager;

constructor(public readonly backend: WebGpuBackend, public readonly context: Session.Context) {
this.dataManager = createGpuDataManager(this.backend.device);
this.dataManager = createTensorDataManager(this.backend.device);
this.programManager = new ProgramManager(this.backend.device, this.context.profiler);
}

Expand Down
Loading

0 comments on commit b6e7fba

Please sign in to comment.