-
Notifications
You must be signed in to change notification settings - Fork 730
Arm backend: Improve dtype validation #15871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Arm backend: Improve dtype validation #15871
Conversation
Improve dtype validaiton in NodeVisitors. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ieb9ced1ae8d2db916e6c8bc0b45773a640d330db
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15871
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 3 Unrelated FailuresAs of commit 4399aa8 with merge base 9952aef ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR improves dtype validation in the ARM backend's node visitors by making dtype support conditional on TOSA specifications and extensions, adding bool dtype support to several operators, and refining dtype lists for better consistency and correctness.
- Enhanced dtype validation to conditionally support INT16 and INT48 based on TOSA 1.0 "int16" extension availability
- Added BOOL dtype support to operators like permute, repeat, expand, slice, and cat
- Removed INT8/INT16 from comparison operators (eq, lt, le, gt, ge) to align with TOSA spec constraints
- Added test cases with bool tensors and xfail markers for U55 hardware which doesn't support bool
Reviewed Changes
Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| backends/arm/test/ops/test_repeat.py | Added bool test case with xfail for U55 |
| backends/arm/test/ops/test_permute.py | Added bool test case with xfail for U55, renamed int16 quantization test functions for clarity (int16→16a8w) |
| backends/arm/test/ops/test_expand.py | Added bool test case with xfail for U55 |
| backends/arm/operators/ops_identity.py | Added conditional dtype validation based on TOSA spec and int16 extension support |
| backends/arm/operators/op_where.py | Deduplicated BOOL from supported dtypes list (moved to base list) |
| backends/arm/operators/op_tosa_transpose.py | Reordered dtype list for consistency (BOOL first, then INT types, then FP types) |
| backends/arm/operators/op_tosa_table.py | Refactored dtype validation to conditionally support INT16 input/INT32 output based on int16 extension |
| backends/arm/operators/op_tosa_resize.py | Improved dtype validation with conditional support for INT16/INT48 based on int16 extension |
| backends/arm/operators/op_tosa_matmul.py | Added conditional support for INT16 input and INT48 output based on int16 extension |
| backends/arm/operators/op_sum.py | Added explicit dtype validation for INT32 and FP32 |
| backends/arm/operators/op_slice.py | Added BOOL dtype support |
| backends/arm/operators/op_repeat.py | Added BOOL dtype support |
| backends/arm/operators/op_permute.py | Added BOOL dtype support |
| backends/arm/operators/op_mul.py | Added INT8 and INT16 dtype support |
| backends/arm/operators/op_max_pool2d.py | Added conditional INT16 support based on int16 extension |
| backends/arm/operators/op_lt.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_le.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_index_select.py | Added validation utilities and renamed unused variable from 'index' to '_' |
| backends/arm/operators/op_gt.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_ge.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_eq.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_clamp.py | Added conditional INT16 support based on int16 extension |
| backends/arm/operators/op_cat.py | Added BOOL dtype support and conditional INT16 support, improved validation with proper TosaArg conversion |
| backends/arm/operators/op_avg_pool2d.py | Added conditional INT16 support based on int16 extension |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class Expand(torch.nn.Module): | ||
| # (input tensor, multiples) | ||
| test_parameters = { | ||
| "randbool_1d": lambda: (torch.randint(0, 1, (1,), dtype=torch.bool), (5,)), |
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test case name "randbool_1d" uses torch.randint(0, 1, ...) which only generates 0s since randint excludes the upper bound. This should be torch.randint(0, 2, ...) to generate both 0 and 1 values for proper bool testing.
| "randbool_1d": lambda: (torch.randint(0, 1, (1,), dtype=torch.bool), (5,)), | |
| "randbool_1d": lambda: (torch.randint(0, 2, (1,), dtype=torch.bool), (5,)), |
|
This seem to get fails in the Arm test :( |
Improve dtype validation in node-visitors.
cc @freddan80 @per @zingo @digantdesai