Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Support capture and replay for jsep #18989

Merged
merged 33 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ccc89af
[js/webgpu] Add record/replay support
qjia7 Dec 27, 2023
c675f71
Add record/replay support in c++
qjia7 Dec 28, 2023
547e005
support record/replay in js
qjia7 Dec 29, 2023
62d64b8
remove unused codes
qjia7 Jan 3, 2024
be58e40
add EP option graphCaptureEnabled
qjia7 Jan 5, 2024
012066b
add sessionId to capture/replay methods
qjia7 Jan 8, 2024
db02615
Add releaseSession interface
qjia7 Jan 8, 2024
2d0c878
Create an internal buffer for each external buffer
qjia7 Jan 8, 2024
c00b29b
Revert "Create an internal buffer for each external buffer"
qjia7 Jan 8, 2024
e1a4bc4
throw errrors when not supported
qjia7 Jan 9, 2024
387ff44
only bind input/output once for IOBinding when graphCaptureEnabled =
qjia7 Jan 9, 2024
79f392c
nits
qjia7 Jan 10, 2024
030d347
update name and annotation
qjia7 Jan 10, 2024
c4cfde0
fix format issues
qjia7 Jan 10, 2024
d105c52
fix lint/format errors
qjia7 Jan 11, 2024
c5137c6
Merge branch 'main' into record_and_replay
qjia7 Jan 15, 2024
a5adf02
Merge branch 'main' into record_and_replay
qjia7 Jan 15, 2024
cc2ff91
nits
qjia7 Jan 15, 2024
e630dbf
enable timestamp query
qjia7 Jan 18, 2024
4c313ad
address Yulong's comments
qjia7 Jan 19, 2024
3f3c6df
reuse the storage buffer
qjia7 Jan 19, 2024
b992f6c
Revert "reuse the storage buffer"
qjia7 Jan 22, 2024
2f13fcd
Merge branch 'main' into record_and_replay
qjia7 Jan 23, 2024
2172984
integrate setQueryType changes
qjia7 Jan 23, 2024
6e0ef20
flush the left commands before status changed
qjia7 Jan 23, 2024
b785a05
address comments
qjia7 Jan 25, 2024
3a80d5c
rename to enableGraphCapture and move to SessionOptions
qjia7 Jan 26, 2024
b6f5d95
Merge branch 'main' into record_and_replay
qjia7 Jan 26, 2024
dc8cc2b
nits
qjia7 Jan 26, 2024
b0de471
Merge branch 'main' into record_and_replay
qjia7 Jan 27, 2024
dff25fa
Address Yulong's comments
qjia7 Jan 29, 2024
1beb3f1
further simplify if (!enableGraphCapture || !inputOutputBound)
qjia7 Jan 29, 2024
59f5f92
call OrtClearBoundOutputs when release session
qjia7 Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ export declare namespace InferenceSession {
export interface WebGpuExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'webgpu';
preferredLayout?: 'NCHW'|'NHWC';
graphCaptureEnabled?: boolean;
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
}
export interface WebNNExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'webnn';
Expand Down
18 changes: 11 additions & 7 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ export declare namespace JSEP {
type ReleaseKernelFunction = (kernel: number) => void;
type RunFunction =
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => number;
type CaptureBeginFunction = (sessionHandle: number) => void;
qjia7 marked this conversation as resolved.
Show resolved Hide resolved
type CaptureEndFunction = (sessionHandle: number) => void;
type ReplayFunction = (sessionHandle: number) => void;
}

export interface OrtWasmModule extends EmscriptenModule {
Expand Down Expand Up @@ -128,7 +131,8 @@ export interface OrtWasmModule extends EmscriptenModule {
jsepInit?
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction,
captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void;

/**
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
Expand Down Expand Up @@ -158,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule {
* @returns the GPU data ID for the registered GPU buffer.
*/
jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
/**
* [exported from js_internal_api.js] Unregister all user GPU buffers for a session.
*
* @param sessionId - specify the session ID.
*/
jsepUnregisterBuffers?: (sessionId: number) => void;
/**
* [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
*
Expand All @@ -186,6 +184,12 @@ export interface OrtWasmModule extends EmscriptenModule {
* [exported from js_internal_api.js] Called when InferenceSession.run started.
*/
jsepOnRunStart: () => void;
/**
* [exported from js_internal_api.js] Release a session.
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
* @param sessionId - specify the session ID.
* @returns
*/
jsepOnReleaseSession: (sessionId: number) => void;
// #endregion
}

Expand Down
98 changes: 95 additions & 3 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, StatusType, TimestampQuery} from './webgpu/types';

interface CommandInfo {
readonly kernelId: number;
readonly computePipeline: GPUComputePipeline;
readonly bindGroup: GPUBindGroup;
readonly dispatchGroup: [number, number, number];
}
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

interface KernelInfo {
readonly kernelType: string;
Expand Down Expand Up @@ -103,6 +110,13 @@ export class WebGpuBackend {
*/
programManager: ProgramManager;

/**
* representing the session ID of which is currently being captured/replay.
* `null` means no session is being captured.
* only valid when captureGraphEnabled = true.
*/
currentSessionId: number|null = null;
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

/**
* representing the kernel ID of which is currently being computed (CPU code perspective).
* `null` means no kernel is being computed.
Expand Down Expand Up @@ -155,6 +169,16 @@ export class WebGpuBackend {
queryType: TimestampQuery;

env: Env;
status: StatusType = StatusType.default;
/**
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
* a SessionID -> CommandInfo[] mapping.
*/
capturedCommandList: Map<number, CommandInfo[]> = new Map();

/**
* a SessionID -> PendingKernelInfo[] mapping for profiling.
*/
private capturedPendingKernels: Map<number, PendingKernelInfo[]> = new Map();

/**
* a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
Expand Down Expand Up @@ -238,6 +262,7 @@ export class WebGpuBackend {

getComputePassEncoder(): GPUComputePassEncoder {
if (!this.computePassEncoder) {
const commandEncoder = this.getCommandEncoder();
const computePassDescriptor: GPUComputePassDescriptor = {};

if (this.queryType === 'at-passes') {
Expand All @@ -248,7 +273,7 @@ export class WebGpuBackend {
};
}

this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor);
this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor);
}
return this.computePassEncoder;
}
Expand Down Expand Up @@ -488,14 +513,17 @@ export class WebGpuBackend {
() => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`);

if (this.queryType !== 'none') {
if (this.queryType !== 'none' || this.status === StatusType.capture) {
const pendingKernelInfo: PendingKernelInfo = {
kernelId: this.currentKernelId!,
programName: artifact.programInfo.name,
inputTensorViews,
outputTensorViews,
};
this.pendingKernels.push(pendingKernelInfo);

const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
sessionPendingKernels!.push(pendingKernelInfo);
}

this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding);
Expand Down Expand Up @@ -656,6 +684,70 @@ export class WebGpuBackend {
}
}
}

captureBegin(sessionHandle: number): void {
LOG_DEBUG('info', () => `captureBegin ${sessionHandle}`);
this.currentSessionId = sessionHandle;
let sessionCommandList = this.capturedCommandList.get(sessionHandle);
let sessionPendingKernels = this.capturedPendingKernels.get(sessionHandle);
if (!sessionCommandList) {
sessionCommandList = [];
this.capturedCommandList.set(sessionHandle, sessionCommandList);
sessionPendingKernels = [];
this.capturedPendingKernels.set(sessionHandle, sessionPendingKernels);
}
this.status = StatusType.capture;
}
captureEnd(sessionHandle: number): void {
LOG_DEBUG('info', () => `captureEnd ${sessionHandle}`);
// flush the left commands before we change the status.
this.flush();
this.currentSessionId = null;
this.status = StatusType.default;
}
replay(sessionHandle: number): void {
LOG_DEBUG('info', () => `replay ${sessionHandle}`);
this.currentSessionId = sessionHandle;
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
this.status = StatusType.replay;
const sessionCommandList = this.capturedCommandList.get(sessionHandle);
const sessionPendingKernels = this.capturedPendingKernels.get(sessionHandle);
const length = sessionCommandList!.length;
this.pendingKernels = [];
for (let i = 0; i < length; i++) {
const computePassEncoder = this.getComputePassEncoder();
const command = sessionCommandList![i];
this.writeTimestamp(this.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(command.computePipeline);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
this.pendingDispatchNumber++;
if (this.queryType !== 'none') {
this.pendingKernels.push(sessionPendingKernels![i]);
}
if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
this.endComputePass();
}
if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
this.flush();
}
}
// flush the left commands before we change the status.
this.flush();
this.status = StatusType.default;
}

onReleaseSession(sessionId: number): void {
this.unregisterBuffers(sessionId);
if (this.capturedCommandList.has(sessionId)) {
this.capturedCommandList.delete(sessionId);
}
if (this.capturedPendingKernels.has(sessionId)) {
this.capturedPendingKernels.delete(sessionId);
}
this.gpuDataManager.onReleaseSession(sessionId);
}

onRunStart(): void {
this.setQueryType();
}
Expand Down
8 changes: 7 additions & 1 deletion js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
});
},
// jsepCaptureBegin
(sessionHandle: number) => backend.captureBegin(sessionHandle),
// jsepCaptureEnd
(sessionHandle: number) => backend.captureEnd(sessionHandle),
// jsepReplay
(sessionHandle: number) => backend.replay(sessionHandle));
};
76 changes: 63 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';

import {GpuData, GpuDataId, GpuDataType} from './types';
import {GpuData, GpuDataId, GpuDataType, StatusType} from './types';

/**
* manages GpuDataId -> GpuBuffer
Expand Down Expand Up @@ -60,9 +60,15 @@ export interface GpuDataManager {
unregisterExternalBuffer(buffer: GPUBuffer): void;

/**
* destroy all gpu buffers. Call this when the session.release is called.
* destroy all gpu buffers.
*/
dispose(): void;

/**
* release session related data.
* @param sessionId - specify the session ID.
*/
onReleaseSession(sessionId: number): void;
}

interface StorageCacheValue {
Expand Down Expand Up @@ -139,13 +145,18 @@ class GpuDataManagerImpl implements GpuDataManager {
// The external buffers registered users for IO Binding.
private externalBuffers: Map<GPUBuffer, GpuDataId>;

// The pendingBuffers for capture graph.
// a SessionID -> GPUBuffer[] mapping.
private capturedPendingBuffers: Map<number, GPUBuffer[]>;

constructor(private backend: WebGpuBackend) {
this.storageCache = new Map();
this.freeBuffers = new Map();
this.freeUniformBuffers = new Map();
this.buffersForUploadingPending = [];
this.buffersPending = [];
this.externalBuffers = new Map();
this.capturedPendingBuffers = new Map();
}

upload(id: GpuDataId, data: Uint8Array): void {
Expand Down Expand Up @@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager {
() => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${
id}, buffer is the same, skip.`);
return id;
} else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) {
throw new Error(`Registering a different external buffer under graph capture mode is not supported yet.
Please use the previous external buffer!`);
}
this.externalBuffers.delete(previousBuffer);
} else {
Expand Down Expand Up @@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager {
buffer.destroy();
}
this.buffersForUploadingPending = [];
for (const buffer of this.buffersPending) {
// eslint-disable-next-line no-bitwise
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
// Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
this.freeBuffers.get(buffer.size)!.push(buffer);

if (this.buffersPending.length === 0) {
return;
}

if (this.backend.status === StatusType.default) {
for (const buffer of this.buffersPending) {
// eslint-disable-next-line no-bitwise
} else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
// Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
this.freeUniformBuffers.get(buffer.size)!.push(buffer);
} else {
buffer.destroy();
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
// Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
this.freeBuffers.get(buffer.size)!.push(buffer);
// eslint-disable-next-line no-bitwise
} else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
// Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
this.freeUniformBuffers.get(buffer.size)!.push(buffer);
} else {
buffer.destroy();
}
}
this.buffersPending = [];
} else {
// Don't release intermediate tensors in non-default mode.
// TODO: reuse the storage buffers in non-default mode.
let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!);
if (!capturedBuffers) {
capturedBuffers = [];
this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers);
}
for (const buffer of this.buffersPending) {
capturedBuffers.push(buffer);
}
this.buffersPending = [];
}
this.buffersPending = [];
}

dispose() {
Expand All @@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager {
storage.gpuData.buffer.destroy();
});

this.capturedPendingBuffers.forEach((buffers) => {
buffers.forEach(buffer => {
buffer.destroy();
});
});
this.storageCache = new Map();
this.freeBuffers = new Map();
this.freeUniformBuffers = new Map();
this.capturedPendingBuffers = new Map();
}

onReleaseSession(sessionId: number) {
// release the captured pending buffers.
const pendingBuffers = this.capturedPendingBuffers.get(sessionId);
if (pendingBuffers) {
pendingBuffers.forEach(buffer => {
buffer.destroy();
});
this.capturedPendingBuffers.delete(sessionId);
}
}
}

Expand Down
17 changes: 14 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';

import {createShaderHelper} from './ops/common';
import {Artifact, GpuData, ProgramInfo} from './types';
import {Artifact, GpuData, ProgramInfo, StatusType} from './types';

/**
* ProgramManager is the main class behind running computations
Expand Down Expand Up @@ -38,7 +38,6 @@ export class ProgramManager {
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(buildArtifact.computePipeline);
const entries = [];
for (const input of inputs) {
entries.push({binding: entries.length, resource: {buffer: input.buffer}});
Expand All @@ -51,8 +50,20 @@ export class ProgramManager {
}
const bindGroup = device.createBindGroup(
{layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name});
computePassEncoder.setBindGroup(0, bindGroup);

if (this.backend.status === StatusType.capture) {
const commandInfo = {
kernelId: this.backend.currentKernelId!,
computePipeline: buildArtifact.computePipeline,
bindGroup,
dispatchGroup
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(commandInfo);
}

computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
this.backend.pendingDispatchNumber++;
Expand Down
Loading