fix: use LSE accum strides from params instead of hardcoded ones#2388
Merged
tridao merged 1 commit intoDao-AILab:mainfrom Mar 25, 2026
Merged
fix: use LSE accum strides from params instead of hardcoded ones#2388tridao merged 1 commit intoDao-AILab:mainfrom
tridao merged 1 commit intoDao-AILab:mainfrom
Conversation
tridao
approved these changes
Mar 25, 2026
Member
|
Thanks! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In the Split-KV path, the forward kernel computes LSE accumulator addresses using hardcoded strides instead of the stride values provided in the params structure. The combine kernel already uses the explicit strides from params, so this creates an inconsistency between the two kernels.
As a result, when the caller supplies an LSE accumulator layout that differs from the layout assumed by the forward kernel, the forward pass writes to incorrect locations and produces wrong output.
This change updates the forward kernel to use the LSE accumulator strides from params, matching the behavior of the combine kernel and ensuring correct results for arbitrary accumulator layouts.