Skip to content

Commit

Permalink
slice (scalar)
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 75c7941 commit e871138
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions js/web/lib/onnxjs/backends/webgpu/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,29 @@ export const parseSliceAttributes: OperatorInitialization<SliceAttributes> = (no
return createAttributeWithCacheKey({starts, ends, axes});
};

const offsetToIndices = (offset: string, strides: readonly number[], indicesPrefix: string): string => {
const outputLines: string[] = [];

for (let i = 0; i < strides.length - 1; i++) {
outputLines.push(`var ${indicesPrefix}${i}=${offset}/${strides[i]}u;`);
outputLines.push(`${offset}%=${strides[i]}u;`);
}
outputLines.push(`var ${indicesPrefix}${strides.length - 1}=${offset};`);

return outputLines.join('\n');
};

const indicesToOffset = (indicesPrefix: string, strides: readonly number[], offset: string): string => {
const outputLines: string[] = [];

for (let i = 0; i < strides.length - 1; i++) {
outputLines.push(`${offset}+=${indicesPrefix}${i} * ${strides[i]}u;`);
}
outputLines.push(`${offset}+=${indicesPrefix}${strides.length - 1};`);

return outputLines.join('\n');
};

const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, dataType = 'f32'): ProgramInfo => {
const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((val, i) => i) : attributes.axes;
const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length);
Expand All @@ -59,12 +82,11 @@ const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, data

const outputShape = input.dims.slice();

const sliceOps: Array<[number, number]> = [];
const sliceOps: string[] = [];
for (let i = 0; i < normalizedAxes.length; i++) {
outputShape[normalizedAxes[i]] = ends[i] - starts[i];
if (starts[i] > 0) {
// sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`);
sliceOps.push([normalizedAxes[i], starts[i]]);
sliceOps.push(`idx_${normalizedAxes[i]} += ${starts[i]}u;`);
} // else { sliceOps.push(`outputIdx[${normalizedAxes[i]}] += 0;`); }
}

Expand All @@ -84,8 +106,11 @@ const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, data
}
var offset = global_id.x;
${sliceOps.map(i => `offset += ${i[1]}u * ${outputStrides[i[0]]}u;`).join('')}
output[global_id.x] = input[offset];
${offsetToIndices('offset', outputStrides, 'idx_')}
${sliceOps.join('')}
var offsetInput = 0u;
${indicesToOffset('idx_', ShapeUtil.computeStrides(input.dims), 'offsetInput')}
output[global_id.x] = input[offsetInput];
}`;
return {
...sliceProgramMetadata,
Expand Down

0 comments on commit e871138

Please sign in to comment.