-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[POC] __blank ( npm test -- -b=webgpu )
- Loading branch information
Showing
8 changed files
with
199 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {env} from 'onnxruntime-common'; | ||
import {Backend, SessionHandler} from '../backend'; | ||
import {Logger} from '../instrument'; | ||
import {Session} from '../session'; | ||
|
||
import {WebGpuSessionHandler} from './webgpu/session-handler'; | ||
|
||
export class WebGpuBackend implements Backend { | ||
initialize(): boolean { | ||
try { | ||
// STEP.1 TODO: set up context (one time initialization) | ||
|
||
// STEP.2 TODO: set up flags | ||
|
||
Logger.setWithEnv(env); | ||
|
||
Logger.verbose('WebGpuBackend', 'Initialized successfully.'); | ||
return true; | ||
} catch (e) { | ||
Logger.warning('WebGpuBackend', `Unable to initialize WebGLBackend. ${e}`); | ||
return false; | ||
} | ||
} | ||
createSessionHandler(context: Session.Context): SessionHandler { | ||
return new WebGpuSessionHandler(this, context); | ||
} | ||
dispose(): void { | ||
// TODO: uninitialization | ||
// this.glContext.dispose(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {InferenceHandler} from '../../backend'; | ||
|
||
import {WebGpuSessionHandler} from './session-handler'; | ||
|
||
export class WebGpuInferenceHandler implements InferenceHandler { | ||
constructor(public session: WebGpuSessionHandler) { | ||
// TODO: | ||
} | ||
|
||
dispose(): void {} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {OpSet} from '../../opset'; | ||
|
||
export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ | ||
// ['Abs', '', '6+', unaryOps.abs], | ||
// ['Acos', '', '7+', unaryOps.acos], | ||
// ['Add', '', '7+', binaryOps.add], | ||
// ['And', '', '7+', binaryOps.and], | ||
// ['Asin', '', '7+', unaryOps.asin], | ||
// ['Atan', '', '7+', unaryOps.atan], | ||
// // TODO: support new attributes for AveragePool-10 | ||
// ['AveragePool', '', '7+', averagePool, parseAveragePoolAttributes], | ||
// ['BatchNormalization', '', '7+', batchNormalization, parseBatchNormalizationAttributes], | ||
// ['Cast', '', '6+', cast, parseCastAttributes], | ||
// ['Ceil', '', '6+', unaryOps.ceil], | ||
// ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes], | ||
// ['Clip', '', '11+', unaryOps.clipV11], | ||
// ['Concat', '', '4+', concat, parseConcatAttributes], | ||
// ['Conv', '', '1+', conv, parseConvAttributes], | ||
// ['Cos', '', '7+', unaryOps.cos], | ||
// ['Div', '', '7+', binaryOps.div], | ||
// ['Dropout', '', '7+', unaryOps.identity], | ||
// ['DepthToSpace', '', '1+', depthToSpace, parseDepthToSpaceAttributes], | ||
// ['Equal', '', '7+', binaryOps.equal], | ||
// ['Elu', '', '6+', unaryOps.elu, unaryOps.parseEluAttributes], | ||
// ['Exp', '', '6+', unaryOps.exp], | ||
// ['Flatten', '', '1+', flatten, parseFlattenAttributes], | ||
// ['Floor', '', '6+', unaryOps.floor], | ||
// ['FusedConv', 'com.microsoft', '1+', conv, parseConvAttributes], | ||
// ['Gather', '', '1+', gather, parseGatherAttributes], | ||
// ['Gemm', '', '7-10', gemm, parseGemmAttributesV7], | ||
// ['Gemm', '', '11+', gemm, parseGemmAttributesV11], | ||
// ['GlobalAveragePool', '', '1+', globalAveragePool, parseGlobalAveragePoolAttributes], | ||
// ['GlobalMaxPool', '', '1+', globalMaxPool], | ||
// ['Greater', '', '7+', binaryOps.greater], | ||
// ['Identity', '', '1+', unaryOps.identity], | ||
// ['ImageScaler', '', '1+', imageScaler, parseImageScalerAttributes], | ||
// ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes], | ||
// ['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes], | ||
// ['Less', '', '7+', binaryOps.less], | ||
// ['Log', '', '6+', unaryOps.log], | ||
// ['MatMul', '', '1+', matMul, parseMatMulAttributes], | ||
// // TODO: support new attributes for MaxPool-8 and MaxPool-10 | ||
// ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes], | ||
// ['Mul', '', '7+', binaryOps.mul], | ||
// ['Neg', '', '6+', unaryOps.neg], | ||
// ['Not', '', '1+', unaryOps.not], | ||
// ['Or', '', '7+', binaryOps.or], | ||
// ['Pad', '', '2-10', padV2, parsePadAttributesV2], | ||
// ['Pad', '', '11+', padV11, parsePadAttributesV11], | ||
// ['Pow', '', '7+', binaryOps.pow], | ||
// ['PRelu', '', '7+', binaryOps.pRelu], | ||
// ['ReduceLogSum', '', '1+', reduceLogSum, parseReduceAttributes], | ||
// ['ReduceMax', '', '1+', reduceMax, parseReduceAttributes], | ||
// ['ReduceMean', '', '1+', reduceMean, parseReduceAttributes], | ||
// ['ReduceMin', '', '1+', reduceMin, parseReduceAttributes], | ||
// ['ReduceProd', '', '1+', reduceProd, parseReduceAttributes], | ||
// ['ReduceSum', '', '1-12', reduceSum, parseReduceAttributes], | ||
// ['ReduceSumSquare', '', '1+', reduceLogSumSquare, parseReduceAttributes], | ||
// ['Relu', '', '6+', unaryOps.relu], | ||
// ['Reshape', '', '5+', reshape], | ||
// ['Resize', '', '10', resize, parseResizeAttributesV10], | ||
// ['Resize', '', '11+', resize, parseResizeAttributesV11], | ||
// ['Shape', '', '1+', shape], | ||
// ['Sigmoid', '', '6+', unaryOps.sigmoid], | ||
// ['Sin', '', '7+', unaryOps.sin], | ||
// ['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10 | ||
// ['Slice', '', '1-9', slice, parseSliceAttributes], | ||
// // The "semantic" meaning of axis has changed in opset-13. | ||
// ['Softmax', '', '1-12', softmax, parseSoftmaxAttributes], | ||
// ['Softmax', '', '13+', softmaxV13, parseSoftmaxAttributesV13], | ||
// // 'Split' operator has an optional attribute 'split' | ||
// // this attribute determines how the specified axis of input data is split. | ||
// // When the attribute is missing, we need the count of number of outputs | ||
// // so that we can determine the 'split' attribute from the runtime input to the Operator | ||
// ['Split', '', '2-12', split, parseSplitAttributes], | ||
// ['Sqrt', '', '6+', unaryOps.sqrt], | ||
// ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes], | ||
// ['Squeeze', '', '13+', squeezeV13], | ||
// ['Sub', '', '7+', binaryOps.sub], | ||
// ['Sum', '', '6+', sum], | ||
// ['Tan', '', '7+', unaryOps.tan], | ||
// ['Tanh', '', '6+', unaryOps.tanh], | ||
// ['Tile', '', '6+', tile], | ||
// ['Transpose', '', '1+', transpose, parseTransposeAttributes], | ||
// ['Upsample', '', '7-8', upsample, parseUpsampleAttributesV7], | ||
// ['Upsample', '', '9', upsample, parseUpsampleAttributesV9], | ||
// ['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes], | ||
// ['Unsqueeze', '', '13+', unsqueezeV13], | ||
// ['Xor', '', '7+', binaryOps.xor], | ||
]; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {SessionHandler} from '../../backend'; | ||
import {Graph} from '../../graph'; | ||
import {Operator} from '../../operators'; | ||
import {OpSet, resolveOperator} from '../../opset'; | ||
import {Session} from '../../session'; | ||
import {Tensor} from '../../tensor'; | ||
import {WebGpuBackend} from '../backend-webgpu'; | ||
import {WebGpuInferenceHandler} from './inference-handler'; | ||
|
||
import {WEBGPU_OP_RESOLVE_RULES} from './op-resolve-rules'; | ||
|
||
export class WebGpuSessionHandler implements SessionHandler { | ||
private initializers: Set<Tensor.Id>; | ||
|
||
constructor(public readonly backend: WebGpuBackend, public readonly context: Session.Context) { | ||
// TODO | ||
} | ||
|
||
createInferenceHandler() { | ||
return new WebGpuInferenceHandler(this); | ||
} | ||
onGraphInitialized(graph: Graph): void { | ||
const initializers = graph.getValues().filter(v => v.from === -1 && v.tensor).map(v => v.tensor!.dataId); | ||
this.initializers = new Set(initializers); | ||
} | ||
isInitializer(tensorId: Tensor.Id): boolean { | ||
return this.initializers ? this.initializers.has(tensorId) : false; | ||
} | ||
addInitializer(tensorId: Tensor.Id): void { | ||
this.initializers.add(tensorId); | ||
} | ||
dispose(): void { | ||
// TODO | ||
} | ||
resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator { | ||
const op = resolveOperator(node, opsets, WEBGPU_OP_RESOLVE_RULES); | ||
return {impl: op.opImpl, context: op.opInit ? op.opInit(node, graph) : node}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters