Skip to content

Commit

Permalink
concat
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 25c9d2a commit 79dd539
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 12 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {OpSet} from '../../opset';

import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {gather, parseGatherAttributes} from './ops/gather';
import {reshape} from './ops/reshape';
import * as unaryOps from './ops/unary-op';
Expand All @@ -18,8 +19,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['BatchNormalization', '', '7+', batchNormalization, parseBatchNormalizationAttributes],
// ['Cast', '', '6+', cast, parseCastAttributes],
['Ceil', '', '6+', unaryOps.ceil], ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes],
['Clip', '', '11+', unaryOps.clipV11],
// ['Concat', '', '4+', concat, parseConcatAttributes],
['Clip', '', '11+', unaryOps.clipV11], ['Concat', '', '4+', concat, parseConcatAttributes],
// ['Conv', '', '1+', conv, parseConvAttributes],
['Cos', '', '7+', unaryOps.cos], ['Div', '', '7+', binaryOps.div],
// ['Dropout', '', '7+', unaryOps.identity],
Expand Down
13 changes: 10 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ export interface IndicesHelper {
i2oImpl: string;
/**
* WGSL code of function implementation for indices-to-offset
*
* @param isPtr - whether the variable is a pointer. default is false.
*/
i2oExpression: (varIndices: string) => string;
i2oExpression: (varIndices: string, isPtr?: boolean) => string;
/**
* WGSL code of indices variable declaration
*/
indicesVariableDeclaration: (v: string) => string;
/**
* data type of indices
*/
iType: string;
}

export const createIndicesHelper = (name: string, shape: readonly number[]) => {
Expand Down Expand Up @@ -72,9 +78,10 @@ export const createIndicesHelper = (name: string, shape: readonly number[]) => {
return ${offsets.length > 0 ? offsets.join('+') : '0u'};
}`;

const i2oExpression = (varIndices: string) => shape.length < 2 ? varIndices : `ih_i2o_${name}(&${varIndices})`;
const i2oExpression = (varIndices: string, isPtr?: boolean) =>
shape.length < 2 ? `(${isPtr ? '*' : ''}${varIndices})` : `ih_i2o_${name}(${isPtr ? '' : '&'}${varIndices})`;

const indicesVariableDeclaration = (v: string) => `var ${v}:${iType};`;

return {o2iImpl, o2iCall, i2oImpl, i2oExpression, indicesVariableDeclaration};
return {o2iImpl, o2iCall, i2oImpl, i2oExpression, indicesVariableDeclaration, iType};
};
173 changes: 173 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key';
import {Graph} from '../../../graph';
import {OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {createIndicesHelper, IndicesHelper, WORKGROUP_SIZE} from './common';

export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
}

export const concat = async(
inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: ConcatAttributes): Promise<Tensor[]> => {
validateInputs(inputs);
return inferenceHandler.run(createConcatProgramInfoLoader(inputs, attributes), inputs);
};

const createConcatProgramMetadata = (inputCount: number, cacheHint: string) =>
({name: 'Concat', inputTypes: Array(inputCount).fill(GpuDataType.default), cacheHint});

const createConcatProgramInfo =
(metadata: ProgramMetadata, inputs: Tensor[], axis: number, dataType = 'f32'): ProgramInfo => {
const inputShape = inputs[0].dims.slice();
if (axis >= inputShape.length || axis < (-1 * inputShape.length)) {
throw new Error('axis specified for concat doesn\'t match input dimensionality');
}
if (axis < 0) {
axis = inputShape.length + axis;
}
// ensure all of the non-concatenated axes match each other
// calculate the shape of the output tensor while we do that
const outputShape = inputShape.slice(0);
for (let i = 1; i < inputs.length; i++) {
const dataNShape = inputs[i].dims.slice();
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
// add to the placeholder for computing output shape
if (axisIndex === axis) {
outputShape[axis] += dataNShape[axisIndex];
}
// ensure all non-cancatenated axes match each other
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
throw new Error('non concat dimensions must match');
}
}
}

const outputSize = ShapeUtil.size(outputShape);
const rank = outputShape.length;

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputStorageBuffersDeclarations = new Array<string>(inputs.length);
const inputIndicesHelpers = new Array<IndicesHelper>(inputs.length);

let previousSum = 0;
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[axis];
sizeInConcatAxis[i] = previousSum;

inputStorageBuffersDeclarations[i] =
`@group(0) @binding(${i}) var<storage, read> input${i} : array<${dataType}>;`;

inputIndicesHelpers[i] = createIndicesHelper(`input${i}`, inputs[i].dims);
}

const outputIndicesHelper = createIndicesHelper('output', outputShape);

const indicesAxis = rank < 2 ? 'indices' : `indices[${axis}]`;
const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
${inputStorageBuffersDeclarations.join('\n')}
@group(0) @binding(${inputs.length}) var<storage, write> output : array<${dataType}>;
${inputIndicesHelpers.map(i => i.i2oImpl).join('\n')}
${outputIndicesHelper.o2iImpl}
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}>(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
${calculateInputIndexImpl(sizeInConcatAxis.length)}
${readBufferDataImpl(inputIndicesHelpers, rank, dataType)}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
if (global_id.x >= ${outputSize}u) {
return;
}
${outputIndicesHelper.indicesVariableDeclaration('indices')}
${outputIndicesHelper.o2iCall('global_id.x', 'indices')}
let textureIndex = calculateInputIndex(${indicesAxis});
if (textureIndex != 0u) {
${indicesAxis} -= sizeInConcatAxis[textureIndex - 1u];
}
output[global_id.x] = readBufferData(textureIndex, &indices);
}`;
return {
...metadata,
outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}],
shaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

const createConcatProgramInfoLoader = (inputs: Tensor[], attributes: ConcatAttributes): ProgramInfoLoader => {
const metadata = createConcatProgramMetadata(inputs.length, attributes.cacheKey);
return {...metadata, get: () => createConcatProgramInfo(metadata, inputs, attributes.axis)};
};

const calculateInputIndexImpl = (numberOfTensors: number): string => `
fn calculateInputIndex(index: u32) -> u32 {
for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) {
if (index < sizeInConcatAxis[i]) {
return i;
}
}
return ${numberOfTensors}u;
}`;

const readBufferDataImpl = (indicesHelper: readonly IndicesHelper[], tensorRank: number, dataType: string) => {
const numberOfTensors = indicesHelper.length;
const codeLines: string[] = [];
for (let i = 0; i < numberOfTensors; ++i) {
const returnSnippet = `return input${i}[${indicesHelper[i].i2oExpression('indices', true)}];`;
if (i === 0) {
codeLines.push(`if (textureIndex == ${i}u) { ${returnSnippet} }`);
} else if (i === numberOfTensors - 1) {
codeLines.push(`else { ${returnSnippet} }`);
} else {
codeLines.push(`else if (textureIndex == ${i}) { ${returnSnippet} }`);
}
}
return `
fn readBufferData(textureIndex: u32, indices: ptr<function, ${indicesHelper[0].iType}>) -> ${dataType} {
${codeLines.join('\n')}
}`;
};

export const parseConcatAttributes: OperatorInitialization<ConcatAttributes> = (node: Graph.Node): ConcatAttributes =>
createAttributeWithCacheKey({axis: node.attributes.getInt('axis')});

const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}

const inputType = inputs[0].type;
const inputDimensionality = inputs[0].dims.length;

// TODO: Support string concat
if (inputType === 'string') {
throw new Error('string tensor is not supported yet');
}

for (const input of inputs) {
// make sure types of all inputs match
if (input.type !== inputType) {
throw new Error('input tensors should be one type');
}

// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputDimensionality) {
throw new Error('input tensors should have the same shape');
}
}
};
14 changes: 7 additions & 7 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,12 @@
"v{7,8,9,10}/test_clip_default_max",
"v{7,8,9,10}/test_clip_default_inbounds",
"v{7,8,9,10}/test_clip",
// "test_concat_1d_axis_0",
// "test_concat_2d_axis_0",
// "test_concat_2d_axis_1",
// "test_concat_3d_axis_0",
// "test_concat_3d_axis_1",
// "test_concat_3d_axis_2",
"test_concat_1d_axis_0",
"test_concat_2d_axis_0",
"test_concat_2d_axis_1",
"test_concat_3d_axis_0",
"test_concat_3d_axis_1",
"test_concat_3d_axis_2",
// "test_conv_with_strides_and_asymmetric_padding",
// "test_conv_with_strides_no_padding",
// "test_conv_with_strides_padding",
Expand Down Expand Up @@ -514,7 +514,7 @@
//"and.jsonc",
"asin.jsonc",
"ceil.jsonc",
//"concat.jsonc",
"concat.jsonc",
//"conv.jsonc",
"cos.jsonc",
"div.jsonc",
Expand Down

0 comments on commit 79dd539

Please sign in to comment.