@@ -302,18 +302,8 @@ def load_weights(weights, model_runner):
302302 param_scale = torch .squeeze (param_scale , dim = - 1 )
303303 weights_quantized .append ([k , param_lp ])
304304 weights_quantized .append ([k + "_scale_inv" , param_scale ])
305- # Monkey patch the param class to their subclass, as certain models
306- # will check the param type to call the proper weightloader
307- for name , param in model .named_parameters ():
308- if hasattr (param , "subclass_type" ):
309- param .orig_type = param .__class__
310- param .__class__ = param .subclass_type
311305 # Finally load the weights into vllm
312306 model .load_weights (weights_quantized )
313- # Undo the type change above to the original type
314- for name , param in model .named_parameters ():
315- if hasattr (param , "subclass_type" ):
316- param .__class__ = param .orig_type
317307
318308
319309def cast_tensor_to_fp8_blockwise (
@@ -324,12 +314,25 @@ def cast_tensor_to_fp8_blockwise(
324314
325315 block_size1 = weight_block_size [1 ]
326316 block_size0 = weight_block_size [0 ]
327- assert data_hp .shape [1 ] % block_size1 == 0 , (
328- f"data_hp.shape[1] { data_hp .shape [1 ]} must be a multiple of block_size1: { block_size1 } ."
329- )
330- assert data_hp .shape [0 ] % block_size0 == 0 , (
331- f"data_hp.shape[0] { data_hp .shape [0 ]} must be a multiple of block_size0: { block_size0 } ."
332- )
317+ shape_before_padding = data_hp .shape
318+ # pad data_hp to make its shape a multiple of weight_block_size with the last element of data_hp
319+ if data_hp .shape [1 ] % block_size1 != 0 or data_hp .shape [0 ] % block_size0 != 0 :
320+ pad1 = (
321+ 0
322+ if data_hp .shape [1 ] % block_size1 == 0
323+ else block_size1 - data_hp .shape [1 ] % block_size1
324+ )
325+ pad0 = (
326+ 0
327+ if data_hp .shape [0 ] % block_size0 == 0
328+ else block_size0 - data_hp .shape [0 ] % block_size0
329+ )
330+ print (
331+ f"Padding data_hp from { data_hp .shape } to { (data_hp .shape [0 ] + pad0 , data_hp .shape [1 ] + pad1 )} "
332+ )
333+ data_hp = torch .nn .functional .pad (
334+ data_hp , (0 , pad1 , 0 , pad0 ), mode = "constant" , value = data_hp [- 1 , - 1 ]
335+ )
333336
334337 # FP8
335338 max_dtype = torch .finfo (torch .float8_e4m3fn ).max
@@ -385,57 +388,35 @@ def cast_tensor_to_fp8_blockwise(
385388 .reshape (original_shape )
386389 )
387390
391+ # remove the padding
392+ if data_hp .shape != shape_before_padding :
393+ fp_data = fp_data [: shape_before_padding [0 ], : shape_before_padding [1 ]]
394+
388395 # Convert to target format, but still in original precision container
389396 return fp_data , descale_fp
390397
391398
392399def process_weights_after_loading (self , layer ) -> None :
393- from torch .nn import Parameter
394- from vllm .model_executor .parameter import (
395- BlockQuantScaleParameter ,
396- ModelWeightParameter ,
400+ from vllm .model_executor .layers .quantization .utils .fp8_utils import (
401+ maybe_post_process_fp8_weight_block ,
402+ process_fp8_weight_block_strategy ,
397403 )
398404
399405 assert self .block_quant and self .quant_config .is_checkpoint_fp8_serialized
400406 assert self .quant_config .activation_scheme == "dynamic"
401407
402- def _create_param_from_subclass_attributes (custom_param ):
403- param = Parameter (custom_param .data , requires_grad = False )
404- base_param_dir = dir (torch .nn .Parameter )
405- custom_param_dir = dir (custom_param )
406- # Find the attributes that are unique to the custom parameter
407- custom_attributes = [
408- attr
409- for attr in custom_param_dir
410- if attr not in base_param_dir and not attr .startswith ("__" )
411- ]
412- # Set the custom attributes into the base parameter object
413- for attr in custom_attributes :
414- setattr (param , attr , getattr (custom_param , attr ))
415-
416- param .subclass_type = type (custom_param )
417- return param
418-
419- weight = layer .weight .data
420- weight_scale_inv = layer .weight_scale_inv .data
421- weight = self ._maybe_pad_weight (weight )
422-
423- layer .weight = _create_param_from_subclass_attributes (
424- ModelWeightParameter (
425- data = weight ,
426- output_dim = 0 ,
427- input_dim = 1 ,
428- weight_loader = layer .weight .weight_loader ,
429- )
430- )
431- layer .weight_scale_inv = _create_param_from_subclass_attributes (
432- BlockQuantScaleParameter (
433- data = weight_scale_inv ,
434- output_dim = 0 ,
435- input_dim = 1 ,
436- weight_loader = layer .weight_scale_inv .weight_loader ,
437- )
438- )
408+ weight_scale = layer .weight_scale_inv
409+ weight , weight_scale = process_fp8_weight_block_strategy (layer .weight , weight_scale )
410+ layer .weight .data = weight .data
411+ if hasattr (layer , "weight_scale" ):
412+ # Not the first time to call this function, just need to update the data
413+ layer .weight_scale .data = weight_scale .data
414+ else :
415+ # The first time to call this function, create a new parameter and update the tp status
416+ layer .weight_scale = torch .nn .Parameter (weight_scale .data , requires_grad = False )
417+ layer .update_param_tp_status ()
418+
419+ maybe_post_process_fp8_weight_block (layer , self .cutlass_block_fp8_supported )
439420
440421
441422@triton .jit
0 commit comments