-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
285 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {Logger} from '../../../instrument'; | ||
import {Tensor} from '../../../tensor'; | ||
import {ShapeUtil} from '../../../util'; | ||
import {WebGpuInferenceHandler} from '../inference-handler'; | ||
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; | ||
|
||
import {createIndicesHelper, WORKGROUP_SIZE} from './common'; | ||
import {calculateOutputShape, ConvAttributes} from './conv'; | ||
import {getActicationSnippet} from './fuse-utils'; | ||
|
||
const createGroupedConvProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ | ||
name: 'GroupedConv', | ||
inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : | ||
[GpuDataType.default, GpuDataType.default], | ||
cacheHint | ||
}); | ||
|
||
const createGroupedConvProgramInfo = | ||
(inferenceHandler: WebGpuInferenceHandler, inputs: readonly Tensor[], metadata: ProgramMetadata, | ||
attributes: ConvAttributes): ProgramInfo => { | ||
const hasBias = inputs.length > 2; | ||
const processBias = hasBias ? 'value += b[output_channel];' : ''; | ||
const xShape = inputs[0].dims; | ||
const wShape = inputs[1].dims; | ||
const outputChannelsPerGroup = wShape[0] / attributes.group; | ||
|
||
const dataType = 'f32'; // TODO: support other data type | ||
const {activationFunction, applyActivation} = getActicationSnippet(attributes); | ||
const inputStorageBuffersDeclarations = [ | ||
`@group(0) @binding(0) var<storage, read> x : array<${dataType}>;`, | ||
`@group(0) @binding(1) var<storage, read> w : array<${dataType}>;` | ||
]; | ||
if (hasBias) { | ||
inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var<storage, read> b : array<${dataType}>;`); | ||
} | ||
|
||
Logger.verbose( | ||
'GroupedConv', | ||
`autpPad:${attributes.autoPad}, dilations:${attributes.dilations}, group:${attributes.group}, kernelShape:${ | ||
attributes.kernelShape}, pads:${attributes.pads}, strides:${attributes.strides}`); | ||
const outputShape = | ||
calculateOutputShape(xShape, wShape, attributes.dilations, attributes.pads, attributes.strides); | ||
const outputSize = ShapeUtil.size(outputShape); | ||
const outputIndicesHelper = createIndicesHelper('output', outputShape); | ||
const xIndicesHelper = createIndicesHelper('x', xShape); | ||
const wIndicesHelper = createIndicesHelper('w', wShape); | ||
|
||
const shaderSource = ` | ||
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u; | ||
let strides: vec2<u32> = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); | ||
let pads: vec2<u32> = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); | ||
${inputStorageBuffersDeclarations.join('\n')} | ||
@group(0) @binding(${inputStorageBuffersDeclarations.length}) var<storage, write> output : array<${dataType}>; | ||
${activationFunction} | ||
${outputIndicesHelper.o2iImpl} | ||
${xIndicesHelper.i2oImpl} | ||
${wIndicesHelper.i2oImpl} | ||
@stage(compute) @workgroup_size(WORKGROUP_SIZE) | ||
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) { | ||
// Guard against out-of-bounds work group sizes | ||
if (global_id.x >= ${outputSize}u) { | ||
return; | ||
} | ||
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} | ||
${outputIndicesHelper.o2iCall('global_id.x', 'outputIndices')} | ||
let batch: u32 = outputIndices[0]; | ||
let output_channel: u32 = outputIndices[1]; | ||
let xRCCorner: vec2<u32> = vec2<u32>(outputIndices[2], outputIndices[3]) * strides - pads; | ||
let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; | ||
var value: ${dataType} = ${dataType}(0); | ||
for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { | ||
let input_channel = group_id * ${wShape[1]}u + wInChannel; | ||
for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { | ||
let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; | ||
if (xHeight < 0u || xHeight >= ${xShape[2]}u) { | ||
continue; | ||
} | ||
for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { | ||
let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; | ||
if (xWidth < 0u || xWidth >= ${xShape[3]}u) { | ||
continue; | ||
} | ||
${ | ||
xIndicesHelper.indicesVariableDeclaration( | ||
'xIndices', | ||
[ | ||
'batch', 'input_channel', 'xHeight', 'xWidth' | ||
])} | ||
let xVal = x[${xIndicesHelper.i2oExpression('xIndices')}]; | ||
${ | ||
wIndicesHelper.indicesVariableDeclaration('wIndices', [ | ||
'output_channel', 'wInChannel', 'wHeight', 'wWidth' | ||
])} | ||
let wVal = w[${wIndicesHelper.i2oExpression('wIndices')}]; | ||
value += xVal*wVal; | ||
} | ||
} | ||
} | ||
${processBias} | ||
${applyActivation} | ||
output[global_id.x] = value; | ||
}`; | ||
return { | ||
...metadata, | ||
outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}], | ||
shaderSource, | ||
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) | ||
}; | ||
}; | ||
|
||
export const createGroupedConvProgramInfoLoader = | ||
(inferenceHandler: WebGpuInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes): | ||
ProgramInfoLoader => { | ||
const metadata = createGroupedConvProgramMetadata(inputs.length > 2, attributes.cacheKey); | ||
return {...metadata, get: () => createGroupedConvProgramInfo(inferenceHandler, inputs, metadata, attributes)}; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; | ||
import {InferenceHandler} from '../../../backend'; | ||
import {Graph} from '../../../graph'; | ||
import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators'; | ||
import {Tensor} from '../../../tensor'; | ||
import {PoolConvUtil} from '../../../util'; | ||
import {WebGpuInferenceHandler} from '../inference-handler'; | ||
|
||
import {createGroupedConvProgramInfoLoader} from './conv-grouped'; | ||
// import {createDotProductProgramInfoLoader} from './dot-product'; | ||
import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; | ||
|
||
// import {createIm2ColProgramInfoLoader} from './im2col'; | ||
// import {createMatmulProgramInfoLoader} from './matmul'; | ||
|
||
|
||
export const calculateOutputShape = | ||
(inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], | ||
adjustPads: readonly number[], strides: readonly number[]): number[] => { | ||
const batchSize = inputShape[0]; | ||
const inputSpatialShape = inputShape.slice(2); | ||
const spatialRank = inputSpatialShape.length; | ||
const outChannels = kernelShape[0]; | ||
const kernelSpatialShape = kernelShape.slice(2); | ||
const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); | ||
const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); | ||
const outputSpatialShape = | ||
inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i])); | ||
const outputShape = [batchSize, outChannels].concat(...outputSpatialShape); | ||
return outputShape; | ||
}; | ||
|
||
export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey { | ||
readonly autoPad: string; | ||
readonly dilations: readonly number[]; | ||
readonly group: number; | ||
readonly kernelShape: readonly number[]; | ||
readonly pads: readonly number[]; | ||
readonly strides: readonly number[]; | ||
} | ||
|
||
export const conv: OperatorAsyncImplementation<ConvAttributes> = | ||
async(inferenceHandler: InferenceHandler, inputs: Tensor[], attributes: ConvAttributes): Promise<Tensor[]> => { | ||
validateInputs(inputs, attributes); // currently will fail if not conv2D | ||
return conv2d(inferenceHandler, inputs, attributes); | ||
}; | ||
|
||
const conv2d: OperatorAsyncImplementation<ConvAttributes> = async( | ||
inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: ConvAttributes): Promise<Tensor[]> => { | ||
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs); | ||
// const isPointwise = adjustedAttributes.kernelShape[0] === 1 && adjustedAttributes.kernelShape[1] === 1; | ||
// if (adjustedAttributes.group > 1) { | ||
return inferenceHandler.run(createGroupedConvProgramInfoLoader(inferenceHandler, inputs, adjustedAttributes), inputs); | ||
// } else if (isPointwise) { | ||
// return conv2DPointwise(inferenceHandler, inputs, adjustedAttributes); | ||
// } else { | ||
// return conv2D(inferenceHandler, inputs, adjustedAttributes); | ||
// } | ||
}; | ||
|
||
const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: Tensor[]): T => { | ||
const kernelShape = attributes.kernelShape.slice(); | ||
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims | ||
if (attributes.kernelShape.length === 0) { | ||
for (let i = 2; i < inputs[1].dims.length; ++i) { | ||
kernelShape.push(inputs[1].dims[i]); | ||
} | ||
} | ||
const pads = attributes.pads.slice(); | ||
PoolConvUtil.adjustPadsBasedOnAutoPad( | ||
inputs[0].dims, attributes.strides, attributes.dilations, kernelShape, pads, attributes.autoPad); | ||
|
||
// always return a new object so does not modify the original attributes | ||
const newAttributes: T = Object.assign({}, attributes); | ||
Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); | ||
return newAttributes; | ||
}; | ||
|
||
export const parseConvAttributes: OperatorInitialization<ConvAttributes> = (node: Graph.Node): ConvAttributes => { | ||
const attributes = node.attributes; | ||
const activationAttributes = parseInternalActivationAttributes(attributes); | ||
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv | ||
const autoPad = attributes.getString('auto_pad', 'NOTSET'); | ||
const dilations = attributes.getInts('dilations', [1, 1]); | ||
const group = attributes.getInt('group', 1); | ||
const kernelShape = attributes.getInts('kernel_shape', []); | ||
const pads = attributes.getInts('pads', [0, 0, 0, 0]); | ||
const strides = attributes.getInts('strides', [1, 1]); | ||
|
||
return createAttributeWithCacheKey({autoPad, dilations, group, kernelShape, pads, strides, ...activationAttributes}); | ||
}; | ||
|
||
const validateInputs = (inputs: Tensor[], attributes: ConvAttributes): void => { | ||
// Refer to the below link for all input checks | ||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv | ||
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) { | ||
throw new Error('Conv requires 2 or 3 inputs'); | ||
} | ||
|
||
// TODO : Need to add support for multi-dimensional conv | ||
if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) { | ||
throw new Error('currently only support 2-dimensional conv'); | ||
} | ||
|
||
// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL | ||
const dataChannel = inputs[0].dims[1]; | ||
const filterInChannel = inputs[1].dims[1] * attributes.group; | ||
if (dataChannel !== filterInChannel) { | ||
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL'); | ||
} | ||
|
||
// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps | ||
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) { | ||
throw new Error('invalid bias'); | ||
} | ||
|
||
const spatialRank = inputs[0].dims.length - 2; | ||
// wrong dilations dimension | ||
if (attributes.dilations.length !== spatialRank) { | ||
throw new Error(`dilations should be ${spatialRank}D`); | ||
} | ||
|
||
// Wrong strides dimension | ||
if (attributes.strides.length !== spatialRank) { | ||
throw new Error(`strides should be ${spatialRank}D`); | ||
} | ||
|
||
// Wrong pads dimension | ||
if (attributes.pads.length !== spatialRank * 2) { | ||
throw new Error(`pads should be ${spatialRank * 2}D`); | ||
} | ||
|
||
// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor | ||
// (the first 2 dims are batch_size and channels) | ||
if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { | ||
throw new Error('invalid kernel shape'); | ||
} | ||
|
||
// TODO : Need to add support for float64 | ||
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') { | ||
throw new Error('Conv input(X,W) should be float tensor'); | ||
} | ||
|
||
if (inputs.length === 3 && inputs[2].type !== 'float32') { | ||
throw new Error('Conv input(bias) should be float tensor'); | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters