@@ -47,29 +47,29 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):
4747def run_single_rank (
4848 tensor_parallel_size ,
4949 single_rank_forward_func ,
50- input ,
51- residual ,
50+ input_list ,
51+ residual_list ,
5252 norm_weight ,
5353 eps ,
5454 hidden_size ,
5555 dtype ,
5656 fused_add_norm ,
57- reference_output ,
57+ reference_output_list ,
5858):
5959 rank = tensorrt_llm .mpi_rank ()
6060 torch .cuda .set_device (rank )
6161 try :
6262 single_rank_forward_func (
63- input ,
64- residual ,
63+ input_list ,
64+ residual_list ,
6565 norm_weight ,
6666 eps ,
6767 hidden_size ,
6868 dtype ,
6969 tensor_parallel_size ,
7070 rank ,
7171 fused_add_norm ,
72- reference_output ,
72+ reference_output_list ,
7373 )
7474 except Exception :
7575 traceback .print_exc ()
@@ -79,25 +79,30 @@ def run_single_rank(
7979
8080@torch .inference_mode ()
8181def row_linear_residual_norm_fusion_forward (
82- x : torch .Tensor ,
83- residual : torch .Tensor ,
82+ x_list : list [ torch .Tensor ] ,
83+ residual_list : list [ torch .Tensor ] ,
8484 norm_weight : torch .Tensor ,
8585 eps : float ,
8686 hidden_size : int ,
8787 dtype : torch .dtype ,
8888 tensor_parallel_size : int ,
8989 tensor_parallel_rank : int ,
9090 fusion : bool ,
91- reference_output : tuple [torch .Tensor , ...],
91+ reference_output_list : list [ tuple [torch .Tensor , ...] ],
9292):
9393
94- x = x .cuda ()
95- residual = residual .cuda ()
94+ # Move all tensors to GPU
95+ x_list = [x .cuda () for x in x_list ]
96+ residual_list = [residual .cuda () for residual in residual_list ]
9697 norm_weight = norm_weight .cuda ()
97- reference_output = tuple (t .cuda () for t in reference_output )
98+ reference_output_list = [
99+ tuple (t .cuda () for t in ref_output )
100+ for ref_output in reference_output_list
101+ ]
98102
99103 MPI .COMM_WORLD .barrier ()
100104
105+ # Create a single AllReduce instance to be reused for all sequence lengths
101106 allreduce = AllReduce (
102107 mapping = Mapping (
103108 world_size = tensor_parallel_size ,
@@ -119,72 +124,116 @@ def func(input, residual, norm_weight, eps, enable_fusion):
119124 residual = residual ,
120125 norm_weight = norm_weight ,
121126 eps = eps ,
122- ))
127+ ),
128+ )
123129 return (output , residual )
124130 else :
125131 output = allreduce (input )
126132 return (output , )
127133
128- output = func (x .clone (), residual .clone (), norm_weight , eps , fusion )
134+ # Process each sequence length using the same AllReduce instance
135+ for i , (x , residual , reference_output ) in enumerate (
136+ zip (x_list , residual_list , reference_output_list )):
137+ output = func (x .clone (), residual .clone (), norm_weight , eps , fusion )
129138
130- torch .testing .assert_close (
131- output [0 ],
132- reference_output [0 ],
133- rtol = 0.05 ,
134- atol = 0.15 ,
135- )
136-
137- if fusion :
138139 torch .testing .assert_close (
139- output [1 ],
140- reference_output [1 ],
140+ output [0 ],
141+ reference_output [0 ],
141142 rtol = 0.05 ,
142143 atol = 0.15 ,
143144 )
144145
146+ if fusion :
147+ torch .testing .assert_close (
148+ output [1 ],
149+ reference_output [1 ],
150+ rtol = 0.05 ,
151+ atol = 0.15 ,
152+ )
153+
145154
146155@skip_pre_blackwell
147156@pytest .mark .skipif (torch .cuda .device_count () < 2 ,
148157 reason = "needs 2 GPUs to run this test" )
149- @pytest .mark .parametrize ("seq_len" , [1 , 4 , 32 , 128 ],
150- ids = lambda x : f"seqlen:{ x } " )
158+ @pytest .mark .parametrize (
159+ "seq_len" ,
160+ [
161+ [
162+ 1 ,
163+ ],
164+ [
165+ 4 ,
166+ ],
167+ [
168+ 15 ,
169+ ],
170+ [
171+ 32 ,
172+ ],
173+ [
174+ 128 ,
175+ ],
176+ [31 , 11 , 27 , 4 ],
177+ ],
178+ ids = lambda x : f"seqlen:{ x } " ,
179+ )
151180@pytest .mark .parametrize ("hidden_size" , [7168 ], ids = lambda x : f"hidden:{ x } " )
181+ @pytest .mark .parametrize ("dtype" ,
182+ [torch .float16 , torch .bfloat16 , torch .float32 ],
183+ ids = lambda x : f"dtype:{ torch .finfo (x ).dtype } " )
152184@pytest .mark .parametrize (
153185 "fusion" ,
154186 [True , False ],
155187 ids = ["fusion" , "no_fusion" ],
156188)
157- def test_row_linear_residual_norm_fusion (seq_len , hidden_size , fusion ):
189+ def test_row_linear_residual_norm_fusion (seq_len , hidden_size , dtype , fusion ):
158190
159191 torch .manual_seed (42 )
160- dtype = torch .bfloat16
161192 tensor_parallel_size = 2
162193
163- x = torch .randn ((tensor_parallel_size , seq_len , hidden_size ), dtype = dtype )
164- residual = torch .randn ((seq_len , hidden_size ), dtype = dtype )
194+ # Create norm_weight once (same for all sequence lengths)
165195 norm_weight = torch .randn ((hidden_size , ), dtype = dtype )
166196 eps = 1e-5
167- reference_output = (torch .sum (x , dim = 0 ), )
168- if fusion :
169- residual_out = reference_output [0 ] + residual
170- reference_output = (rms_norm (residual_out .to (torch .float32 ),
171- norm_weight , eps ).to (dtype ), residual_out )
197+
198+ # Create lists of tensors for each sequence length
199+ x_list = []
200+ residual_list = []
201+ reference_output_list = []
202+
203+ for seq_len_val in seq_len :
204+ x = torch .randn ((tensor_parallel_size , seq_len_val , hidden_size ),
205+ dtype = dtype )
206+ residual = torch .randn ((seq_len_val , hidden_size ), dtype = dtype )
207+ reference_output = (torch .sum (x , dim = 0 ), )
208+ if fusion :
209+ residual_out = reference_output [0 ] + residual
210+ reference_output = (rms_norm (residual_out .to (torch .float32 ),
211+ norm_weight ,
212+ eps ).to (dtype ), residual_out )
213+
214+ x_list .append (x )
215+ residual_list .append (residual )
216+ reference_output_list .append (reference_output )
172217
173218 with MPIPoolExecutor (max_workers = tensor_parallel_size ) as executor :
174219 results = executor .map (
175220 run_single_rank ,
176- * zip (* [(
177- tensor_parallel_size ,
178- row_linear_residual_norm_fusion_forward ,
179- x [i , :, :],
180- residual ,
181- norm_weight ,
182- eps ,
183- hidden_size ,
184- dtype ,
185- fusion ,
186- reference_output ,
187- ) for i in range (tensor_parallel_size )]),
221+ * zip (* [
222+ (
223+ tensor_parallel_size ,
224+ row_linear_residual_norm_fusion_forward ,
225+ [
226+ x [i , :, :] for x in x_list
227+ ], # Extract the i-th rank's data from each sequence length
228+ residual_list ,
229+ norm_weight ,
230+ eps ,
231+ hidden_size ,
232+ dtype ,
233+ fusion ,
234+ reference_output_list ,
235+ ) for i in range (tensor_parallel_size )
236+ ]),
188237 )
189238 for r in results :
190239 assert r is True
0 commit comments