Skip to content

Commit

Permalink
pool
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 59b10fb commit a2197f0
Show file tree
Hide file tree
Showing 4 changed files with 415 additions and 35 deletions.
15 changes: 8 additions & 7 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {conv, parseConvAttributes} from './ops/conv';
import {gather, parseGatherAttributes} from './ops/gather';
import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm';
import {matMul, parseMatMulAttributes} from './ops/matmul';
import {averagePool, globalAveragePool, globalMaxPool, maxPool, parseAveragePoolAttributes, parseGlobalAveragePoolAttributes, parseMaxPoolAttributes} from './ops/pool';
import {reshape} from './ops/reshape';
import {parseSliceAttributes, slice, sliceV10} from './ops/slice';
import * as unaryOps from './ops/unary-op';
Expand All @@ -18,8 +19,8 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
['Abs', '', '6+', unaryOps.abs], ['Acos', '', '7+', unaryOps.acos], ['Add', '', '7+', binaryOps.add],
// ['And', '', '7+', binaryOps.and],
['Asin', '', '7+', unaryOps.asin], ['Atan', '', '7+', unaryOps.atan],
// // TODO: support new attributes for AveragePool-10
// ['AveragePool', '', '7+', averagePool, parseAveragePoolAttributes],
// TODO: support new attributes for AveragePool-10
['AveragePool', '', '7+', averagePool, parseAveragePoolAttributes],
// ['BatchNormalization', '', '7+', batchNormalization, parseBatchNormalizationAttributes],
// ['Cast', '', '6+', cast, parseCastAttributes],
['Ceil', '', '6+', unaryOps.ceil], ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes],
Expand All @@ -34,18 +35,18 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['FusedConv', 'com.microsoft', '1+', conv, parseConvAttributes],
['Gather', '', '1+', gather, parseGatherAttributes], ['Gemm', '', '7-10', gemm, parseGemmAttributesV7],
['Gemm', '', '11+', gemm, parseGemmAttributesV11],
// ['GlobalAveragePool', '', '1+', globalAveragePool, parseGlobalAveragePoolAttributes],
// ['GlobalMaxPool', '', '1+', globalMaxPool],
['GlobalAveragePool', '', '1+', globalAveragePool, parseGlobalAveragePoolAttributes],
['GlobalMaxPool', '', '1+', globalMaxPool],
// ['Greater', '', '7+', binaryOps.greater],
// ['Identity', '', '1+', unaryOps.identity],
// ['ImageScaler', '', '1+', imageScaler, parseImageScalerAttributes],
// ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes],
['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes],
// ['Less', '', '7+', binaryOps.less],
['Log', '', '6+', unaryOps.log], ['MatMul', '', '1+', matMul, parseMatMulAttributes],
// // TODO: support new attributes for MaxPool-8 and MaxPool-10
// ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes],
['Mul', '', '7+', binaryOps.mul], ['Neg', '', '6+', unaryOps.neg],
// TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes], ['Mul', '', '7+', binaryOps.mul],
['Neg', '', '6+', unaryOps.neg],
// ['Not', '', '1+', unaryOps.not],
// ['Or', '', '7+', binaryOps.or],
// ['Pad', '', '2-10', padV2, parsePadAttributesV2],
Expand Down
5 changes: 4 additions & 1 deletion js/web/lib/onnxjs/backends/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ export interface IndicesHelper {
i2oExpression: (varIndices: string, isPtr?: boolean) => string;
/**
* WGSL code of indices variable declaration
*
* @param v - variable name.
* @param init - initial value.
*/
indicesVariableDeclaration: (v: string) => string;
indicesVariableDeclaration: (v: string, init?: string[]) => string;
/**
* data type of indices
*/
Expand Down
Loading

0 comments on commit a2197f0

Please sign in to comment.