Ggml/cuda snake fusion hardening#22912
Conversation
|
ADD/SUB/MUL/DIV in CUDA don't support BF16 in bin_bcast, but supports_op blindly returns true, so my PR crashes the CI when the fallback kicks in. Should I fix supports_op to tell the truth (mirroring Vulkan), or would you prefer extending bin_bcast to BF16 in a dedicated backend PR? |
I think you can just adds the supports_op fallback here, we can add bf16 support if you'd like later. |
bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting.
Does it make sense to add those checks as test cases to test-backend-ops? Seems like the same fusion pattern is implemented independently in multiple backends. |
Done: added rank-3/rank-4 shapes to test_snake_fuse. Mixed-precision is already covered by the existing F16/BF16 variants (x typed, a/inv_b in F32), which exercise the same types_ok rejection path. CUDA: 16/16 SNAKE_FUSE tests passed (F32 + F16, BF16 not supported) |
|
@ggml-org/maintainers Need a re-approval, please. |
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
Overview
Tightening of fusion pattern matching edge cases, mirroring the Vulkan PR. Thanks to @jeffbolznv for the review remarks.
Additional information
Vulkan counterpart: #22855
All Snake fusion operands and intermediates now share x's type, matching the kernel's single-T template and the float cast on a / inv_b. Mixed-precision chains cleanly fall back to the naive path. Mirrors the Vulkan fix.
Reject Snake fusion when ne[2] > 1 or ne[3] > 1. The kernel only iterates over the first two dimensions, so higher-rank tensors would silently produce garbage on the upper dims. The matcher now falls back to the naive chain, mirroring the Vulkan fix.
Requirements