Skip to content

Commit

Permalink
naive conv
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 7c5e446 commit 4ed1bfb
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 7 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {OpSet} from '../../opset';

import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {gather, parseGatherAttributes} from './ops/gather';
import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm';
import {matMul, parseMatMulAttributes} from './ops/matmul';
Expand All @@ -23,8 +24,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['Cast', '', '6+', cast, parseCastAttributes],
['Ceil', '', '6+', unaryOps.ceil], ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes],
['Clip', '', '11+', unaryOps.clipV11], ['Concat', '', '4+', concat, parseConcatAttributes],
// ['Conv', '', '1+', conv, parseConvAttributes],
['Cos', '', '7+', unaryOps.cos], ['Div', '', '7+', binaryOps.div],
['Conv', '', '1+', conv, parseConvAttributes], ['Cos', '', '7+', unaryOps.cos], ['Div', '', '7+', binaryOps.div],
// ['Dropout', '', '7+', unaryOps.identity],
// ['DepthToSpace', '', '1+', depthToSpace, parseDepthToSpaceAttributes],
// ['Equal', '', '7+', binaryOps.equal],
Expand Down
3 changes: 2 additions & 1 deletion js/web/lib/onnxjs/backends/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ export const createIndicesHelper = (name: string, shape: readonly number[]) => {
const i2oExpression = (varIndices: string, isPtr?: boolean) =>
shape.length < 2 ? `(${isPtr ? '*' : ''}${varIndices})` : `ih_i2o_${name}(${isPtr ? '' : '&'}${varIndices})`;

const indicesVariableDeclaration = (v: string) => `var ${v}:${iType};`;
const indicesVariableDeclaration = (v: string, init?: string[]) =>
`var ${v}:${iType}${init ? `=${iType}(${init.join(',')})` : ''};`;

return {o2iImpl, o2iCall, i2oImpl, i2oExpression, indicesVariableDeclaration, iType};
};
127 changes: 127 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Logger} from '../../../instrument';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {createIndicesHelper, WORKGROUP_SIZE} from './common';
import {calculateOutputShape, ConvAttributes} from './conv';
import {getActicationSnippet} from './fuse-utils';

const createGroupedConvProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({
name: 'GroupedConv',
inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] :
[GpuDataType.default, GpuDataType.default],
cacheHint
});

const createGroupedConvProgramInfo =
(inferenceHandler: WebGpuInferenceHandler, inputs: readonly Tensor[], metadata: ProgramMetadata,
attributes: ConvAttributes): ProgramInfo => {
const hasBias = inputs.length > 2;
const processBias = hasBias ? 'value += b[output_channel];' : '';
const xShape = inputs[0].dims;
const wShape = inputs[1].dims;
const outputChannelsPerGroup = wShape[0] / attributes.group;

const dataType = 'f32'; // TODO: support other data type
const {activationFunction, applyActivation} = getActicationSnippet(attributes);
const inputStorageBuffersDeclarations = [
`@group(0) @binding(0) var<storage, read> x : array<${dataType}>;`,
`@group(0) @binding(1) var<storage, read> w : array<${dataType}>;`
];
if (hasBias) {
inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var<storage, read> b : array<${dataType}>;`);
}

Logger.verbose(
'GroupedConv',
`autpPad:${attributes.autoPad}, dilations:${attributes.dilations}, group:${attributes.group}, kernelShape:${
attributes.kernelShape}, pads:${attributes.pads}, strides:${attributes.strides}`);
const outputShape =
calculateOutputShape(xShape, wShape, attributes.dilations, attributes.pads, attributes.strides);
const outputSize = ShapeUtil.size(outputShape);
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const xIndicesHelper = createIndicesHelper('x', xShape);
const wIndicesHelper = createIndicesHelper('w', wShape);

const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
let strides: vec2<u32> = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u);
let pads: vec2<u32> = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u);
${inputStorageBuffersDeclarations.join('\n')}
@group(0) @binding(${inputStorageBuffersDeclarations.length}) var<storage, write> output : array<${dataType}>;
${activationFunction}
${outputIndicesHelper.o2iImpl}
${xIndicesHelper.i2oImpl}
${wIndicesHelper.i2oImpl}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
if (global_id.x >= ${outputSize}u) {
return;
}
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
${outputIndicesHelper.o2iCall('global_id.x', 'outputIndices')}
let batch: u32 = outputIndices[0];
let output_channel: u32 = outputIndices[1];
let xRCCorner: vec2<u32> = vec2<u32>(outputIndices[2], outputIndices[3]) * strides - pads;
let group_id: u32 = output_channel / ${outputChannelsPerGroup}u;
var value: ${dataType} = ${dataType}(0);
for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) {
let input_channel = group_id * ${wShape[1]}u + wInChannel;
for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) {
let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u;
if (xHeight < 0u || xHeight >= ${xShape[2]}u) {
continue;
}
for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) {
let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u;
if (xWidth < 0u || xWidth >= ${xShape[3]}u) {
continue;
}
${
xIndicesHelper.indicesVariableDeclaration(
'xIndices',
[
'batch', 'input_channel', 'xHeight', 'xWidth'
])}
let xVal = x[${xIndicesHelper.i2oExpression('xIndices')}];
${
wIndicesHelper.indicesVariableDeclaration('wIndices', [
'output_channel', 'wInChannel', 'wHeight', 'wWidth'
])}
let wVal = w[${wIndicesHelper.i2oExpression('wIndices')}];
value += xVal*wVal;
}
}
}
${processBias}
${applyActivation}
output[global_id.x] = value;
}`;
return {
...metadata,
outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}],
shaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

export const createGroupedConvProgramInfoLoader =
(inferenceHandler: WebGpuInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes):
ProgramInfoLoader => {
const metadata = createGroupedConvProgramMetadata(inputs.length > 2, attributes.cacheKey);
return {...metadata, get: () => createGroupedConvProgramInfo(inferenceHandler, inputs, metadata, attributes)};
};
150 changes: 150 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key';
import {InferenceHandler} from '../../../backend';
import {Graph} from '../../../graph';
import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {PoolConvUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';

import {createGroupedConvProgramInfoLoader} from './conv-grouped';
// import {createDotProductProgramInfoLoader} from './dot-product';
import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';

// import {createIm2ColProgramInfoLoader} from './im2col';
// import {createMatmulProgramInfoLoader} from './matmul';


export const calculateOutputShape =
(inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[],
adjustPads: readonly number[], strides: readonly number[]): number[] => {
const batchSize = inputShape[0];
const inputSpatialShape = inputShape.slice(2);
const spatialRank = inputSpatialShape.length;
const outChannels = kernelShape[0];
const kernelSpatialShape = kernelShape.slice(2);
const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1));
const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]);
const outputSpatialShape =
inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]));
const outputShape = [batchSize, outChannels].concat(...outputSpatialShape);
return outputShape;
};

export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey {
readonly autoPad: string;
readonly dilations: readonly number[];
readonly group: number;
readonly kernelShape: readonly number[];
readonly pads: readonly number[];
readonly strides: readonly number[];
}

export const conv: OperatorAsyncImplementation<ConvAttributes> =
async(inferenceHandler: InferenceHandler, inputs: Tensor[], attributes: ConvAttributes): Promise<Tensor[]> => {
validateInputs(inputs, attributes); // currently will fail if not conv2D
return conv2d(inferenceHandler, inputs, attributes);
};

const conv2d: OperatorAsyncImplementation<ConvAttributes> = async(
inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: ConvAttributes): Promise<Tensor[]> => {
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);
// const isPointwise = adjustedAttributes.kernelShape[0] === 1 && adjustedAttributes.kernelShape[1] === 1;
// if (adjustedAttributes.group > 1) {
return inferenceHandler.run(createGroupedConvProgramInfoLoader(inferenceHandler, inputs, adjustedAttributes), inputs);
// } else if (isPointwise) {
// return conv2DPointwise(inferenceHandler, inputs, adjustedAttributes);
// } else {
// return conv2D(inferenceHandler, inputs, adjustedAttributes);
// }
};

const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: Tensor[]): T => {
const kernelShape = attributes.kernelShape.slice();
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
if (attributes.kernelShape.length === 0) {
for (let i = 2; i < inputs[1].dims.length; ++i) {
kernelShape.push(inputs[1].dims[i]);
}
}
const pads = attributes.pads.slice();
PoolConvUtil.adjustPadsBasedOnAutoPad(
inputs[0].dims, attributes.strides, attributes.dilations, kernelShape, pads, attributes.autoPad);

// always return a new object so does not modify the original attributes
const newAttributes: T = Object.assign({}, attributes);
Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey});
return newAttributes;
};

export const parseConvAttributes: OperatorInitialization<ConvAttributes> = (node: Graph.Node): ConvAttributes => {
const attributes = node.attributes;
const activationAttributes = parseInternalActivationAttributes(attributes);
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv
const autoPad = attributes.getString('auto_pad', 'NOTSET');
const dilations = attributes.getInts('dilations', [1, 1]);
const group = attributes.getInt('group', 1);
const kernelShape = attributes.getInts('kernel_shape', []);
const pads = attributes.getInts('pads', [0, 0, 0, 0]);
const strides = attributes.getInts('strides', [1, 1]);

return createAttributeWithCacheKey({autoPad, dilations, group, kernelShape, pads, strides, ...activationAttributes});
};

const validateInputs = (inputs: Tensor[], attributes: ConvAttributes): void => {
// Refer to the below link for all input checks
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
throw new Error('Conv requires 2 or 3 inputs');
}

// TODO : Need to add support for multi-dimensional conv
if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) {
throw new Error('currently only support 2-dimensional conv');
}

// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
const dataChannel = inputs[0].dims[1];
const filterInChannel = inputs[1].dims[1] * attributes.group;
if (dataChannel !== filterInChannel) {
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
}

// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) {
throw new Error('invalid bias');
}

const spatialRank = inputs[0].dims.length - 2;
// wrong dilations dimension
if (attributes.dilations.length !== spatialRank) {
throw new Error(`dilations should be ${spatialRank}D`);
}

// Wrong strides dimension
if (attributes.strides.length !== spatialRank) {
throw new Error(`strides should be ${spatialRank}D`);
}

// Wrong pads dimension
if (attributes.pads.length !== spatialRank * 2) {
throw new Error(`pads should be ${spatialRank * 2}D`);
}

// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
// (the first 2 dims are batch_size and channels)
if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
throw new Error('invalid kernel shape');
}

// TODO : Need to add support for float64
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') {
throw new Error('Conv input(X,W) should be float tensor');
}

if (inputs.length === 3 && inputs[2].type !== 'float32') {
throw new Error('Conv input(bias) should be float tensor');
}
};
8 changes: 4 additions & 4 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@
"test_concat_3d_axis_0",
"test_concat_3d_axis_1",
"test_concat_3d_axis_2",
// "test_conv_with_strides_and_asymmetric_padding",
// "test_conv_with_strides_no_padding",
// "test_conv_with_strides_padding",
"test_conv_with_strides_and_asymmetric_padding",
"test_conv_with_strides_no_padding",
"test_conv_with_strides_padding",
"test_constant",
"test_cos_example",
"test_cos",
Expand Down Expand Up @@ -515,7 +515,7 @@
"asin.jsonc",
"ceil.jsonc",
"concat.jsonc",
//"conv.jsonc",
"conv.jsonc",
"cos.jsonc",
"div.jsonc",
//"depth-to-space.jsonc",
Expand Down

0 comments on commit 4ed1bfb

Please sign in to comment.