Skip to content

Commit

Permalink
[POC] __blank ( npm test -- -b=webgpu )
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 9edc946 commit ed35262
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 3 deletions.
11 changes: 10 additions & 1 deletion js/web/lib/backend-onnxjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@ class OnnxjsBackend implements Backend {
// onnxruntime-common).
// In future we should remove Session.Config and use InferenceSession.SessionOptions.
// Currently we allow this to happen to make test runner work.
const session = new Session(options as unknown as Session.Config);
const onnxjsOptions = {...options as unknown as Session.Config};
if (!onnxjsOptions.backendHint && options?.executionProviders && options?.executionProviders[0]) {
const ep = options?.executionProviders[0];
if (typeof ep === 'string') {
onnxjsOptions.backendHint = ep;
} else {
onnxjsOptions.backendHint = ep.name;
}
}
const session = new Session(onnxjsOptions);

// typescript cannot merge method override correctly (so far in 4.2.3). need if-else to call the method.
if (typeof pathOrBuffer === 'string') {
Expand Down
1 change: 1 addition & 0 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {registerBackend} from 'onnxruntime-common';
if (!BUILD_DEFS.DISABLE_WEBGL) {
const onnxjsBackend = require('./backend-onnxjs').onnxjsBackend;
registerBackend('webgl', onnxjsBackend, -1);
registerBackend('webgpu', onnxjsBackend, 999); // set to 999 as the highest priority
}
if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = require('./backend-wasm').wasmBackend;
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/onnxjs/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {WebGLBackend} from './backends/backend-webgl';
import {WebGpuBackend} from './backends/backend-webgpu';
import {Graph} from './graph';
import {Operator} from './operators';
import {OpSet} from './opset';
Expand Down Expand Up @@ -79,6 +80,7 @@ const backendsCache: Map<string, Backend> = new Map();

export const backend: {[name: string]: Backend} = {
webgl: new WebGLBackend(),
webgpu: new WebGpuBackend()
};

/**
Expand Down
34 changes: 34 additions & 0 deletions js/web/lib/onnxjs/backends/backend-webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env} from 'onnxruntime-common';
import {Backend, SessionHandler} from '../backend';
import {Logger} from '../instrument';
import {Session} from '../session';

import {WebGpuSessionHandler} from './webgpu/session-handler';

export class WebGpuBackend implements Backend {
initialize(): boolean {
try {
// STEP.1 TODO: set up context (one time initialization)

// STEP.2 TODO: set up flags

Logger.setWithEnv(env);

Logger.verbose('WebGpuBackend', 'Initialized successfully.');
return true;
} catch (e) {
Logger.warning('WebGpuBackend', `Unable to initialize WebGLBackend. ${e}`);
return false;
}
}
createSessionHandler(context: Session.Context): SessionHandler {
return new WebGpuSessionHandler(this, context);
}
dispose(): void {
// TODO: uninitialization
// this.glContext.dispose();
}
}
14 changes: 14 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/inference-handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {InferenceHandler} from '../../backend';

import {WebGpuSessionHandler} from './session-handler';

export class WebGpuInferenceHandler implements InferenceHandler {
constructor(public session: WebGpuSessionHandler) {
// TODO:
}

dispose(): void {}
}
93 changes: 93 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {OpSet} from '../../opset';

export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['Abs', '', '6+', unaryOps.abs],
// ['Acos', '', '7+', unaryOps.acos],
// ['Add', '', '7+', binaryOps.add],
// ['And', '', '7+', binaryOps.and],
// ['Asin', '', '7+', unaryOps.asin],
// ['Atan', '', '7+', unaryOps.atan],
// // TODO: support new attributes for AveragePool-10
// ['AveragePool', '', '7+', averagePool, parseAveragePoolAttributes],
// ['BatchNormalization', '', '7+', batchNormalization, parseBatchNormalizationAttributes],
// ['Cast', '', '6+', cast, parseCastAttributes],
// ['Ceil', '', '6+', unaryOps.ceil],
// ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes],
// ['Clip', '', '11+', unaryOps.clipV11],
// ['Concat', '', '4+', concat, parseConcatAttributes],
// ['Conv', '', '1+', conv, parseConvAttributes],
// ['Cos', '', '7+', unaryOps.cos],
// ['Div', '', '7+', binaryOps.div],
// ['Dropout', '', '7+', unaryOps.identity],
// ['DepthToSpace', '', '1+', depthToSpace, parseDepthToSpaceAttributes],
// ['Equal', '', '7+', binaryOps.equal],
// ['Elu', '', '6+', unaryOps.elu, unaryOps.parseEluAttributes],
// ['Exp', '', '6+', unaryOps.exp],
// ['Flatten', '', '1+', flatten, parseFlattenAttributes],
// ['Floor', '', '6+', unaryOps.floor],
// ['FusedConv', 'com.microsoft', '1+', conv, parseConvAttributes],
// ['Gather', '', '1+', gather, parseGatherAttributes],
// ['Gemm', '', '7-10', gemm, parseGemmAttributesV7],
// ['Gemm', '', '11+', gemm, parseGemmAttributesV11],
// ['GlobalAveragePool', '', '1+', globalAveragePool, parseGlobalAveragePoolAttributes],
// ['GlobalMaxPool', '', '1+', globalMaxPool],
// ['Greater', '', '7+', binaryOps.greater],
// ['Identity', '', '1+', unaryOps.identity],
// ['ImageScaler', '', '1+', imageScaler, parseImageScalerAttributes],
// ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes],
// ['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes],
// ['Less', '', '7+', binaryOps.less],
// ['Log', '', '6+', unaryOps.log],
// ['MatMul', '', '1+', matMul, parseMatMulAttributes],
// // TODO: support new attributes for MaxPool-8 and MaxPool-10
// ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes],
// ['Mul', '', '7+', binaryOps.mul],
// ['Neg', '', '6+', unaryOps.neg],
// ['Not', '', '1+', unaryOps.not],
// ['Or', '', '7+', binaryOps.or],
// ['Pad', '', '2-10', padV2, parsePadAttributesV2],
// ['Pad', '', '11+', padV11, parsePadAttributesV11],
// ['Pow', '', '7+', binaryOps.pow],
// ['PRelu', '', '7+', binaryOps.pRelu],
// ['ReduceLogSum', '', '1+', reduceLogSum, parseReduceAttributes],
// ['ReduceMax', '', '1+', reduceMax, parseReduceAttributes],
// ['ReduceMean', '', '1+', reduceMean, parseReduceAttributes],
// ['ReduceMin', '', '1+', reduceMin, parseReduceAttributes],
// ['ReduceProd', '', '1+', reduceProd, parseReduceAttributes],
// ['ReduceSum', '', '1-12', reduceSum, parseReduceAttributes],
// ['ReduceSumSquare', '', '1+', reduceLogSumSquare, parseReduceAttributes],
// ['Relu', '', '6+', unaryOps.relu],
// ['Reshape', '', '5+', reshape],
// ['Resize', '', '10', resize, parseResizeAttributesV10],
// ['Resize', '', '11+', resize, parseResizeAttributesV11],
// ['Shape', '', '1+', shape],
// ['Sigmoid', '', '6+', unaryOps.sigmoid],
// ['Sin', '', '7+', unaryOps.sin],
// ['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10
// ['Slice', '', '1-9', slice, parseSliceAttributes],
// // The "semantic" meaning of axis has changed in opset-13.
// ['Softmax', '', '1-12', softmax, parseSoftmaxAttributes],
// ['Softmax', '', '13+', softmaxV13, parseSoftmaxAttributesV13],
// // 'Split' operator has an optional attribute 'split'
// // this attribute determines how the specified axis of input data is split.
// // When the attribute is missing, we need the count of number of outputs
// // so that we can determine the 'split' attribute from the runtime input to the Operator
// ['Split', '', '2-12', split, parseSplitAttributes],
// ['Sqrt', '', '6+', unaryOps.sqrt],
// ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes],
// ['Squeeze', '', '13+', squeezeV13],
// ['Sub', '', '7+', binaryOps.sub],
// ['Sum', '', '6+', sum],
// ['Tan', '', '7+', unaryOps.tan],
// ['Tanh', '', '6+', unaryOps.tanh],
// ['Tile', '', '6+', tile],
// ['Transpose', '', '1+', transpose, parseTransposeAttributes],
// ['Upsample', '', '7-8', upsample, parseUpsampleAttributesV7],
// ['Upsample', '', '9', upsample, parseUpsampleAttributesV9],
// ['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes],
// ['Unsqueeze', '', '13+', unsqueezeV13],
// ['Xor', '', '7+', binaryOps.xor],
];
42 changes: 42 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/session-handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {SessionHandler} from '../../backend';
import {Graph} from '../../graph';
import {Operator} from '../../operators';
import {OpSet, resolveOperator} from '../../opset';
import {Session} from '../../session';
import {Tensor} from '../../tensor';
import {WebGpuBackend} from '../backend-webgpu';
import {WebGpuInferenceHandler} from './inference-handler';

import {WEBGPU_OP_RESOLVE_RULES} from './op-resolve-rules';

export class WebGpuSessionHandler implements SessionHandler {
private initializers: Set<Tensor.Id>;

constructor(public readonly backend: WebGpuBackend, public readonly context: Session.Context) {
// TODO
}

createInferenceHandler() {
return new WebGpuInferenceHandler(this);
}
onGraphInitialized(graph: Graph): void {
const initializers = graph.getValues().filter(v => v.from === -1 && v.tensor).map(v => v.tensor!.dataId);
this.initializers = new Set(initializers);
}
isInitializer(tensorId: Tensor.Id): boolean {
return this.initializers ? this.initializers.has(tensorId) : false;
}
addInitializer(tensorId: Tensor.Id): void {
this.initializers.add(tensorId);
}
dispose(): void {
// TODO
}
resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator {
const op = resolveOperator(node, opsets, WEBGPU_OP_RESOLVE_RULES);
return {impl: op.opImpl, context: op.opInit ? op.opInit(node, graph) : node};
}
}
5 changes: 3 additions & 2 deletions js/web/script/test-runner-cli-args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Options:
-b=<...>, --backend=<...> Specify one or more backend(s) to run the test upon.
Backends can be one or more of the following, splitted by comma:
webgl
webgpu
wasm
-e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following:
chrome (default)
Expand Down Expand Up @@ -97,7 +98,7 @@ Examples:

export declare namespace TestRunnerCliArgs {
type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op';
type Backend = 'cpu'|'webgl'|'wasm'|'onnxruntime';
type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime';
type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs';
type BundleMode = 'prod'|'dev'|'perf';
}
Expand Down Expand Up @@ -333,7 +334,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
}

// Option: -b=<...>, --backend=<...>
const browserBackends = ['webgl', 'wasm'];
const browserBackends = ['webgl', 'webgpu', 'wasm'];
const nodejsBackends = ['cpu', 'wasm'];
const backendArgs = args.backend || args.b;
const backend =
Expand Down

0 comments on commit ed35262

Please sign in to comment.