@@ -719,6 +719,8 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
719719 self .quant_dtype = torch .float8_e4m3fn
720720 self .quant_fp8 = QuantFP8 (static = True ,
721721 group_shape = GroupShape .PER_TENSOR )
722+ # TODO HACK
723+ self .quant_fp8 ._forward_method = self .quant_fp8 .forward_native
722724
723725 def register (self , pm_pass : PatternMatcherPass ):
724726
@@ -729,9 +731,9 @@ def get_inputs():
729731 rmsnorm_result = torch .empty ([1 , 8 , 4 ],
730732 device = self .device ,
731733 dtype = self .dtype )
732- quant_result = torch .empty ([1 , 8 , 4 ],
733- device = self .device ,
734- dtype = self .quant_dtype )
734+ # quant_result = torch.empty([1, 8, 4],
735+ # device=self.device,
736+ # dtype=self.quant_dtype)
735737 weight = torch .empty ([4 ], device = self .device , dtype = self .dtype )
736738 scale = torch .tensor (1.0 , device = self .device , dtype = torch .float32 )
737739 return [
@@ -807,6 +809,8 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
807809 self .quant_dtype = torch .float8_e4m3fn
808810 self .quant_fp8 = QuantFP8 (static = True ,
809811 group_shape = GroupShape .PER_TENSOR )
812+ # TODO HACK
813+ self .quant_fp8 ._forward_method = self .quant_fp8 .forward_native
810814
811815 def register (self , pm_pass : PatternMatcherPass ):
812816
@@ -817,9 +821,9 @@ def get_inputs():
817821 device = self .device ,
818822 dtype = self .dtype )
819823 weight = torch .empty ([4 , 4 ], device = self .device , dtype = self .dtype )
820- quant_result = torch .empty ([4 , 4 ],
821- device = self .device ,
822- dtype = self .quant_dtype )
824+ # quant_result = torch.empty([4, 4],
825+ # device=self.device,
826+ # dtype=self.quant_dtype)
823827 scale = torch .empty ([1 , 1 ],
824828 device = self .device ,
825829 dtype = torch .float32 )
@@ -1166,6 +1170,9 @@ def __init__(self, config: VllmConfig):
11661170 # and allow multiple values of epsilon.
11671171 torch ._inductor .pattern_matcher ._seen_patterns .clear ()
11681172
1173+ if path := config .compilation_config .debug_dump_path :
1174+ with open (f"{ path } /patterns.txt" , 'w' ) as f :
1175+ print (self .patterns .patterns , file = f )
11691176 self .disabled = False
11701177
11711178 def __call__ (self , graph : fx .Graph ):
0 commit comments