55import torch
66
77import vllm .envs as envs
8- from vllm import LLM , SamplingParams
98from vllm .compilation .activation_quant_fusion import ActivationQuantFusionPass
109from vllm .compilation .fix_functionalization import FixFunctionalizationPass
11- from vllm .compilation .fusion import FUSED_OPS , RMSNormQuantFusionPass
10+ from vllm .compilation .fusion import RMSNormQuantFusionPass
1211from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe , is_func
1312from vllm .compilation .noop_elimination import NoOpEliminationPass
1413from vllm .compilation .post_cleanup import PostCleanupPass
1514from vllm .config import CompilationConfig , PassConfig , VllmConfig
15+ from vllm .model_executor .layers .activation import SiluAndMul
16+ from vllm .model_executor .layers .layernorm import RMSNorm
1617from vllm .model_executor .layers .quantization .utils .quant_utils import (
17- QuantKey , kFp8DynamicTokenSym , kFp8StaticTensorSym )
18+ GroupShape )
19+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
20+ Fp8LinearOp )
21+ from vllm .model_executor .layers .rotary_embedding import get_rope
22+ from vllm .platforms import current_platform
1823
1924from .backend import TestBackend
2025
26+ FP8_DTYPE = current_platform .fp8_dtype ()
2127OPS_IN_MODEL = [
2228 torch .ops ._C .rotary_embedding .default ,
2329 torch .ops ._C .fused_add_rms_norm .default ,
2834RMS_QUANT_OPS = {
2935 "static_fp8" : [
3036 torch .ops ._C .rms_norm_static_fp8_quant .default ,
31- torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default
37+ torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ,
3238 ],
3339}
3440
4349]
4450
4551
46- @pytest .mark .parametrize (
47- "model, quant_key" ,
48- [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" , kFp8StaticTensorSym ),
49- ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e" ,
50- kFp8DynamicTokenSym )])
51- @pytest .mark .parametrize ("do_fusion" , [True , False ])
52+ class TestSiluMul (torch .nn .Module ):
53+
54+ def __init__ (self , quant = True , hidden_size : int = 128 ):
55+ super ().__init__ ()
56+ self .quant = quant
57+ self .silu_and_mul = SiluAndMul ()
58+ self .wscale = torch .rand (1 , dtype = torch .float32 )
59+ self .scale = torch .rand (1 , dtype = torch .float32 )
60+
61+ if self .quant :
62+ self .w = torch .rand (hidden_size ,
63+ hidden_size ).to (dtype = FP8_DTYPE ).t ()
64+ self .fp8_linear = Fp8LinearOp (
65+ act_quant_static = True ,
66+ act_quant_group_shape = GroupShape .PER_TENSOR ,
67+ )
68+
69+ def forward (self , x ):
70+ y = self .silu_and_mul (x )
71+ if self .quant :
72+ x2 = self .fp8_linear .apply (y ,
73+ self .w ,
74+ self .wscale ,
75+ input_scale = self .wscale )
76+ return x2
77+ else :
78+ return y
79+
80+ def example_inputs (self , num_tokens = 32 , hidden_size = 128 ):
81+ dtype = torch .float16 if self .quant else torch .float32
82+ return (torch .rand (num_tokens , hidden_size * 2 , dtype = dtype ), )
83+
84+ def ops_in_model (self ):
85+ return ([torch .ops ._C .silu_and_mul_quant .default ]
86+ if self .quant else [torch .ops ._C .silu_and_mul .default ])
87+
88+ def ops_not_in_model (self ):
89+ return []
90+
91+
92+ class TestFusedAddRMSNorm (torch .nn .Module ):
93+
94+ def __init__ (self , quant = True , hidden_size = 16 , intermediate_size = 32 ):
95+ super ().__init__ ()
96+ self .quant = quant
97+ self .hidden_size = hidden_size
98+ self .intermediate_size = intermediate_size
99+
100+ dtype = torch .float16 if self .quant else torch .float32
101+
102+ self .gate_proj = torch .nn .Parameter (
103+ torch .empty ((intermediate_size , hidden_size ), dtype = dtype ))
104+ self .norm = RMSNorm (intermediate_size , 1e-05 )
105+ self .norm .weight = torch .nn .Parameter (
106+ torch .ones (intermediate_size , dtype = dtype ))
107+
108+ torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
109+
110+ if self .quant :
111+ self .fp8_linear = Fp8LinearOp (act_quant_static = True )
112+
113+ self .scale = torch .rand (1 , dtype = torch .float32 )
114+ self .w = torch .rand (hidden_size ,
115+ intermediate_size ).to (dtype = FP8_DTYPE ).t ()
116+ self .wscale = torch .rand (1 , dtype = torch .float32 )
117+
118+ def forward (self , hidden_states , residual ):
119+ # Reshape input
120+ view = hidden_states .reshape (- 1 , self .hidden_size )
121+
122+ # matrix multiplication
123+ permute = self .gate_proj .permute (1 , 0 )
124+ mm = torch .mm (view , permute )
125+
126+ # layer normalization
127+ norm_output , residual_output = self .norm (mm , residual )
128+
129+ if self .quant :
130+ # scaled_mm with static input quantization
131+ fp8_linear_result = self .fp8_linear .apply (
132+ norm_output ,
133+ self .w ,
134+ self .wscale ,
135+ input_scale = self .scale .to (norm_output .device ),
136+ )
137+
138+ return fp8_linear_result , residual_output
139+
140+ else :
141+ return norm_output , residual_output
142+
143+ def example_inputs (self , batch_size = 8 , hidden_size = 16 , seq_len = 16 ):
144+ dtype = torch .float16 if self .quant else torch .float32
145+ hidden_states = torch .randn ((batch_size * seq_len , hidden_size ),
146+ dtype = dtype )
147+ residual = torch .randn ((batch_size * seq_len , hidden_size ),
148+ dtype = dtype )
149+ return (hidden_states , residual )
150+
151+ def ops_in_model (self ):
152+ return ([torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ]
153+ if self .quant else [torch .ops ._C .fused_add_rms_norm .default ])
154+
155+ def ops_not_in_model (self ):
156+ return []
157+
158+
159+ class TestRotaryEmbedding (torch .nn .Module ):
160+
161+ def __init__ (
162+ self ,
163+ quant = False , # not used
164+ head_dim = 64 ,
165+ rotary_dim = None ,
166+ max_position = 2048 ,
167+ base = 10000 ):
168+ super ().__init__ ()
169+ self .head_dim = head_dim
170+ self .rotary_dim = rotary_dim or head_dim
171+
172+ self .rotary_emb = get_rope (
173+ self .head_dim ,
174+ rotary_dim = self .rotary_dim ,
175+ max_position = max_position ,
176+ base = base ,
177+ )
178+
179+ def forward (self , positions , q , k ):
180+ q_rotated , k_rotated = self .rotary_emb (positions , q , k )
181+ return q_rotated , k_rotated
182+
183+ def example_inputs (self , num_tokens = 32 , head_dim = 64 ):
184+ dtype = torch .float16
185+ positions = torch .arange (num_tokens , dtype = torch .long )
186+ q = torch .randn (num_tokens , head_dim , dtype = dtype )
187+ k = torch .randn (num_tokens , head_dim , dtype = dtype )
188+ return (positions , q , k )
189+
190+ def ops_in_model (self ):
191+ return [torch .ops ._C .rotary_embedding .default ]
192+
193+ def ops_not_in_model (self ):
194+ return []
195+
196+
197+ class TestRotaryEmbeddingSliceScatter (torch .nn .Module ):
198+
199+ def __init__ (
200+ self ,
201+ quant = False , # not used
202+ head_dim = 64 ,
203+ num_heads = 4 ,
204+ max_position = 2048 ,
205+ base = 10000 ):
206+ super ().__init__ ()
207+ self .head_dim = head_dim
208+ self .num_heads = num_heads
209+ self .hidden_size = head_dim * num_heads
210+
211+ self .qkv_proj = torch .nn .Linear (self .hidden_size ,
212+ self .hidden_size * 3 ,
213+ bias = False ,
214+ dtype = torch .float16 )
215+
216+ self .rotary_emb = get_rope (
217+ self .head_dim ,
218+ rotary_dim = self .head_dim ,
219+ max_position = max_position ,
220+ base = base ,
221+ )
222+
223+ def forward (self , positions , hidden_states ):
224+ # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
225+ # -> slice_scatter -> split_with_sizes
226+
227+ qkv = self .qkv_proj (hidden_states )
228+ split_sizes = [self .hidden_size , self .hidden_size , self .hidden_size ]
229+ q , k , v = torch .split (qkv , split_sizes , dim = - 1 )
230+
231+ q_rotated , k_rotated = self .rotary_emb (positions , q , k )
232+
233+ qkv_updated = torch .cat ([q_rotated , k_rotated , v ], dim = - 1 )
234+ return qkv_updated
235+
236+ def example_inputs (self , num_tokens = 32 , head_dim = 64 , num_heads = 4 ):
237+ dtype = torch .float16
238+ hidden_size = head_dim * num_heads
239+ positions = torch .arange (num_tokens , dtype = torch .long )
240+ hidden_states = torch .randn (num_tokens , hidden_size , dtype = dtype )
241+ return (positions , hidden_states )
242+
243+ def ops_in_model (self ):
244+ return [torch .ops ._C .rotary_embedding .default ]
245+
246+ def ops_not_in_model (self ):
247+ return [torch .ops .aten .slice_scatter .default ]
248+
249+
250+ MODELS = [
251+ TestSiluMul ,
252+ TestFusedAddRMSNorm ,
253+ TestRotaryEmbedding ,
254+ TestRotaryEmbeddingSliceScatter ,
255+ ]
256+
257+
258+ @pytest .mark .parametrize ("model_class" , MODELS )
259+ @pytest .mark .parametrize ("quant" , [True , False ])
260+ @pytest .mark .parametrize ("do_fusion" , [True ]) # , False])
52261@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE != "cuda" ,
53262 reason = "Only test on CUDA" )
54- def test_fix_functionalization (model : str , quant_key : QuantKey ,
263+ def test_fix_functionalization (model_class : torch . nn . Module , quant : bool ,
55264 do_fusion : bool ):
56265 torch .set_default_device ("cuda" )
57266
@@ -63,56 +272,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
63272 cleanup_pass = PostCleanupPass (vllm_config )
64273 act_quant_fusion_pass = ActivationQuantFusionPass (vllm_config )
65274
66- passes = [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass
67- ] if do_fusion else [noop_pass , cleanup_pass ]
275+ passes = ( [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass ]
276+ if do_fusion else [noop_pass , cleanup_pass ])
68277 func_pass = FixFunctionalizationPass (vllm_config )
278+
69279 backend_func = TestBackend (* passes , func_pass )
70280 backend_no_func = TestBackend (* passes )
71281
72- # instantiate a full engine and manually compile the model 2x
73- # (with and without FixFunctionalizationPass)
74- llm = LLM (model = model , enforce_eager = True )
75- model_runner = llm .llm_engine .model_executor .driver_worker .model_runner
76- orig_model = model_runner .model
77- # TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
78- # Can only do that by using the decorator but then we'd have to instantiate
79- # 2 LLM instances.
80-
81- sampling_params = SamplingParams (temperature = 0.0 , top_p = 1.0 )
82- model_runner .model = torch .compile (orig_model ,
83- fullgraph = True ,
84- backend = backend_func )
85- gen_func = llm .generate (prompts , sampling_params )
86-
87- model_runner .model = torch .compile (orig_model ,
88- fullgraph = True ,
89- backend = backend_no_func )
90-
91- gen_no_func = llm .generate (prompts , sampling_params )
92-
93- for output_func , output_no_func in zip (gen_func , gen_no_func ):
94- assert output_func .outputs [0 ].text == output_no_func .outputs [0 ].text
95-
96- # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
97- # and replaced by fused quantized ops in RMS_QUANT_OPS.
98- rms_ops = [FUSED_OPS [(quant_key , True )], FUSED_OPS [(quant_key , False )]
99- ] if do_fusion else [RMS_OP ]
100- silu_mul_ops = [SILU_MUL_QUANT_OP ] if do_fusion and \
101- quant_key == kFp8StaticTensorSym else [
102- SILU_MUL_OP
103- ]
104-
105- ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
106-
107- for op in ops :
282+ model = model_class (quant = quant )
283+ torch .compile (model , backend = backend_func )(* model .example_inputs ())
284+ torch .compile (model , backend = backend_no_func )(* model .example_inputs ())
285+
286+ # check if the functionalization pass is applied
287+ for op in model .ops_in_model ():
108288 find_auto_fn (backend_no_func .graph_post_pass .nodes , op )
109- assert find_auto_fn_maybe (backend_func .graph_post_pass .nodes ,
110- op ) is None # noqa: E501
289+ assert ( find_auto_fn_maybe (backend_func .graph_post_pass .nodes , op )
290+ is None ) # noqa: E501
111291
112292 # make sure the ops were all de-functionalized
113293 found = dict ()
114294 for node in backend_func .graph_post_pass .nodes :
115- for op in ops :
295+ for op in model .ops_in_model ():
296+ if is_func (node , op ):
297+ found [op ] = True
298+ for op in model .ops_not_in_model ():
116299 if is_func (node , op ):
117300 found [op ] = True
118- assert all (found [op ] for op in ops )
301+ assert all (found [op ] for op in model .ops_in_model ())
302+ assert all (not found .get (op ) for op in model .ops_not_in_model ())
0 commit comments