diff --git a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts index 7fd0c2ef42aff..ec1d23e4887d5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts @@ -78,6 +78,36 @@ const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: } }; +const calcDataOffsetSnippet = (dataRank: number, parallel: boolean) => + `${ + dataRank === 1 + ? ` + let element_count_dim = uniforms.output_strides; + let dim_value = uniforms.output_shape;` + : ` + let element_count_dim = uniforms.output_strides[${parallel ? 'i - indices_start' : 'i'}]; + let dim_value = uniforms.output_shape[${parallel ? 'i - indices_start' : 'i'} + uniforms.last_index_dimension];` + } + + if (index >= 0) { + if (index >= i32(dim_value)) { + index = i32(dim_value - 1); + } + } else { + if (index < -i32(dim_value)) { + index = 0; + } else { + index += i32(dim_value); + } + } + data_offset += u32((u32(index) * element_count_dim));`; + +const updateElementsSnippet = (attributes: ScatterNDAttributes, outputTypeValue: ReductionType, parallel: boolean) => + `for (var i = 0u; i < uniforms.num_updates_elements; i++) { + let value = updates[uniforms.num_updates_elements * ${parallel ? 'global_idx' : 'idx'} + i]; + ${atomicReductionSnippet(attributes.reduction, 'output[data_offset + i]', 'value', outputTypeValue)} + }`; + const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const indicesShape = inputs[1].dims; @@ -87,6 +117,7 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components); const lastIndexDimension = indicesShape[indicesShape.length - 1]; const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension); + const numIndicesElements = ShapeUtil.sizeFromDimension(indicesShape, 0) / lastIndexDimension; const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, @@ -113,9 +144,8 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} var hasDuplicates = false; if (${attributes.reduction === 'none'}) { - let n = ${ShapeUtil.size(indicesShape)}; - for (var i = 0; i < n; i = i + 1) { - for (var j = i + 1; j < n; j = j + 1) { + for (var i = 0; i < ${numIndicesElements}; i = i + 1) { + for (var j = i + 1; j < ${numIndicesElements}; j = j + 1) { var index_i = i32(indices[i].x); var index_j = i32(indices[j].x); if (index_i == index_j) { @@ -129,51 +159,31 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S } } - var data_offset = 0u; - var indices_start = uniforms.last_index_dimension * global_idx; if (${attributes.reduction === 'none'} && hasDuplicates) { if (global_idx != 0u) { return; } - indices_start = 0u; - } - let indices_end = indices_start + uniforms.last_index_dimension; - for (var i = indices_start; i < indices_end; i++) { - var index = i32(indices[i].x); - ${ - inputs[0].dims.length === 1 - ? ` - let element_count_dim = uniforms.output_strides; - let dim_value = uniforms.output_shape;` - : ` - let element_count_dim = uniforms.output_strides[i - indices_start]; - let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];` - } - if (index >= 0) { - if (index >= i32(dim_value)) { - index = i32(dim_value - 1); - } - } else { - if (index < -i32(dim_value)) { - index = 0; - } else { - index += i32(dim_value); + // Process each index-update pair individually when duplicates exist + for (var idx = 0u; idx < ${numIndicesElements}u; idx++) { + var data_offset = 0u; + for (var i = 0u; i < uniforms.last_index_dimension; i++) { + var index = i32(indices[idx * uniforms.last_index_dimension + i].x); + ${calcDataOffsetSnippet(inputShape.length, false)} } + ${updateElementsSnippet(attributes, output.type.value as ReductionType, false)} } - data_offset += u32((u32(index) * element_count_dim)); + return; } - for (var i = 0u; i < uniforms.num_updates_elements; i++) { - let value = updates[uniforms.num_updates_elements * global_idx + i]; - ${atomicReductionSnippet( - attributes.reduction, - 'output[data_offset + i]', - 'value', - output.type.value as ReductionType, - )} + var data_offset = 0u; + var indices_start = uniforms.last_index_dimension * global_idx; + var indices_end = indices_start + uniforms.last_index_dimension; + for (var i = indices_start; i < indices_end; i++) { + var index = i32(indices[i].x); + ${calcDataOffsetSnippet(inputShape.length, true)} } - - }`; + ${updateElementsSnippet(attributes, output.type.value as ReductionType, true)} + }`; }; return { name: 'ScatterND',