-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Add Mistral Large 3 support. #14213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Mistral Large 3 support. #14213
Changes from all commits
90098a5
0cef9be
b92e641
4bfe175
94f923f
4fe81ec
6942b91
efc8ffa
896aa66
f1fa66f
1a3b88d
0353f52
69186b3
7aac216
ee68d59
bc4ed51
74cda59
7f0c021
15d6201
f4ad69c
0a48bd0
90f589a
c7be4b3
5100fa9
f0f5064
50fdab2
1ef9948
b9d5f64
dc500a0
77319c4
623b33a
73abf6b
d7a7b6c
1b1c84b
fcc2c6d
5d077e1
5d52732
13ccef9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| from compressed_tensors import CompressionFormat | ||
| from compressed_tensors.quantization import QuantizationStrategy | ||
|
|
||
| from sglang.srt.distributed import get_tensor_model_parallel_world_size | ||
| from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig | ||
| from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo | ||
| from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase | ||
|
|
@@ -81,8 +82,8 @@ def get_moe_method( | |
|
|
||
| weight_quant = quant_config.target_scheme_map["Linear"].get("weights") | ||
| input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") | ||
| if quant_config._is_wNa16_group_channel(weight_quant, input_quant): | ||
|
|
||
| if quant_config._is_wNa16_group_channel(weight_quant, input_quant): | ||
| logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") | ||
| return CompressedTensorsWNA16MoEMethod(quant_config) | ||
| elif quant_config._is_fp8_w8a8(weight_quant, input_quant): | ||
|
|
@@ -102,7 +103,28 @@ def __init__(self, quant_config: CompressedTensorsConfig): | |
| "input_activations" | ||
| ) | ||
|
|
||
| per_tensor = ( | ||
| self.weight_quant.strategy == QuantizationStrategy.TENSOR | ||
| and self.input_quant.strategy == QuantizationStrategy.TENSOR | ||
| ) | ||
| per_channel = ( | ||
| self.weight_quant.strategy == QuantizationStrategy.CHANNEL | ||
| and self.input_quant.strategy == QuantizationStrategy.TOKEN | ||
| ) | ||
| if not (per_tensor or per_channel): | ||
| assert self.weight_quant.strategy == QuantizationStrategy.BLOCK | ||
| self.weight_block_size = self.weight_quant.block_structure | ||
| assert self.weight_quant.dynamic is not None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self.weight_quant.dynamic is false, so it can be None?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is mostly from vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py#L667-L681 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
like this recipe https://github.com/vllm-project/llm-compressor/blob/aa504491afd28a0d5f66d3e38088352dcb4e63ff/src/llmcompressor/modifiers/quantization/gptq/base.py#L57, there is no attribute of dynamic |
||
| else: | ||
| self.weight_block_size = None | ||
| self.block_quant = self.weight_block_size is not None | ||
|
|
||
| self.static_input_scales = not self.input_quant.dynamic | ||
| if self.static_input_scales and per_channel: | ||
| raise ValueError( | ||
| "For FP8 Fused MoE layer, we require either per tensor or " | ||
| "channelwise, dynamic per token quantization." | ||
| ) | ||
|
|
||
| def create_weights( | ||
| self, | ||
|
|
@@ -117,6 +139,32 @@ def create_weights( | |
|
|
||
| params_dtype = torch.float8_e4m3fn | ||
|
|
||
| if self.block_quant: | ||
| assert self.weight_block_size is not None | ||
| layer.weight_block_size = self.weight_block_size | ||
| tp_size = get_tensor_model_parallel_world_size() | ||
| block_n, block_k = ( | ||
| self.weight_block_size[0], | ||
| self.weight_block_size[1], | ||
| ) | ||
| # NOTE: To ensure proper alignment of the block-wise quantization | ||
| # scales, the output_size of the weights for both the gate and up | ||
| # layers must be divisible by block_n. | ||
| # Required by column parallel or enabling merged weights | ||
| if intermediate_size_per_partition % block_n != 0: | ||
| raise ValueError( | ||
| f"The output_size of gate's and up's weight = " | ||
| f"{intermediate_size_per_partition} is not divisible by " | ||
| f"weight quantization block_n = {block_n}." | ||
| ) | ||
| if tp_size > 1 and intermediate_size_per_partition % block_k != 0: | ||
| # Required by row parallel | ||
| raise ValueError( | ||
| f"The input_size of down's weight = " | ||
| f"{intermediate_size_per_partition} is not divisible by " | ||
| f"weight quantization block_k = {block_k}." | ||
| ) | ||
|
|
||
| # WEIGHTS | ||
| w13_weight = torch.nn.Parameter( | ||
| torch.empty( | ||
|
|
@@ -169,6 +217,26 @@ def create_weights( | |
| requires_grad=False, | ||
| ) | ||
| weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value | ||
| elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: | ||
| w13_weight_scale = torch.nn.Parameter( | ||
| torch.ones( | ||
| num_experts, | ||
| 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), | ||
| (hidden_size + block_k - 1) // block_k, | ||
| dtype=torch.float32, | ||
| ), | ||
| requires_grad=False, | ||
| ) | ||
| w2_weight_scale = torch.nn.Parameter( | ||
| torch.ones( | ||
| num_experts, | ||
| (hidden_size + block_n - 1) // block_n, | ||
| (intermediate_size_per_partition + block_k - 1) // block_k, | ||
| dtype=torch.float32, | ||
| ), | ||
| requires_grad=False, | ||
| ) | ||
| weight_quant_method = FusedMoeWeightScaleSupported.BLOCK.value | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported weight quantization strategy: {self.weight_quant.strategy}" | ||
|
|
@@ -343,6 +411,18 @@ def apply( | |
| a2_scale=layer.w2_input_scale, | ||
| ) | ||
| return StandardCombineInput(hidden_states=output) | ||
| elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: | ||
| quant_info = TritonMoeQuantInfo( | ||
| w13_weight=layer.w13_weight, | ||
| w2_weight=layer.w2_weight, | ||
| use_fp8_w8a8=True, | ||
| w13_scale=layer.w13_weight_scale, | ||
| w2_scale=layer.w2_weight_scale, | ||
| a13_scale=layer.w13_input_scale, | ||
| a2_scale=layer.w2_input_scale, | ||
| block_shape=self.weight_block_size, | ||
| ) | ||
| return self.runner.run(dispatch_output, quant_info) | ||
| else: | ||
| quant_info = TritonMoeQuantInfo( | ||
| w13_weight=layer.w13_weight, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we add this logic here, we will have to do the same thing for a ton of attn backends and sync them. thus one way maybe just put logic in the models folder (deepseek_v2.py or mistral py etc), since looks like it is just scaling the q tensor before entering the core attention logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From Mistral's implementation in vLLM, the scaling is applied between RoPE and attention, so it implements inside the mla layer: https://github.com/vllm-project/vllm/pull/29757/files#diff-6ffcb4f51daf85df32c7d35433c3393f1602663960f677ae61f55af1ed3ab524
In SGLang, for trtllm mla backend, the RoPE is fused into the attention backend. This is quite different from vLLM. The scaling is needed to be passed into the backend to apply correctly.