-
-
Notifications
You must be signed in to change notification settings - Fork 16.4k
[LoRA] Initial EP support for LoRA #40867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
85cf592
9c64777
7b81c9e
80d0188
9350a67
b3d1ea6
13caeb4
1455838
51be12a
c2dbb14
4f3b7f9
550e19d
019cfa1
d1ae808
61d7746
ea4a8fd
5600221
cd29a49
7707bf3
166386e
6fe6601
99c00a2
7c855e4
ce0f6c3
9495872
5c1fe18
3efd6c5
bc9b997
fe00d8c
400d6cd
57cab35
381c45c
08e33a6
fbc1a61
9d244b3
a264079
5e93587
6fa9268
4172be4
1cea06b
bf3d2a8
a8311e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,10 +7,6 @@ | |
|
|
||
| from vllm import envs | ||
| from vllm.config.lora import LoRAConfig | ||
| from vllm.distributed.parallel_state import ( | ||
| get_tensor_model_parallel_rank, | ||
| get_tensor_model_parallel_world_size, | ||
| ) | ||
| from vllm.distributed.utils import divide | ||
| from vllm.lora.layers.base import BaseLayerWithLoRA | ||
| from vllm.model_executor.layers.fused_moe import FusedMoE | ||
|
|
@@ -30,15 +26,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): | |
| def __init__(self, base_layer: FusedMoE) -> None: | ||
| super().__init__() | ||
| self.base_layer = base_layer | ||
|
|
||
| assert not self.base_layer.use_ep, ( | ||
| "EP support for Fused MoE LoRA is not implemented yet." | ||
| ) | ||
| assert not self.base_layer.quant_method.is_monolithic, ( | ||
| "Monolithic kernels are not supported for Fused MoE LoRA." | ||
| ) | ||
| self.tp_size = get_tensor_model_parallel_world_size() | ||
| self.tp_rank = get_tensor_model_parallel_rank() | ||
| self._ep_check() | ||
| # Use the MoE-aware TP rank/size: when EP is active, FusedMoE collapses | ||
| # moe_parallel_config.tp_size to 1 (experts are sharded across the | ||
| # TP group instead). | ||
| self.tp_size = self.base_layer.tp_size | ||
| self.tp_rank = self.base_layer.tp_rank | ||
| self.device = _get_lora_device(base_layer) | ||
| # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed | ||
| # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) | ||
|
|
@@ -65,7 +58,7 @@ def __init__(self, base_layer: FusedMoE) -> None: | |
| "For quantized MoE, mix LoRAExpertsMixin into the experts class " | ||
| "and consume self._lora_context in apply()." | ||
| ) | ||
| self._fused_experts = moe_kernel.fused_experts | ||
| self._moe_kernel = moe_kernel | ||
| self.base_layer._replace_quant_method( | ||
| FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel) | ||
| ) | ||
|
|
@@ -150,13 +143,35 @@ def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig): | |
| ), | ||
| ) | ||
|
|
||
| def _ep_check(self): | ||
| if self.base_layer.use_ep: | ||
| moe_config = self.base_layer.moe_config | ||
| all2all_backend = moe_config.moe_parallel_config.all2all_backend | ||
| assert all2all_backend == "allgather_reducescatter", ( | ||
| "Fused MoE LoRA with EP currently only supports " | ||
| f"all2all_backend='allgather_reducescatter', got '{all2all_backend}'." | ||
| ) | ||
| assert not moe_config.moe_parallel_config.is_sequence_parallel | ||
|
|
||
| def _verify_ep_fs(self, lora_config: LoRAConfig): | ||
| # EP and fully_sharded LoRA both partition along the same TP group — | ||
| # EP on the expert dim, fully_sharded on the LoRA rank dim — with | ||
| # mutually contradictory assumptions about which rank holds which | ||
| # expert's rank-shard. | ||
| assert not (self.base_layer.use_ep and lora_config.fully_sharded_loras), ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, do you know anyone using this
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HollowMan6 I know your team tried
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. It generally works okay, except this bug #35077 (comment) But once LoRA + EP is supported, I don't think we need to have support for it to be enabled at the same time. |
||
| "Fused MoE LoRA does not support enable_expert_parallel=True " | ||
| "together with fully_sharded_loras=True. Disable one of them." | ||
| ) | ||
|
|
||
| def create_lora_weights( | ||
| self, | ||
| max_loras: int, | ||
| lora_config: LoRAConfig, | ||
| model_config: PretrainedConfig | None = None, | ||
| ) -> None: | ||
| """Initializes lora matrices.""" | ||
|
|
||
| self._verify_ep_fs(lora_config) | ||
| self.max_loras = lora_config.max_loras | ||
| self.fully_sharded = lora_config.fully_sharded_loras | ||
|
|
||
|
|
@@ -282,6 +297,24 @@ def set_lora( | |
|
|
||
| w1_lora_a, w2_lora_a, w3_lora_a = lora_a | ||
| w1_lora_b, w2_lora_b, w3_lora_b = lora_b | ||
|
|
||
| # Under EP the adapter tensors carry all global experts; slice this | ||
| # rank's owned range so downstream shapes line up with local buffers. | ||
| global_num_experts = self.base_layer.global_num_experts | ||
| ep_rank = self.base_layer.ep_rank | ||
| if ( | ||
| w1_lora_a.shape[0] == global_num_experts | ||
| and num_experts != global_num_experts | ||
| ): | ||
| expert_start = ep_rank * num_experts | ||
| expert_end = expert_start + num_experts | ||
| w1_lora_a = w1_lora_a[expert_start:expert_end] | ||
| w2_lora_a = w2_lora_a[expert_start:expert_end] | ||
| w3_lora_a = w3_lora_a[expert_start:expert_end] | ||
| w1_lora_b = w1_lora_b[expert_start:expert_end] | ||
| w2_lora_b = w2_lora_b[expert_start:expert_end] | ||
| w3_lora_b = w3_lora_b[expert_start:expert_end] | ||
|
Comment on lines
+301
to
+316
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this slicing be moved to load instead? If it's here in the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense |
||
|
|
||
|
jeejeelee marked this conversation as resolved.
|
||
| assert ( | ||
| num_experts | ||
| == w1_lora_a.shape[0] | ||
|
|
@@ -326,7 +359,11 @@ def set_lora( | |
|
|
||
| def set_mapping(self, punica_wrapper): | ||
| super().set_mapping(punica_wrapper) | ||
| self._fused_experts.set_lora_context(self._build_lora_context()) | ||
| lora_context = self._build_lora_context() | ||
| self._moe_kernel.fused_experts.set_lora_context(lora_context) | ||
| prepare_finalize = self._moe_kernel.prepare_finalize | ||
| if hasattr(prepare_finalize, "set_lora_context"): | ||
| prepare_finalize.set_lora_context(lora_context) | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| return self.base_layer.forward(*args, **kwargs) | ||
|
|
@@ -396,6 +433,7 @@ def create_lora_weights( | |
| """Initializes lora matrices.""" | ||
|
|
||
| assert isinstance(model_config, PretrainedConfig) | ||
| self._verify_ep_fs(lora_config) | ||
| self._base_model = model_config.architectures[0] | ||
| self.max_loras = lora_config.max_loras | ||
| self.fully_sharded = lora_config.fully_sharded_loras | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -562,6 +562,10 @@ def create_dummy_lora( | |
| else: | ||
| parts = module_name.split(".") | ||
| replacements = self.packed_modules_mapping[parts[-1]] | ||
| if module.__class__.__name__ == "FusedMoEWithLoRA": | ||
| replacements = replacements[ | ||
| : len(module.lora_a_stacked) // self.lora_slots | ||
| ] | ||
|
Comment on lines
+565
to
+568
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Im actually kind of lost as to what is happening here 😓 will read in detail later. But just a quick question out of curiosity. Why do we do this packing at
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One benefit of moving it is that it makes the loading more efficient. We dont need to allocate all the small 2D MoE tensors at load time then pack them into 3D at add time. We can instead just allocate in 3D and load the 2D slices into it with local expert subsetting! |
||
| subloras: list[LoRALayerWeights | None] = [] | ||
| for i, r in enumerate(replacements): | ||
| lora = LoRALayerWeights.create_dummy_lora_weights( | ||
|
|
@@ -762,23 +766,33 @@ def _stack_moe_lora_weights( | |
| assert gate_up_proj_lora is not None | ||
| assert down_proj_lora is not None | ||
| if self._is_3d_moe_model: | ||
| num_experts = module.w13_lora_a_stacked[0].shape[1] | ||
| local_num_experts = module.w13_lora_a_stacked[0].shape[1] | ||
| # The checkpoint holds weights for all global experts, but | ||
| # each EP rank owns only local_num_experts. Reshape against | ||
| # the adapter's actual expert count, then slice this rank's | ||
| # owned expert range before it gets copied into the local | ||
| # stacked buffer. For non-EP (local == global) this is a | ||
| # no-op slice. | ||
| global_num_experts = module.base_layer.global_num_experts | ||
| ep_rank = module.base_layer.ep_rank | ||
| expert_start = ep_rank * local_num_experts | ||
| expert_end = expert_start + local_num_experts | ||
|
|
||
| # (num_experts,rank,input_size) | ||
| gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape( | ||
| num_experts, -1, gate_up_proj_lora.lora_a.shape[-1] | ||
| ) | ||
| global_num_experts, -1, gate_up_proj_lora.lora_a.shape[-1] | ||
| )[expert_start:expert_end].contiguous() | ||
| down_proj_lora.lora_a = down_proj_lora.lora_a.reshape( | ||
| num_experts, -1, down_proj_lora.lora_a.shape[-1] | ||
| ) | ||
| global_num_experts, -1, down_proj_lora.lora_a.shape[-1] | ||
| )[expert_start:expert_end].contiguous() | ||
|
|
||
| # (output_size,rank,num_experts) | ||
| gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape( | ||
| gate_up_proj_lora.lora_b.shape[0], -1, num_experts | ||
| ) | ||
| gate_up_proj_lora.lora_b.shape[0], -1, global_num_experts | ||
| )[..., expert_start:expert_end] | ||
| down_proj_lora.lora_b = down_proj_lora.lora_b.reshape( | ||
| down_proj_lora.lora_b.shape[0], -1, num_experts | ||
| ) | ||
| down_proj_lora.lora_b.shape[0], -1, global_num_experts | ||
| )[..., expert_start:expert_end] | ||
|
|
||
| # (num_experts,output_size,rank) | ||
| gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute( | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.