Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 49 additions & 39 deletions js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type:
}
};

const calcDataOffsetSnippet = (dataRank: number, parallel: boolean) =>
`${
dataRank === 1
? `
let element_count_dim = uniforms.output_strides;
let dim_value = uniforms.output_shape;`
: `
let element_count_dim = uniforms.output_strides[${parallel ? 'i - indices_start' : 'i'}];
let dim_value = uniforms.output_shape[${parallel ? 'i - indices_start' : 'i'} + uniforms.last_index_dimension];`
}

if (index >= 0) {
if (index >= i32(dim_value)) {
index = i32(dim_value - 1);
}
} else {
if (index < -i32(dim_value)) {
index = 0;
} else {
index += i32(dim_value);
}
}
data_offset += u32((u32(index) * element_count_dim));`;

const updateElementsSnippet = (attributes: ScatterNDAttributes, outputTypeValue: ReductionType, parallel: boolean) =>
`for (var i = 0u; i < uniforms.num_updates_elements; i++) {
let value = updates[uniforms.num_updates_elements * ${parallel ? 'global_idx' : 'idx'} + i];
${atomicReductionSnippet(attributes.reduction, 'output[data_offset + i]', 'value', outputTypeValue)}
}`;

const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const indicesShape = inputs[1].dims;
Expand All @@ -87,6 +117,7 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components);
const lastIndexDimension = indicesShape[indicesShape.length - 1];
const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension);
const numIndicesElements = ShapeUtil.sizeFromDimension(indicesShape, 0) / lastIndexDimension;

const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
Expand All @@ -113,9 +144,8 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var hasDuplicates = false;
if (${attributes.reduction === 'none'}) {
let n = ${ShapeUtil.size(indicesShape)};
for (var i = 0; i < n; i = i + 1) {
for (var j = i + 1; j < n; j = j + 1) {
for (var i = 0; i < ${numIndicesElements}; i = i + 1) {
for (var j = i + 1; j < ${numIndicesElements}; j = j + 1) {
var index_i = i32(indices[i].x);
var index_j = i32(indices[j].x);
if (index_i == index_j) {
Expand All @@ -129,51 +159,31 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
}
}

var data_offset = 0u;
var indices_start = uniforms.last_index_dimension * global_idx;
if (${attributes.reduction === 'none'} && hasDuplicates) {
if (global_idx != 0u) {
return;
}
indices_start = 0u;
}
let indices_end = indices_start + uniforms.last_index_dimension;
for (var i = indices_start; i < indices_end; i++) {
var index = i32(indices[i].x);
${
inputs[0].dims.length === 1
? `
let element_count_dim = uniforms.output_strides;
let dim_value = uniforms.output_shape;`
: `
let element_count_dim = uniforms.output_strides[i - indices_start];
let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];`
}
if (index >= 0) {
if (index >= i32(dim_value)) {
index = i32(dim_value - 1);
}
} else {
if (index < -i32(dim_value)) {
index = 0;
} else {
index += i32(dim_value);
// Process each index-update pair individually when duplicates exist
for (var idx = 0u; idx < ${numIndicesElements}u; idx++) {
var data_offset = 0u;
for (var i = 0u; i < uniforms.last_index_dimension; i++) {
var index = i32(indices[idx * uniforms.last_index_dimension + i].x);
${calcDataOffsetSnippet(inputShape.length, false)}
}
${updateElementsSnippet(attributes, output.type.value as ReductionType, false)}
}
data_offset += u32((u32(index) * element_count_dim));
return;
}

for (var i = 0u; i < uniforms.num_updates_elements; i++) {
let value = updates[uniforms.num_updates_elements * global_idx + i];
${atomicReductionSnippet(
attributes.reduction,
'output[data_offset + i]',
'value',
output.type.value as ReductionType,
)}
var data_offset = 0u;
var indices_start = uniforms.last_index_dimension * global_idx;
var indices_end = indices_start + uniforms.last_index_dimension;
for (var i = indices_start; i < indices_end; i++) {
var index = i32(indices[i].x);
${calcDataOffsetSnippet(inputShape.length, true)}
}

}`;
${updateElementsSnippet(attributes, output.type.value as ReductionType, true)}
}`;
};
return {
name: 'ScatterND',
Expand Down
Loading