Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions flash_attn/cute/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,59 +64,81 @@ def online_softmax(
# Change acc_S to M,N layout view.
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
row_scale = cute.make_fragment_like(self.row_max, Float32)

row_max = self.row_max
row_sum = self.row_sum
scale_log2 = self.scale_log2
arch = self.arch

# Each iteration processes one row of acc_S
for r in cutlass.range_constexpr(cute.size(self.row_max)):
for r in cutlass.range(cute.size(row_max), unroll_full=True):
acc_S_row = acc_S_mn[r, None].load() # (n_block_size)
row_max_cur = self._compute_row_max(

row_max_cur = utils.fmax_reduce(
acc_S_row,
init_val=self.row_max[r] if cutlass.const_expr(not is_first) else None,
init_val=row_max[r] if cutlass.const_expr(not is_first) else None,
arch=arch
)

row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4)
if cutlass.const_expr(check_inf):
row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur

if cutlass.const_expr(is_first):
row_max_cur_scaled = row_max_cur * self.scale_log2
acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled)
acc_S_row_sum = self._compute_row_sum(acc_S_row_exp)
row_max_cur_scaled = row_max_cur * scale_log2
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)

acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
row_scale[r] = 1.0
else:
row_max_prev = self.row_max[r]
row_max_cur_scaled = row_max_cur * self.scale_log2
acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled)
row_max_prev = row_max[r]
row_max_cur_scaled = row_max_cur * scale_log2
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
# row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled)
row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * self.scale_log2)
acc_S_row_sum = (
self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[r] * row_scale[r])
row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2)

acc_S_row_sum = utils.fadd_reduce(
acc_S_row_exp,
init_val=row_sum[r] * row_scale[r],
arch=arch
)
self.row_max[r] = row_max_cur
self.row_sum[r] = acc_S_row_sum

row_max[r] = row_max_cur
row_sum[r] = acc_S_row_sum
acc_S_mn[r, None].store(acc_S_row_exp)

return row_scale

@cute.jit
def finalize(self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None) -> cute.Tensor:
"""Finalize the online softmax by computing the scale and logsumexp."""
if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)):
assert cute.size(sink_val) == cute.size(self.row_sum)
row_sum = self.row_sum
row_max = self.row_max
scale_log2 = self.scale_log2

# quad reduction for row_sum as we didn't do it during each iteration of online softmax
self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4))
row_scale = cute.make_fragment_like(self.row_max, Float32)
for r in cutlass.range_constexpr(cute.size(self.row_sum)):
row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4))
row_scale = cute.make_fragment_like(row_max, Float32)

for r in cutlass.range(cute.size(row_sum), unroll_full=True):
if cutlass.const_expr(sink_val is not None):
sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
LOG2_E = math.log2(math.e)
self.row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - self.row_max[r] * self.scale_log2)
row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2)

# if row_sum is zero or nan, set acc_O_mn_row to 1.0
acc_O_mn_row_is_zero_or_nan = (
self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r]
row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
)
row_scale[r] = (
cute.arch.rcp_approx(self.row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)
cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)
) * final_scale
row_sum_cur = self.row_sum[r]
row_sum_cur = row_sum[r]
LN2 = math.log(2.0)
self.row_sum[r] = (
(self.row_max[r] * self.scale_log2 + utils.log2f(row_sum_cur)) * LN2
row_sum[r] = (
(row_max[r] * scale_log2 + utils.log2f(row_sum_cur)) * LN2
if not acc_O_mn_row_is_zero_or_nan
else -Float32.inf
)
Expand Down