Skip to content

Commit

Permalink
slice (...)
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 40b15e4 commit 75c7941
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
5 changes: 3 additions & 2 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {concat, parseConcatAttributes} from './ops/concat';
import {gather, parseGatherAttributes} from './ops/gather';
import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm';
import {reshape} from './ops/reshape';
import {parseSliceAttributes, slice, sliceV10} from './ops/slice';
import * as unaryOps from './ops/unary-op';
import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze';

Expand Down Expand Up @@ -63,8 +64,8 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['Resize', '', '11+', resize, parseResizeAttributesV11],
// ['Shape', '', '1+', shape],
['Sigmoid', '', '6+', unaryOps.sigmoid], ['Sin', '', '7+', unaryOps.sin],
// ['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10
// ['Slice', '', '1-9', slice, parseSliceAttributes],
['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10
['Slice', '', '1-9', slice, parseSliceAttributes],
// // The "semantic" meaning of axis has changed in opset-13.
// ['Softmax', '', '1-12', softmax, parseSoftmaxAttributes],
// ['Softmax', '', '13+', softmaxV13, parseSoftmaxAttributesV13],
Expand Down
45 changes: 30 additions & 15 deletions js/web/lib/onnxjs/backends/webgpu/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import {NUMBER_TYPES, OperatorAsyncImplementation, OperatorInitialization} from
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {ProgramInfo, TextureType} from '../types';
import {GpuDataType, ProgramInfo} from '../types';

import {WORKGROUP_SIZE} from './common';

export interface SliceAttributes extends AttributeWithCacheKey {
readonly axes: number[];
Expand All @@ -17,8 +19,7 @@ export interface SliceAttributes extends AttributeWithCacheKey {

const sliceProgramMetadata = {
name: 'Slice',
inputNames: ['A'],
inputTypes: [TextureType.unpacked]
inputTypes: [GpuDataType.default]
};

export const slice: OperatorAsyncImplementation<SliceAttributes> = async(
Expand All @@ -40,7 +41,7 @@ export const parseSliceAttributes: OperatorInitialization<SliceAttributes> = (no
return createAttributeWithCacheKey({starts, ends, axes});
};

const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes): ProgramInfo => {
const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, dataType = 'f32'): ProgramInfo => {
const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((val, i) => i) : attributes.axes;
const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length);
const starts = attributes.starts.map((start, i) => {
Expand All @@ -58,24 +59,39 @@ const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes): Pro

const outputShape = input.dims.slice();

const sliceOps: string[] = [];
const sliceOps: Array<[number, number]> = [];
for (let i = 0; i < normalizedAxes.length; i++) {
outputShape[normalizedAxes[i]] = ends[i] - starts[i];
if (starts[i] > 0) {
sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`);
// sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`);
sliceOps.push([normalizedAxes[i], starts[i]]);
} // else { sliceOps.push(`outputIdx[${normalizedAxes[i]}] += 0;`); }
}

const rank = outputShape.length;
const outputSize = ShapeUtil.size(outputShape);
const outputStrides = ShapeUtil.computeStrides(outputShape);
const shaderSource = `
float process(int outputIdx[${rank}]) {
${sliceOps.join('\n ')}
return _A(outputIdx);
}`;
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> input : array<${dataType}>;
@group(0) @binding(1) var<storage, write> output : array<${dataType}>;
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
if (global_id.x >= ${outputSize}u) {
return;
}
var offset = global_id.x;
${sliceOps.map(i => `offset += ${i[1]}u * ${outputStrides[i[0]]}u;`).join('')}
output[global_id.x] = input[offset];
}`;
return {
...sliceProgramMetadata,
output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked},
shaderSource
outputs: [{dims: outputShape, type: input.type, gpuDataType: GpuDataType.default}],
shaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

Expand All @@ -88,8 +104,7 @@ const validateInputs = (inputs: Tensor[]): void => {
}
};

export const sliceV10: OperatorAsyncImplementation<SliceAttributes> =
async(inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> => {
export const sliceV10 = async(inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> => {
validateInputsV10(inputs);
const attributes = generateSliceAttributesFromInputs(inferenceHandler, inputs);
return inferenceHandler.run(
Expand Down
8 changes: 4 additions & 4 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,10 @@
// "v{7,8,9,10,11,12}/test_split_equal_parts_default_axis",
// "v{7,8,9,10,11,12}/test_split_equal_parts_1d",
// "v{7,8,9,10,11,12}/test_split_equal_parts_2d",
// "v{7,8,9}/test_slice",
// "v{7,8,9}/test_slice_default_axes",
// "v{7,8,9}/test_slice_end_out_of_bounds",
// "v{7,8,9}/test_slice_neg",
"v{7,8,9}/test_slice",
"v{7,8,9}/test_slice_default_axes",
"v{7,8,9}/test_slice_end_out_of_bounds",
"v{7,8,9}/test_slice_neg",
// "test_slice_start_out_of_bounds", // tensor shape of 0
// "test_squeeze",
"test_tan_example",
Expand Down

0 comments on commit 75c7941

Please sign in to comment.