diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index cfee07a9239d7..a6375847fc42f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -27,7 +27,7 @@ const createWhereOpProgramShader = const expressionA = `a_data[index_a${x}][component_a${x}]`; const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise - const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; return ` let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; @@ -38,6 +38,7 @@ const createWhereOpProgramShader = let index_c${x} = offset_c${x} / 4u; let component_a${x} = offset_a${x} % 4u; let component_b${x} = offset_b${x} % 4u; + let component_c${x} = offset_c${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc index 047fd6fd7511b..990120dd3708e 100644 --- a/js/web/test/data/ops/where.jsonc +++ b/js/web/test/data/ops/where.jsonc @@ -168,5 +168,39 @@ ] } ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1", + "inputs": [ + { + "data": [true, false], + "dims": [1, 1, 2, 1], + "type": "bool" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 4], + "type": "float32" + }, + { + "data": [5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index ecc7d4b4a09a5..a4adf5c4ce144 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -627,8 +627,8 @@ export async function runModelTestSet( try { const feeds: Record = {}; const outputsMetaInfo: Record = {}; - testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor); + testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor); const [start, end, outputs] = await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) {