Skip to content

Commit

Permalink
batch mode
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent b160840 commit 340c88b
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 30 deletions.
40 changes: 40 additions & 0 deletions js/web/lib/onnxjs/backends/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@ import {Backend, SessionHandler} from '../backend';
import {Logger} from '../instrument';
import {Session} from '../session';

import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager';
import {WebGpuSessionHandler} from './webgpu/session-handler';

export class WebGpuBackend implements Backend {
device: GPUDevice;
gpuDataManager: GpuDataManager;

commandEncoder: GPUCommandEncoder|null = null;
computePassEncoder: GPUComputePassEncoder|null = null;
pendingDispatchNumber = 0;

// #region interface Backend

async initialize(): Promise<boolean> {
try {
if (!navigator.gpu) {
Expand All @@ -25,6 +34,7 @@ export class WebGpuBackend implements Backend {
return false;
}
this.device = await adapter.requestDevice();
this.gpuDataManager = createGpuDataManager(this);

// TODO: set up flags

Expand Down Expand Up @@ -52,4 +62,34 @@ export class WebGpuBackend implements Backend {
// TODO: uninitialization
// this.glContext.dispose();
}

// #endregion interface Backend

getCommandEncoder(): GPUCommandEncoder {
if (!this.commandEncoder) {
this.commandEncoder = this.device.createCommandEncoder();
}
return this.commandEncoder;
}

getComputePassEncoder(): GPUComputePassEncoder {
if (!this.computePassEncoder) {
this.computePassEncoder = this.getCommandEncoder().beginComputePass();
}
return this.computePassEncoder;
}

endComputePass(): void {
if (this.computePassEncoder) {
this.computePassEncoder.end();
this.computePassEncoder = null;
}
}

flush(): void {
this.endComputePass();
this.device.queue.submit([this.commandEncoder!.finish()]);
this.commandEncoder = null;
this.pendingDispatchNumber = 0;
}
}
20 changes: 11 additions & 9 deletions js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {Logger} from '../../instrument';

import {sizeof, Tensor} from '../../tensor';
import {ShapeUtil} from '../../util';
import {WebGpuBackend} from '../backend-webgpu';
import {GpuData, GpuDataId, GpuDataType} from './types';

/**
Expand Down Expand Up @@ -57,7 +58,7 @@ class GpuDataManagerImpl implements GpuDataManager {
// GPU Data ID => GPU Data ( read buffer )
downloadCache: Map<GpuDataId, DownloadCacheValue>;

constructor(private device: GPUDevice) {
constructor(private backend: WebGpuBackend /* , private reuseBuffer: boolean */) {
this.storageCache = new Map();
this.downloadCache = new Map();
}
Expand All @@ -75,7 +76,7 @@ class GpuDataManagerImpl implements GpuDataManager {
const size = calcNormalizedBufferSize(srcLength);

// create gpu buffer
const gpuBuffer = this.device.createBuffer({mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE});
const gpuBuffer = this.backend.device.createBuffer({mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE});

// copy (upload) data
const arrayBuffer = gpuBuffer.getMappedRange();
Expand Down Expand Up @@ -104,7 +105,7 @@ class GpuDataManagerImpl implements GpuDataManager {
// create gpu buffer
const gpuBuffer =
// eslint-disable-next-line no-bitwise
this.device.createBuffer({size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC});
this.backend.device.createBuffer({size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC});

const gpuData = {id: Guid.create(), type: GpuDataType.default, buffer: gpuBuffer};
this.storageCache.set(gpuData.id, {gpuData, size: bufferLength});
Expand Down Expand Up @@ -146,20 +147,21 @@ class GpuDataManagerImpl implements GpuDataManager {

Logger.verbose('GpuData', `Downloading data from GPU: {${id}}`);

const commandEncoder = this.device.createCommandEncoder();
const gpuReadBuffer =
const commandEncoder = this.backend.getCommandEncoder();
this.backend.endComputePass();
const gpuReadBuffer = this.backend.device.createBuffer(
// eslint-disable-next-line no-bitwise
this.device.createBuffer({size: cachedData.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ});
{size: cachedData.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ});
commandEncoder.copyBufferToBuffer(
cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */,
0 /* destination offset */, cachedData.size /* size */
);
const gpuCommands = commandEncoder.finish();
this.device.queue.submit([gpuCommands]);
this.backend.flush();

await gpuReadBuffer.mapAsync(GPUMapMode.READ);
return gpuReadBuffer.getMappedRange();
}
}

export const createGpuDataManager = (device: GPUDevice): GpuDataManager => new GpuDataManagerImpl(device);
export const createGpuDataManager = (...args: ConstructorParameters<typeof GpuDataManagerImpl>): GpuDataManager =>
new GpuDataManagerImpl(...args);
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/inference-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export class WebGpuInferenceHandler implements InferenceHandler {
dataManager: TensorDataManager;

constructor(public session: WebGpuSessionHandler) {
this.dataManager = createTensorDataManager(session.backend.device);
this.dataManager = createTensorDataManager(session.backend.gpuDataManager);
}

private async uploadGpuData(tensor: Tensor, textureType: GpuDataType): Promise<GpuData> {
Expand All @@ -46,7 +46,7 @@ export class WebGpuInferenceHandler implements InferenceHandler {
throw new Error(`Input size must be equal to ${program.inputTypes.length}.`);
}

// create info for input
// create info for inputs
const inputDatas: GpuData[] = [];
for (let i = 0; i < program.inputTypes.length; ++i) {
inputDatas[i] = await this.uploadGpuData(inputs[i], program.inputTypes[i]);
Expand All @@ -59,7 +59,7 @@ export class WebGpuInferenceHandler implements InferenceHandler {
(typeof (program as ProgramInfoLoader).get === 'function' ? (program as ProgramInfoLoader).get() :
(program as ProgramInfo));

// create texture info for outputs
// create info for outputs
const outputDatas: GpuData[] = [];
const outputTensors: Tensor[] = [];
for (let i = 0; i < programInfo.outputs.length; ++i) {
Expand Down
24 changes: 12 additions & 12 deletions js/web/lib/onnxjs/backends/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {env} from 'onnxruntime-common';

import {Logger, Profiler} from '../../instrument';
import {WebGpuBackend} from '../backend-webgpu';

import {Artifact, GpuData, ProgramInfo} from './types';

Expand All @@ -20,7 +21,7 @@ export class ProgramManager {
repo: Map<unknown, Artifact>; // this should be per-session object
attributesBound: boolean;

constructor(private device: GPUDevice, public profiler: Readonly<Profiler>) {
constructor(private backend: WebGpuBackend, public profiler: Readonly<Profiler>) {
this.repo = new Map();
this.attributesBound = false;
}
Expand All @@ -32,14 +33,11 @@ export class ProgramManager {
}
run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[],
dispatchGroup: {x: number; y?: number; z?: number}): void {
const device = this.device;
const device = this.backend.device;

// TODO: should we create command encoder every time?
const computePassEncoder = this.backend.getComputePassEncoder();

const commandEncoder = device.createCommandEncoder();

const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setPipeline(buildArtifact.computePipeline);
const entries = [];
for (const input of inputs) {
entries.push({binding: entries.length, resource: {buffer: input.buffer}});
Expand All @@ -48,20 +46,22 @@ export class ProgramManager {
entries.push({binding: entries.length, resource: {buffer: output.buffer}});
}
const bindGroup = device.createBindGroup({layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries});
passEncoder.setBindGroup(0, bindGroup);
computePassEncoder.setBindGroup(0, bindGroup);

const {x, y, z} = dispatchGroup;
passEncoder.dispatch(x, y, z);
computePassEncoder.dispatch(x, y, z);

passEncoder.endPass();
this.backend.pendingDispatchNumber++;

device.queue.submit([commandEncoder.finish()]);
if (this.backend.pendingDispatchNumber >= 16) {
this.backend.flush();
}
}
dispose(): void {
// this.repo.forEach(a => this.glContext.deleteProgram(a.program));
}
build(programInfo: ProgramInfo): Artifact {
const device = this.device;
const device = this.backend.device;

const shaderModule = device.createShaderModule({code: programInfo.shaderSource});
if (env.debug) {
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ import {createTensorDataManager, TensorDataManager} from './tensor-data-manager'
export class WebGpuSessionHandler implements SessionHandler {
private initializers: Set<Tensor.Id>;
readonly dataManager: TensorDataManager;
programManager: ProgramManager;
readonly programManager: ProgramManager;

constructor(public readonly backend: WebGpuBackend, public readonly context: Session.Context) {
this.dataManager = createTensorDataManager(this.backend.device);
this.programManager = new ProgramManager(this.backend.device, this.context.profiler);
this.dataManager = createTensorDataManager(this.backend.gpuDataManager);
this.programManager = new ProgramManager(this.backend, this.context.profiler);
}

createInferenceHandler() {
Expand Down
18 changes: 15 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/tensor-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@

import {createView, Tensor} from '../../tensor';

import {createGpuDataManager, GpuDataManager} from './gpu-data-manager';
import {GpuDataManager} from './gpu-data-manager';
import {GpuData, GpuDataId, GpuDataType} from './types';

/**
* manages Tensor ID -> Gpu Data ID
*
* A tensor ID is a unique ID representing a value(tensor), which is the graph's node's input or output.
* A GPU Data ID is a unique ID representing an abstract data on GPU memory. Specifically, for current WebGPU scenarios,
* GPU Data is a storage buffer, and GPU Data ID is a handle to a storage buffer.
*
* - a value is different to the graph's edge. if a node's output is consumed by 2 other downstream nodes, there are
* 2 edges, but only one value.
*
* - a tensor ID maps to 0 or 1 GPU Data ID, depending on whether the data is available on GPU or not.
*
* - a GPU Data ID maps to 1 or more tensor ID.
*
*/
export interface TensorDataManager {
/**
Expand Down Expand Up @@ -124,5 +136,5 @@ class TensorDataManagerImpl implements TensorDataManager {
}
}

export const createTensorDataManager = (device: GPUDevice): TensorDataManager =>
new TensorDataManagerImpl(createGpuDataManager(device));
export const createTensorDataManager = (gpuDataManager: GpuDataManager): TensorDataManager =>
new TensorDataManagerImpl(gpuDataManager);

0 comments on commit 340c88b

Please sign in to comment.