Skip to content

Commit

Permalink
[JS/WebGPU] Fix Split and Where to handle corner cases. (#19613)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
1. Fix Where operator to handle Boolean input less than 4 bytes.
2. Fix JSEP test harness to use tensor names consistently.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored Feb 23, 2024
1 parent 5e432a3 commit ae3d73c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/where.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const createWhereOpProgramShader =
const expressionA = `a_data[index_a${x}][component_a${x}]`;
const expressionB = `b_data[index_b${x}][component_b${x}]`;
// eslint-disable-next-line no-bitwise
const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
return `
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
Expand All @@ -38,6 +38,7 @@ const createWhereOpProgramShader =
let index_c${x} = offset_c${x} / 4u;
let component_a${x} = offset_a${x} % 4u;
let component_b${x} = offset_b${x} % 4u;
let component_c${x} = offset_c${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
Expand Down
34 changes: 34 additions & 0 deletions js/web/test/data/ops/where.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,39 @@
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1",
"inputs": [
{
"data": [true, false],
"dims": [1, 1, 2, 1],
"type": "bool"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 4],
"type": "float32"
},
{
"data": [5, 6, 7, 8, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 2, 3, 4, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
]
}
]
}
]
4 changes: 2 additions & 2 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,8 @@ export async function runModelTestSet(
try {
const feeds: Record<string, ort.Tensor> = {};
const outputsMetaInfo: Record<string, ort.Tensor> = {};
testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor);
testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor);
const [start, end, outputs] =
await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
if (context.perfData.count === 0) {
Expand Down

0 comments on commit ae3d73c

Please sign in to comment.