diff --git a/src/flag_gems/ops/groupnorm.py b/src/flag_gems/ops/groupnorm.py index 7b8219fe4..682d92e6f 100644 --- a/src/flag_gems/ops/groupnorm.py +++ b/src/flag_gems/ops/groupnorm.py @@ -27,33 +27,47 @@ def group_norm_kernel( num_groups, eps, BLOCK_GROUP_SIZE: tl.constexpr, - BLOCK_HW_SIZE: tl.constexpr, + BLOCK_HW_SIZE: tl.constexpr = 1024, ): pid = tle.program_id(0) group = pid % num_groups num_elements = group_size * HW group_offset = tl.arange(0, BLOCK_GROUP_SIZE) - hw_offset = tl.arange(0, BLOCK_HW_SIZE) wb_offset = group * group_size + group_offset wb_mask = wb_offset < C - xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] - xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW + mean_num = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32) - Mean_ptr = Mean + pid - Rstd_ptr = Rstd + pid + for off in range(0, HW, BLOCK_HW_SIZE): + hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) + hw_mask = hw_offset < HW - X_ptr = X + xy_offset - Y_ptr = Y + xy_offset + xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] + xy_mask = wb_mask[:, None] & hw_mask[None, :] - X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32) - mean = tl.sum(X_val) / num_elements - x = tl.where(xy_mask, X_val - mean, 0.0) + X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) + mean_num += X_val - var = tl.sum(x * x) / num_elements + mean_sum = tl.sum(mean_num) + mean = mean_sum / num_elements + + var_num = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32) + + for off in range(0, HW, BLOCK_HW_SIZE): + hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) + hw_mask = hw_offset < HW + xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] + xy_mask = wb_mask[:, None] & hw_mask[None, :] + + X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) + x = tl.where(xy_mask, X_val - mean, 0.0) + + var_num += x * x + + var_sum = tl.sum(var_num) + var = var_sum / num_elements rstd = rsqrt(var + eps) - x_hat = x * rstd if W is None: weight = 1 @@ -63,9 +77,21 @@ def group_norm_kernel( bias = 0 else: bias = tl.load(B + wb_offset, mask=wb_mask, other=0.0)[:, None] - Y_val = x_hat * weight + bias + for off in range(0, HW, BLOCK_HW_SIZE): + hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) + hw_mask = hw_offset < HW + xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] + xy_mask = wb_mask[:, None] & hw_mask[None, :] + + X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) + x = tl.where(xy_mask, X_val - mean, 0.0) + x_hat = x * rstd - tl.store(Y_ptr, Y_val, mask=xy_mask) + Y_val = x_hat * weight + bias + + tl.store(Y + xy_offset, Y_val, mask=xy_mask) + Mean_ptr = Mean + pid + Rstd_ptr = Rstd + pid tl.store(Mean_ptr, mean) tl.store(Rstd_ptr, rstd) @@ -208,7 +234,7 @@ def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05): group, eps, BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size), - BLOCK_HW_SIZE=triton.next_power_of_2(HxW), + BLOCK_HW_SIZE=1024, ) return y, mean, rstd