[Feature][ROCM] add online int4_fp8_moe quant feature#6238
[Feature][ROCM] add online int4_fp8_moe quant feature#6238DehuaTang wants to merge 1 commit intosgl-project:mainfrom
Conversation
| "compressed-tensors", | ||
| "fbgemm_fp8", | ||
| "w8a8_fp8", | ||
| "quark_int4fp8_moe", |
There was a problem hiding this comment.
How is this PR related to quark? At the moment, there is no utility from quark used in this PR.
There was a problem hiding this comment.
The quantized weights obtained in online_quant should actually be implemented using Quark's realquantizer. However, due to time constraints, a simple function was used.
There was a problem hiding this comment.
got it. Should we have a code structure like
quark/
schemes/
quark_scheme.py
quark_int4fp8.py
quark.py
similar to vllm? Maybe easier to extend with mxfp4 then
There was a problem hiding this comment.
No. sglang does not support quark format yet.
There was a problem hiding this comment.
Okay, it can be done in an other PR.
| ) | ||
|
|
||
|
|
||
| class QuarkInt4Fp8MoEMethod: |
There was a problem hiding this comment.
#4152 applied changes to Fp8MoEMethod class to be able to load int4 weights, using the environment variable SGLANG_INT4_WEIGHT, see
I am confused regarding the need to introduce this new class QuarkInt4Fp8MoEMethod. Once weights are quantized on the fly to int4-fp8, the execution path should be shared with what was done in #4152.
Alternatively, Fp8MoEMethod should be cleaned up if int4-fp8 related code, and int4-fp8 should always be handled with QuarkInt4Fp8MoEMethod, no matter whether quantization of weights is online or offline.
It seems a lot of code is duplicated here
There was a problem hiding this comment.
Yes, the current code does have some redundant operations with the PR you mentioned. But we don't want to use the previous process because it uses environment variables for control. This is our improvement.
| return [] | ||
|
|
||
|
|
||
| class QuarkInt4Fp8LinearMethod(LinearMethodBase): |
There was a problem hiding this comment.
In this pr, only w1 w2 w3 use int4fp8, and the linear in attention is still fp8, so it touch this function
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def apply_quark_quant_config_to_model(model_config, quark_config): |
| def quantize_fp8_scale_tensorwise(w): | ||
| FP8_MAX = 448.0 | ||
| scale = w.abs().amax().float() / FP8_MAX | ||
| scaled = (w / scale).clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) | ||
| return scaled, scale | ||
|
|
||
|
|
||
| def quantize_int4_scale_columnwise(w): | ||
| S4_MAX = 7 | ||
| w_flat = w.reshape(-1, w.shape[-1]).float() | ||
| scale = w_flat.abs().amax(axis=-1) / S4_MAX | ||
| scaled = torch.round(w_flat / scale[:, None]).to(torch.int8).clamp(-S4_MAX, S4_MAX) | ||
| return scaled.reshape(w.shape), scale.reshape(w.shape[:-1]) | ||
|
|
||
|
|
||
| def pack(to_pack: torch.Tensor, reorder: bool = True) -> torch.Tensor: | ||
| if to_pack.ndim > 2: | ||
| raise ValueError("Pack: Only supports tensors with dimensions not greater than 2.") | ||
|
|
||
| if reorder: | ||
| order_map = [0, 2, 4, 6, 1, 3, 5, 7] | ||
| else: | ||
| order_map = [0, 1, 2, 3, 4, 5, 6, 7] | ||
| pack_num = 8 | ||
| if to_pack.ndim == 2: | ||
| packed = torch.zeros(to_pack.shape[0], to_pack.shape[1] // pack_num, dtype=torch.int32, device=to_pack.device) | ||
| new_c = to_pack.shape[1] // pack_num | ||
| for c in range(new_c): | ||
| for i in range(pack_num): | ||
| # Use -3 as an example, high_position is 11111111,cause bit_or generate errors, so we can't use int4 directly | ||
| packed_col = to_pack[:, c * pack_num + order_map[i]].to(torch.int32) | ||
| packed_col = packed_col & 0x0F | ||
| packed[:, c] = torch.bitwise_or(packed[:, c], torch.bitwise_left_shift(packed_col, i * 4)) | ||
| elif to_pack.ndim == 0: | ||
| packed = to_pack.to(torch.int32) | ||
| else: | ||
| packed = torch.zeros(to_pack.shape[0] // pack_num, dtype=torch.int32, device=to_pack.device) | ||
| new_c = to_pack.shape[0] // pack_num | ||
| for c in range(new_c): | ||
| for i in range(pack_num): | ||
| # Use -3 as an example, high_position is 11111111,cause bit_or generate errors, so we can't use int4 directly | ||
| packed_col = to_pack[c * pack_num + order_map[i]] | ||
| packed_col = packed_col & 0x0F | ||
| packed[c] = torch.bitwise_or(packed[c], torch.bitwise_left_shift(packed_col, i * 4)) | ||
|
|
||
| return packed | ||
|
|
||
| def quark_quant_weights(weights_dict): | ||
| for name, loaded_weight in tqdm(weights_dict, desc="Quark Online Quantizating "): | ||
| if "w1.weight" in name or "w2.weight" in name or "w3.weight" in name: | ||
| fp8_w, fp8_scale = quantize_fp8_scale_tensorwise(loaded_weight) | ||
| int4_w, int4_scale = quantize_int4_scale_columnwise(loaded_weight) | ||
|
|
||
| int4_w = pack(int4_w) | ||
| int4_scale /= fp8_scale | ||
|
|
||
| yield name, int4_w | ||
| yield name + "_scale", fp8_scale | ||
| yield name + "_scale1", int4_scale | ||
| elif "proj.weight" in name: | ||
| fp8_w, fp8_scale = quantize_fp8_scale_tensorwise(loaded_weight) | ||
|
|
||
| yield name, fp8_w | ||
| yield name + "_scale", fp8_scale | ||
| else: | ||
| yield name, loaded_weight |
There was a problem hiding this comment.
I'd move these functions out
There was a problem hiding this comment.
Yes. If you have time, it is best to use quark's function to implement it.
There was a problem hiding this comment.
I meant move them out of apply_quark_quant_config_to_model - but yeah eventually we might use quark utilities here
|
|
||
| @staticmethod | ||
| def load_weights_and_postprocess(model, weights, target_device): | ||
| def load_weights_and_postprocess(model, model_config, weights, target_device): |
There was a problem hiding this comment.
DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) is called in model_runner.py but was not modified
| yield name + "_scale", fp8_scale | ||
| yield name + "_scale1", int4_scale |
There was a problem hiding this comment.
I'd give more explicit names
| return [] | ||
|
|
||
|
|
||
| class QuarkInt4Fp8LinearMethod(LinearMethodBase): |
There was a problem hiding this comment.
Isn't it just simply FP8 static quantization for weights, dynamic quant for activations? The class should probably not be named QuarkInt4Fp8LinearMethod as int4 is not involved here.
| @staticmethod | ||
| def load_weights_and_postprocess(model, weights, target_device): | ||
| def load_weights_and_postprocess(model, model_config, weights, target_device): | ||
| weights = online_quant(model_config, weights) |
There was a problem hiding this comment.
It looks a bit weird to me to call quark/int4-fp8 specific method here.
Couldn't we rather make it so that the weight_loader method handles on the fly quantization? WDYT?
Like this in QuarkInt4Fp8MoEMethod:
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# ...
original_weight_loader = extra_weight_attrs.get("weight_loader")
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
_, fp8_scale = quantize_fp8_scale_tensorwise(loaded_weight)
int4_w, int4_scale = quantize_int4_scale_columnwise(loaded_weight)
original_weight_loader(
param,
int4_w,
weight_name,
shard_id=shard_id,
expert_id=expert_id
)
# maybe need to care about TP>1
param.fp8_scale = fp8_scale
param.int4_scale = int4_scale
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size,
hidden_size // 8,
dtype=params_dtype,
),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# ...
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# properly register int4 and fp8 scales here.There was a problem hiding this comment.
Maybe you are right. This method is simple to implement and is compatible with the case where tp>1.
|
@merrymercy @zhyncs @HaiShaw do you have comments? |
|
Replaced by #7392, feel free to close this. |
Quark int4_fp8_moe on the fly quant.
In this PR, we support on the fly quantization of int4_fp8_moe using quark.
python bench_one_batch.py --model-path /model/mistralai/Mixtral-8x7B-Instruct-v0.1 --correct --quark-config int4fp8_moe