@@ -101,6 +101,131 @@ def func(
101101 )
102102
103103
104+ @torch .inference_mode ()
105+ def row_linear_residual_norm_fusion_forward_legacy (
106+ x : torch .Tensor ,
107+ residual : torch .Tensor ,
108+ norm_weight : torch .Tensor ,
109+ eps : float ,
110+ hidden_size : int ,
111+ dtype : torch .dtype ,
112+ mapping : Mapping ,
113+ fusion : bool ,
114+ reference_output : tuple [torch .Tensor , ...],
115+ multicast_ptr : int ,
116+ buffer_ptrs_dev : int ,
117+ unicast_ptr : int ,
118+ max_num_elements_mnnvl : int ,
119+ buffer_flags_mnnvl : torch .Tensor ,
120+ ):
121+ tensor_parallel_size = mapping .tp_size
122+ tensor_parallel_rank = mapping .tp_rank
123+ MPI .COMM_WORLD .barrier ()
124+
125+ def func (
126+ input ,
127+ residual ,
128+ norm_weight ,
129+ eps ,
130+ enable_fusion ,
131+ multicast_ptr ,
132+ buffer_ptrs_dev ,
133+ unicast_ptr ,
134+ max_num_elements_mnnvl ,
135+ ):
136+ # For both fused and unfused cases:
137+ shape = input .shape
138+ input = input .view (- 1 , shape [- 1 ])
139+ buffer_M = max_num_elements_mnnvl // hidden_size
140+
141+ if enable_fusion :
142+ use_pdl = True
143+
144+ prenorm_output = torch .empty_like (residual )
145+ normed_output = torch .empty_like (residual )
146+
147+ trtllm_mnnvl_ar .mpi_barrier ()
148+
149+ trtllm_mnnvl_ar .trtllm_mnnvl_fused_allreduce_rmsnorm (
150+ prenorm_output ,
151+ normed_output ,
152+ input ,
153+ multicast_ptr ,
154+ buffer_ptrs_dev ,
155+ unicast_ptr ,
156+ buffer_M ,
157+ buffer_flags_mnnvl ,
158+ tensor_parallel_size ,
159+ tensor_parallel_rank ,
160+ norm_weight ,
161+ eps ,
162+ residual ,
163+ use_pdl ,
164+ )
165+
166+ return normed_output .view (shape ), prenorm_output .view (shape )
167+
168+ else :
169+ output = torch .empty_like (input )
170+
171+ trtllm_mnnvl_ar .trtllm_mnnvl_all_reduce (
172+ input ,
173+ multicast_ptr ,
174+ buffer_ptrs_dev ,
175+ buffer_M ,
176+ buffer_flags_mnnvl ,
177+ tensor_parallel_size ,
178+ tensor_parallel_rank ,
179+ True , # wait_for_results
180+ False , # launch_with_pdl
181+ output , # Need to provide output tensor since we are writing them out.
182+ )
183+ return (output .view (shape ),)
184+
185+ output = func (
186+ x .clone (),
187+ residual .clone (),
188+ norm_weight ,
189+ eps ,
190+ fusion ,
191+ multicast_ptr ,
192+ buffer_ptrs_dev ,
193+ unicast_ptr ,
194+ max_num_elements_mnnvl ,
195+ )
196+
197+ assert output [0 ].shape == reference_output [0 ].shape
198+
199+ if tensor_parallel_rank == 0 :
200+ print ("output[0] (first 10 values):" , output [0 ].flatten ()[:10 ])
201+ print (
202+ "reference_output[0] (first 10 values):" ,
203+ reference_output [0 ].flatten ()[:10 ],
204+ )
205+
206+ if fusion :
207+ print ("output[1] (first 10 values):" , output [1 ].flatten ()[:10 ])
208+ print (
209+ "reference_output[1] (first 10 values):" ,
210+ reference_output [1 ].flatten ()[:10 ],
211+ )
212+
213+ torch .testing .assert_close (
214+ output [0 ],
215+ reference_output [0 ],
216+ rtol = 0.05 ,
217+ atol = 0.15 ,
218+ )
219+
220+ if fusion :
221+ torch .testing .assert_close (
222+ output [1 ],
223+ reference_output [1 ],
224+ rtol = 0.05 ,
225+ atol = 0.15 ,
226+ )
227+
228+
104229"""Helper function to run the core MNNVL AllReduce test logic"""
105230
106231
@@ -146,7 +271,13 @@ def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion
146271
147272
148273def run_mnnvl_ar_full (
149- monkeypatch , seq_lens : list [int ], fusion : bool , dtype : torch .dtype , hidden_size : int
274+ monkeypatch ,
275+ seq_lens : list [int ],
276+ fusion : bool ,
277+ dtype : torch .dtype ,
278+ hidden_size : int ,
279+ legacy_explicit_workspace_bytes : int = None ,
280+ legacy_api : bool = False ,
150281):
151282 """Core test logic for MNNVL AllReduce operations.
152283
@@ -195,16 +326,30 @@ def run_mnnvl_ar_full(
195326 failure_message = ""
196327
197328 try :
198- required_workspace_bytes = trtllm_mnnvl_ar .MNNVLAllreduceFusionWorkspace .get_required_buffer_size_bytes (
199- mapping .tp_size ,
200- max (seq_lens ),
201- hidden_size ,
202- dtype ,
203- trtllm_mnnvl_ar .MNNVLAllreduceFusionStrategy .AUTO ,
204- )
205- workspace = trtllm_mnnvl_ar .MNNVLAllreduceFusionWorkspace (
206- mapping , required_workspace_bytes
207- )
329+ if legacy_api :
330+ mcast_buffer_mnnvl , buffer_flags_mnnvl , max_num_elements_mnnvl = (
331+ trtllm_mnnvl_ar .get_allreduce_mnnvl_workspace (
332+ mapping , dtype , buffer_size_in_bytes = legacy_explicit_workspace_bytes
333+ )
334+ )
335+
336+ multicast_ptr = mcast_buffer_mnnvl .get_multicast_ptr ()
337+ buffer_ptrs_dev = mcast_buffer_mnnvl .get_buffer_ptrs_dev ()
338+ unicast_ptr = mcast_buffer_mnnvl .mcast_device_memory .get_unicast_ptr (
339+ mapping .tp_rank
340+ )
341+
342+ else :
343+ required_workspace_bytes = trtllm_mnnvl_ar .MNNVLAllreduceFusionWorkspace .get_required_buffer_size_bytes (
344+ mapping .tp_size ,
345+ max (seq_lens ),
346+ hidden_size ,
347+ dtype ,
348+ trtllm_mnnvl_ar .MNNVLAllreduceFusionStrategy .AUTO ,
349+ )
350+ workspace = trtllm_mnnvl_ar .MNNVLAllreduceFusionWorkspace (
351+ mapping , required_workspace_bytes
352+ )
208353
209354 test_data = []
210355 for seq_len in seq_lens :
@@ -221,18 +366,34 @@ def run_mnnvl_ar_full(
221366 print (
222367 f"Testing seq_len={ seq_len } , hidden_size={ hidden_size } , fusion={ fusion } , dtype={ dtype } "
223368 )
224-
225- # Run the test with the same workspace
226- row_linear_residual_norm_fusion_forward (
227- x ,
228- residual ,
229- norm_weight ,
230- eps ,
231- mapping ,
232- fusion ,
233- reference_output ,
234- workspace ,
235- )
369+ if legacy_api :
370+ row_linear_residual_norm_fusion_forward_legacy (
371+ x ,
372+ residual ,
373+ norm_weight ,
374+ eps ,
375+ hidden_size ,
376+ dtype ,
377+ mapping ,
378+ fusion ,
379+ reference_output ,
380+ multicast_ptr ,
381+ buffer_ptrs_dev ,
382+ unicast_ptr ,
383+ max_num_elements_mnnvl ,
384+ buffer_flags_mnnvl ,
385+ )
386+ else :
387+ row_linear_residual_norm_fusion_forward (
388+ x ,
389+ residual ,
390+ norm_weight ,
391+ eps ,
392+ mapping ,
393+ fusion ,
394+ reference_output ,
395+ workspace ,
396+ )
236397
237398 # Synchronize before next test
238399 trtllm_mnnvl_ar .mpi_barrier ()
@@ -283,8 +444,23 @@ def run_mnnvl_ar_full(
283444@pytest .mark .parametrize ("fusion" , [False , True ])
284445@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
285446@pytest .mark .parametrize ("hidden_size" , [2880 , 5120 , 7168 , 8192 ])
286- def test_mnnvl_allreduce_default_workspace (
447+ def test_mnnvl_allreduce_refactored (
448+ monkeypatch , seq_lens : list [int ], fusion : bool , dtype : torch .dtype , hidden_size : int
449+ ):
450+ """Test MNNVL AllReduce with refactored API."""
451+ run_mnnvl_ar_full (
452+ monkeypatch , seq_lens , fusion , dtype , hidden_size , legacy_api = False
453+ )
454+
455+
456+ @pytest .mark .parametrize ("seq_lens" , [[1 ], [4 ], [15 ], [27 , 11 , 24 ], [127 ]])
457+ @pytest .mark .parametrize ("fusion" , [False , True ])
458+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
459+ @pytest .mark .parametrize ("hidden_size" , [2048 , 4096 , 5120 , 7168 , 8192 ])
460+ def test_mnnvl_allreduce_legacy (
287461 monkeypatch , seq_lens : list [int ], fusion : bool , dtype : torch .dtype , hidden_size : int
288462):
289- """Test MNNVL AllReduce with default workspace size."""
290- run_mnnvl_ar_full (monkeypatch , seq_lens , fusion , dtype , hidden_size )
463+ """Test MNNVL AllReduce with legacy API."""
464+ run_mnnvl_ar_full (
465+ monkeypatch , seq_lens , fusion , dtype , hidden_size , legacy_api = True
466+ )
0 commit comments