Skip to content

[Feature][ROCM] add online int4_fp8_moe quant feature#6238

Draft
DehuaTang wants to merge 1 commit intosgl-project:mainfrom
DehuaTang:quark_int4_fp8_moe
Draft

[Feature][ROCM] add online int4_fp8_moe quant feature#6238
DehuaTang wants to merge 1 commit intosgl-project:mainfrom
DehuaTang:quark_int4_fp8_moe

Conversation

@DehuaTang
Copy link
Copy Markdown

Quark int4_fp8_moe on the fly quant.

In this PR, we support on the fly quantization of int4_fp8_moe using quark.

python bench_one_batch.py --model-path /model/mistralai/Mixtral-8x7B-Instruct-v0.1 --correct --quark-config int4fp8_moe

"compressed-tensors",
"fbgemm_fp8",
"w8a8_fp8",
"quark_int4fp8_moe",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this PR related to quark? At the moment, there is no utility from quark used in this PR.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The quantized weights obtained in online_quant should actually be implemented using Quark's realquantizer. However, due to time constraints, a simple function was used.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. Should we have a code structure like

quark/
    schemes/
        quark_scheme.py
        quark_int4fp8.py
    quark.py

similar to vllm? Maybe easier to extend with mxfp4 then

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. sglang does not support quark format yet.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, it can be done in an other PR.

)


class QuarkInt4Fp8MoEMethod:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#4152 applied changes to Fp8MoEMethod class to be able to load int4 weights, using the environment variable SGLANG_INT4_WEIGHT, see

I am confused regarding the need to introduce this new class QuarkInt4Fp8MoEMethod. Once weights are quantized on the fly to int4-fp8, the execution path should be shared with what was done in #4152.

Alternatively, Fp8MoEMethod should be cleaned up if int4-fp8 related code, and int4-fp8 should always be handled with QuarkInt4Fp8MoEMethod, no matter whether quantization of weights is online or offline.

It seems a lot of code is duplicated here

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the current code does have some redundant operations with the PR you mentioned. But we don't want to use the previous process because it uses environment variables for control. This is our improvement.

return []


class QuarkInt4Fp8LinearMethod(LinearMethodBase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#4152 did not touch Fp8LinearMethod it seems

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this pr, only w1 w2 w3 use int4fp8, and the linear in attention is still fp8, so it touch this function

logger = logging.getLogger(__name__)


def apply_quark_quant_config_to_model(model_config, quark_config):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type hints

Comment on lines +34 to +99
def quantize_fp8_scale_tensorwise(w):
FP8_MAX = 448.0
scale = w.abs().amax().float() / FP8_MAX
scaled = (w / scale).clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn)
return scaled, scale


def quantize_int4_scale_columnwise(w):
S4_MAX = 7
w_flat = w.reshape(-1, w.shape[-1]).float()
scale = w_flat.abs().amax(axis=-1) / S4_MAX
scaled = torch.round(w_flat / scale[:, None]).to(torch.int8).clamp(-S4_MAX, S4_MAX)
return scaled.reshape(w.shape), scale.reshape(w.shape[:-1])


def pack(to_pack: torch.Tensor, reorder: bool = True) -> torch.Tensor:
if to_pack.ndim > 2:
raise ValueError("Pack: Only supports tensors with dimensions not greater than 2.")

if reorder:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
pack_num = 8
if to_pack.ndim == 2:
packed = torch.zeros(to_pack.shape[0], to_pack.shape[1] // pack_num, dtype=torch.int32, device=to_pack.device)
new_c = to_pack.shape[1] // pack_num
for c in range(new_c):
for i in range(pack_num):
# Use -3 as an example, high_position is 11111111,cause bit_or generate errors, so we can't use int4 directly
packed_col = to_pack[:, c * pack_num + order_map[i]].to(torch.int32)
packed_col = packed_col & 0x0F
packed[:, c] = torch.bitwise_or(packed[:, c], torch.bitwise_left_shift(packed_col, i * 4))
elif to_pack.ndim == 0:
packed = to_pack.to(torch.int32)
else:
packed = torch.zeros(to_pack.shape[0] // pack_num, dtype=torch.int32, device=to_pack.device)
new_c = to_pack.shape[0] // pack_num
for c in range(new_c):
for i in range(pack_num):
# Use -3 as an example, high_position is 11111111,cause bit_or generate errors, so we can't use int4 directly
packed_col = to_pack[c * pack_num + order_map[i]]
packed_col = packed_col & 0x0F
packed[c] = torch.bitwise_or(packed[c], torch.bitwise_left_shift(packed_col, i * 4))

return packed

def quark_quant_weights(weights_dict):
for name, loaded_weight in tqdm(weights_dict, desc="Quark Online Quantizating "):
if "w1.weight" in name or "w2.weight" in name or "w3.weight" in name:
fp8_w, fp8_scale = quantize_fp8_scale_tensorwise(loaded_weight)
int4_w, int4_scale = quantize_int4_scale_columnwise(loaded_weight)

int4_w = pack(int4_w)
int4_scale /= fp8_scale

yield name, int4_w
yield name + "_scale", fp8_scale
yield name + "_scale1", int4_scale
elif "proj.weight" in name:
fp8_w, fp8_scale = quantize_fp8_scale_tensorwise(loaded_weight)

yield name, fp8_w
yield name + "_scale", fp8_scale
else:
yield name, loaded_weight
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move these functions out

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. If you have time, it is best to use quark's function to implement it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant move them out of apply_quark_quant_config_to_model - but yeah eventually we might use quark utilities here


@staticmethod
def load_weights_and_postprocess(model, weights, target_device):
def load_weights_and_postprocess(model, model_config, weights, target_device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) is called in model_runner.py but was not modified

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for that

Comment on lines +91 to +92
yield name + "_scale", fp8_scale
yield name + "_scale1", int4_scale
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd give more explicit names

return []


class QuarkInt4Fp8LinearMethod(LinearMethodBase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it just simply FP8 static quantization for weights, dynamic quant for activations? The class should probably not be named QuarkInt4Fp8LinearMethod as int4 is not involved here.

@staticmethod
def load_weights_and_postprocess(model, weights, target_device):
def load_weights_and_postprocess(model, model_config, weights, target_device):
weights = online_quant(model_config, weights)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks a bit weird to me to call quark/int4-fp8 specific method here.

Couldn't we rather make it so that the weight_loader method handles on the fly quantization? WDYT?

Like this in QuarkInt4Fp8MoEMethod:

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # ...
        original_weight_loader = extra_weight_attrs.get("weight_loader")

        def weight_loader(
            self,
            param: torch.nn.Parameter,
            loaded_weight: torch.Tensor,
            weight_name: str,
            shard_id: str,
            expert_id: int,
        ) -> None:
            _, fp8_scale = quantize_fp8_scale_tensorwise(loaded_weight)
            int4_w, int4_scale = quantize_int4_scale_columnwise(loaded_weight)
            
            original_weight_loader(
                param,
                int4_w,
                weight_name,
                shard_id=shard_id,
                expert_id=expert_id
            )

            # maybe need to care about TP>1
            param.fp8_scale = fp8_scale
            param.int4_scale = int4_scale

        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size,
                hidden_size // 8,
                dtype=params_dtype,
            ),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", weight)

        set_weight_attrs(w13_weight, extra_weight_attrs)

        # ...
    
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # properly register int4 and fp8 scales here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you are right. This method is simple to implement and is compatible with the case where tp>1.

@fxmarty-amd
Copy link
Copy Markdown
Contributor

@merrymercy @zhyncs @HaiShaw do you have comments?

@BowenBao
Copy link
Copy Markdown
Collaborator

Replaced by #7392, feel free to close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants