2727except ImportError :
2828 float4_sf_dtype = None
2929
30+ # TODO: put the ENUMs in the same place and import it
31+ FORMAT_FP8 = 0
32+ FORMAT_NVFP4 = 1
33+
3034
3135def modelopt_fp4_scale_to_cutlass_fp4_scale (modelopt_scale : torch .Tensor ) -> torch .Tensor :
3236 """Converts the modelopt FP4 per-block weight scale to the cutlass format (padded and swizzled)."""
@@ -160,6 +164,18 @@ def shard_load_hook(
160164 def fuse_linear_weights (weights , ** kwargs ) -> Tuple [torch .Tensor , Dict [str , torch .Tensor ]]:
161165 pass
162166
167+ @staticmethod
168+ def custom_op ():
169+ """Unified custom kernel entry-point for quantized linear."""
170+ return torch .ops .auto_deploy .custom_quant_linear
171+
172+ @staticmethod
173+ def build_custom_kwargs_for_linear (
174+ scale_getattrs : Dict [str , Node ],
175+ ) -> Dict [str , object ]:
176+ """Default: no extra kwargs. Each impl overrides to pass the right inputs/scales/zps/format."""
177+ return {}
178+
163179
164180class FP8QuantizationImpl (QuantizationImpl ):
165181 @staticmethod
@@ -180,6 +196,20 @@ def scale_names() -> List[str]:
180196 def default_scales (original_weight_shape : Tuple ) -> Dict [str , torch .Tensor ]:
181197 return {"input_scale" : torch .tensor (1.0 ), "weight_scale" : torch .tensor (1.0 )}
182198
199+ @staticmethod
200+ def build_custom_kwargs_for_linear (
201+ scale_getattrs : Dict [str , Node ],
202+ ) -> Dict [str , object ]:
203+ # FP8 custom op contract:
204+ # input_scale=[tensor], weight_scale=[tensor], input_zp=[], weight_zp=[], format_type=FORMAT_FP8
205+ return dict (
206+ input_scale = [scale_getattrs ["input_scale" ]],
207+ weight_scale = [scale_getattrs ["weight_scale" ]],
208+ input_zp = [],
209+ weight_zp = [],
210+ format_type = FORMAT_FP8 ,
211+ )
212+
183213 @staticmethod
184214 def load_hook (state_dict , prefix , * args , weight_name ):
185215 if weight_name in state_dict :
@@ -264,6 +294,29 @@ def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
264294 "alpha" : torch .tensor (1.0 / 6.0 ),
265295 }
266296
297+ @staticmethod
298+ def build_custom_kwargs_for_linear (
299+ scale_getattrs : Dict [str , Node ],
300+ ) -> Dict [str , object ]:
301+ """
302+ Contract:
303+ custom_quant_linear(
304+ x, Wq, bias,
305+ input_scale=[s_in2],
306+ weight_scale=[weight_scale_cutlass_uint8, alpha_fused],
307+ input_zp=[],
308+ weight_zp=[],
309+ format_type=FORMAT_NVFP4
310+ )
311+ """
312+ return dict (
313+ input_scale = [scale_getattrs ["input_scale" ]],
314+ weight_scale = [scale_getattrs ["weight_scale" ], scale_getattrs ["alpha" ]],
315+ input_zp = [],
316+ weight_zp = [],
317+ format_type = FORMAT_NVFP4 ,
318+ )
319+
267320 @staticmethod
268321 def load_hook (state_dict , prefix , * args , weight_name ):
269322 if weight_name in state_dict :
0 commit comments