-
Notifications
You must be signed in to change notification settings - Fork 439
[Refactor] Phaseout legacy alloc_local statement in examples and introduce processing for floating fragment buffers
#1495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d23c882
c471c62
b2d281c
4240371
c45188f
a377fa5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -173,31 +173,25 @@ def combine( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(heads, batch, threads=128) as (by, bz): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| po_local = T.alloc_fragment([dim], dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o_accum_local = T.alloc_fragment([dim], accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_local_split = T.alloc_local([1], accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_logsum_local = T.alloc_local([1], accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_max_local = T.alloc_local([1], accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_local = T.alloc_local([1], accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.annotate_layout( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_local_split = T.alloc_var(accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_logsum_local = T.alloc_var(accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_max_local = T.alloc_var(accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_local = T.alloc_var(accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.clear(lse_logsum_local) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.clear(o_accum_local) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_max_local[0] = -T.infinity(accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_max_local = -T.infinity(accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for k in T.serial(num_split): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_max_local = T.max(lse_max_local, glse[bz, by, k]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for k in T.Pipelined(num_split, num_stages=1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_local_split[0] = glse[bz, by, k] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_local_split = glse[bz, by, k] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for k in T.serial(num_split): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in T.Parallel(dim): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| po_local[i] = Output_partial[bz, by, k, i] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_local_split[0] = glse[bz, by, k] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lse_local_split = glse[bz, by, k] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in T.Parallel(dim): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o_accum_local[i] += po_local[i] * scale_local[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+176
to
196
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug:
🔎 Proposed fix for i in T.Parallel(dim):
- o_accum_local[i] += po_local[i] * scale_local[0]
+ o_accum_local[i] += po_local[i] * scale_local📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in T.Parallel(dim): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical:
max_splitused uninitialized when all lse values are zero.At line 150,
max_splitis only assigned inside the conditionalif lse_local_split != 0. If all log-sum-exp values are zero (which can occur when no valid blocks are processed in any split),max_splitis never initialized but is subsequently used in conditionals at lines 154 and 159.In the previous array-based implementation,
max_split[0]would have been zero-initialized, providing defined behavior. The scalar refactor introduces undefined behavior in this edge case.🔎 Proposed fix: Initialize max_split before the loop
T.clear(lse_logsum_local) T.clear(o_accum_local) lse_max_local = -T.infinity(accum_dtype) +max_split = -1 for k in T.serial(num_split): lse_local_split = glse[bz, by, k] if lse_local_split != 0: max_split = k lse_max_local = T.max(lse_max_local, glse[bz, by, k])Initializing to
-1ensures that if no valid splits are found, the conditionsk <= max_splitat lines 154 and 159 will be false for all k ≥ 0, correctly skipping the accumulation loops.🤖 Prompt for AI Agents