Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions src/flag_gems/ops/groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,47 @@ def group_norm_kernel(
num_groups,
eps,
BLOCK_GROUP_SIZE: tl.constexpr,
BLOCK_HW_SIZE: tl.constexpr,
BLOCK_HW_SIZE: tl.constexpr = 1024,
):
pid = tle.program_id(0)
group = pid % num_groups
num_elements = group_size * HW
group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
hw_offset = tl.arange(0, BLOCK_HW_SIZE)

wb_offset = group * group_size + group_offset
wb_mask = wb_offset < C

xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW
mean_num = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32)

Mean_ptr = Mean + pid
Rstd_ptr = Rstd + pid
for off in range(0, HW, BLOCK_HW_SIZE):
hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
hw_mask = hw_offset < HW

X_ptr = X + xy_offset
Y_ptr = Y + xy_offset
xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
xy_mask = wb_mask[:, None] & hw_mask[None, :]

X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32)
mean = tl.sum(X_val) / num_elements
x = tl.where(xy_mask, X_val - mean, 0.0)
X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
mean_num += X_val

var = tl.sum(x * x) / num_elements
mean_sum = tl.sum(mean_num)
mean = mean_sum / num_elements

var_num = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32)

for off in range(0, HW, BLOCK_HW_SIZE):
hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
hw_mask = hw_offset < HW
xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
xy_mask = wb_mask[:, None] & hw_mask[None, :]

X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
x = tl.where(xy_mask, X_val - mean, 0.0)

var_num += x * x

var_sum = tl.sum(var_num)
var = var_sum / num_elements
rstd = rsqrt(var + eps)
x_hat = x * rstd

if W is None:
weight = 1
Expand All @@ -63,9 +77,21 @@ def group_norm_kernel(
bias = 0
else:
bias = tl.load(B + wb_offset, mask=wb_mask, other=0.0)[:, None]
Y_val = x_hat * weight + bias
for off in range(0, HW, BLOCK_HW_SIZE):
hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
hw_mask = hw_offset < HW
xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
xy_mask = wb_mask[:, None] & hw_mask[None, :]

X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
x = tl.where(xy_mask, X_val - mean, 0.0)
x_hat = x * rstd

tl.store(Y_ptr, Y_val, mask=xy_mask)
Y_val = x_hat * weight + bias

tl.store(Y + xy_offset, Y_val, mask=xy_mask)
Mean_ptr = Mean + pid
Rstd_ptr = Rstd + pid
tl.store(Mean_ptr, mean)
tl.store(Rstd_ptr, rstd)

Expand Down Expand Up @@ -208,7 +234,7 @@ def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05):
group,
eps,
BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size),
BLOCK_HW_SIZE=triton.next_power_of_2(HxW),
BLOCK_HW_SIZE=1024,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use fixed BLOCK_HW_SIZE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous bug was caused by BLOCK_HW_SIZE=triton.next_power_of_2(HxW), which made the index tensor xy_offset too large. Similar to the backward logic, we can fix BLOCK_HW_SIZE and introduce loop-blocking to handle large dimension sizes. This line can be deleted, as the default value is already used in the group_norm_kernel.

)
return y, mean, rstd

Expand Down
Loading