@@ -46,35 +46,67 @@ def get_autotune_config():
46
46
47
47
@triton .autotune (configs = get_autotune_config (), key = ['n_rows' , 'n_cols' ], use_cuda_graph = True )
48
48
@triton .jit
49
- def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , epsilon ,
49
+ def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , eps ,
50
50
BLOCK_SIZE : tl .constexpr ):
51
- row_start = tl .program_id (0 )
52
- row_step = tl .num_programs (0 )
53
- col_offsets = tl .arange (0 , BLOCK_SIZE )
51
+ row_idx = tl .program_id (0 )
52
+
53
+ #Calculate squared mean by block
54
+ row_start_ptr = input_ptr + row_idx * input_row_stride
55
+ row_sum = 0.0
56
+ n_cols_blks = tl .cdiv (n_cols , BLOCK_SIZE ) - 1
57
+ #tl.device_print("n_cols_blks",n_cols_blks)
58
+ for b in tl .range (0 , n_cols_blks ):
59
+ col_offsets = b * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
60
+ input_ptrs = row_start_ptr + col_offsets
61
+ row_block = tl .load (input_ptrs , cache_modifier = ".cg" )
62
+ row_block = row_block * row_block #square every value the block
63
+ row_sum += (tl .sum (row_block , axis = - 1 ) / n_cols ) #tl.sum across row
64
+
65
+ col_offsets = n_cols_blks * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
66
+ input_ptrs = row_start_ptr + col_offsets
54
67
mask = col_offsets < n_cols
55
- for row_idx in tl .range (row_start , n_rows , row_step ):
56
- row_start_ptr = input_ptr + row_idx * input_row_stride
68
+ row_block = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
69
+ row_block = row_block * row_block #square every value the block
70
+ row_sum += (tl .sum (row_block , axis = - 1 ) / n_cols ) #tl.sum across row
71
+
72
+
73
+ row_norm = row_sum + eps
74
+ row_norm = tl .rsqrt (row_norm )
75
+
76
+ #Blocked normalization
77
+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
78
+ #for b in tl.range(0, n_cols, BLOCK_SIZE):
79
+ for b in tl .range (0 , n_cols_blks ):
80
+ col_offsets = b * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
57
81
input_ptrs = row_start_ptr + col_offsets
58
- input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
59
- row = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
60
- g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 )
61
- row_norm = row * row #square each value
62
- row_norm = tl .sum (row_norm , axis = - 1 ) #sum across columns(axis=-1)
63
- row_norm = row_norm / n_cols #divide by n_cols
64
- row_norm = row_norm + epsilon #add epsilon
65
- row_norm = tl .rsqrt (row_norm ) #take rsqrt, this is normalization value
66
- rms_norm = row * row_norm #multiply each x by normalization value
67
- rms_norm = rms_norm * g #element wise multiplication with g
68
-
69
- output_row_start_ptr = output_ptr + row_idx * output_row_stride
82
+ row_block = tl .load (input_ptrs , cache_modifier = ".cg" ) #load block of input
83
+ g = tl .load (g_ptr + col_offsets , cache_modifier = ".cg" ) #load block of g
84
+ output = row_block * row_norm #element wise multiply with rms_norm
85
+ output = output * g #element wise multiplication with g
86
+
70
87
output_ptrs = output_row_start_ptr + col_offsets
71
- output_ptrs = tl .multiple_of (output_ptrs , (16 , ))
72
- tl .store (output_ptrs , rms_norm , mask = mask )
88
+ tl .store (output_ptrs , output )
89
+
90
+ col_offsets = n_cols_blks * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
91
+ input_ptrs = row_start_ptr + col_offsets
92
+ mask = col_offsets < n_cols
93
+ row_block = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" ) #load block of input
94
+ g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 , cache_modifier = ".cg" ) #load block of g
95
+ output = row_block * row_norm #element wise multiply with rms_norm
96
+ output = output * g #element wise multiplication with g
97
+
98
+ #tl.device_print("output",output)
99
+ output_ptrs = output_row_start_ptr + col_offsets
100
+ tl .store (output_ptrs , output , mask = mask )
101
+
73
102
74
103
75
104
def triton_rmsnorm (x , g , epsilon = 1e-6 ):
76
105
n_rows , n_cols = x .shape
77
- BLOCK_SIZE = triton .next_power_of_2 (n_cols )
106
+ #Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
107
+ #performance can drop significantly for larger n_cols.
108
+ MAX_FUSED_SIZE = 65536 // x .element_size ()
109
+ BLOCK_SIZE = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
78
110
79
111
y = torch .empty_like (x , device = 'cuda' )
80
112
@@ -84,7 +116,6 @@ def triton_rmsnorm(x, g, epsilon=1e-6):
84
116
85
117
return y
86
118
87
-
88
119
def torch_rmsnorm (x , g ):
89
120
M , N = x .shape
90
121
if hasattr (torch .nn , 'RMSNorm' ):
@@ -95,15 +126,17 @@ def torch_rmsnorm(x, g):
95
126
rms_norm = torch .div (x , rms .unsqueeze (1 ).repeat (1 , N )) * g
96
127
return rms_norm
97
128
98
-
129
+ # yapf: disable
99
130
@pytest .mark .parametrize ('M, N' , [
100
- (1 , 4 ),
101
- (2 , 10 ),
102
- (8192 , 4096 ),
103
- (4096 , 8192 ),
104
- (1 , 8192 ),
105
- (873 , 1245 ),
106
- ])
131
+ (1 , 4 ),
132
+ (2 , 10 ),
133
+ (8192 , 4096 ),
134
+ (4096 , 8192 ),
135
+ (1 , 8192 ),
136
+ (873 , 1245 ),
137
+ (1 , 98304 )
138
+ ])
139
+ # yapf: enable
107
140
def test_rmsnorm (M , N ):
108
141
torch .manual_seed (0 )
109
142
x = torch .randn (M , N , device = 'cuda' )
@@ -114,7 +147,6 @@ def test_rmsnorm(M, N):
114
147
115
148
assert torch .allclose (y_triton , y_torch ), (y_triton , y_torch )
116
149
117
-
118
150
#Benchmark
119
151
arg_to_torch_dtype = {'fp16' : torch .float16 , 'bf16' : torch .bfloat16 , 'fp32' : torch .float32 }
120
152
0 commit comments