Skip to content

Commit 93e28e6

Browse files
momo609wangxiaoxin-sherie
andauthored
add weight transpose check. (#2756)
### What this PR does / why we need it? In reinforcement learning scenarios, weight updates are required, but the current inference applies a transpose operation to the weights, altering their shape. This causes a shape mismatch with the training weights, triggering an error during weight updates. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@6fb2788 Signed-off-by: wangxiaoxin-sherie <[email protected]> Co-authored-by: wangxiaoxin-sherie <[email protected]>
1 parent e13c4dd commit 93e28e6

File tree

2 files changed

+109
-9
lines changed

2 files changed

+109
-9
lines changed

tests/ut/ops/test_common_fused_moe.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818

1919
from tests.ut.base import TestBase
20-
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
20+
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
2121

2222

2323
class TestFusedExpertsMoGE(TestBase):
@@ -67,3 +67,39 @@ def test_fused_experts_moge(self):
6767
)
6868

6969
self.assertEqual(output.shape, (4, 128))
70+
71+
72+
class TestLoadWeight(TestBase):
73+
74+
def test_load_w13_transpose(self):
75+
with patch.object(AscendFusedMoE, "__init__",
76+
lambda self, *args, **kwargs: None):
77+
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
78+
moe.hidden_size = 8
79+
expert_data = torch.randn(128, 8)
80+
loaded_weight = torch.randn(128, 4)
81+
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
82+
83+
expert_data = torch.randn(8, 128)
84+
loaded_weight = torch.randn(128, 4)
85+
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
86+
87+
expert_data = torch.randn(128, 8)
88+
loaded_weight = torch.randn(128, 4)
89+
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
90+
91+
expert_data = torch.randn(8, 128)
92+
loaded_weight = torch.randn(128, 4)
93+
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
94+
95+
def test_load_w2_transpose(self):
96+
with patch.object(AscendFusedMoE, "__init__",
97+
lambda self, *args, **kwargs: None):
98+
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
99+
expert_data = torch.randn(128, 4)
100+
loaded_weight = torch.randn(128, 8)
101+
moe._load_w2(expert_data, 1, loaded_weight, 0)
102+
103+
expert_data = torch.randn(4, 128)
104+
loaded_weight = torch.randn(128, 8)
105+
moe._load_w2(expert_data, 1, loaded_weight, 0)

vllm_ascend/ops/common_fused_moe.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
134134
self.use_aclgraph = (vllm_config.compilation_config.level
135135
== CompilationLevel.PIECEWISE
136136
and not vllm_config.model_config.enforce_eager)
137+
self.transpose = True
137138

138139

139140
def forward_oot_v01011(
@@ -261,13 +262,22 @@ def forward_oot(
261262

262263
def process_weights_after_loading(self, layer):
263264
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
264-
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
265-
1, 2).contiguous()
266-
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
265+
if self.transpose:
266+
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
267+
1, 2).contiguous()
268+
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
267269

268-
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
269-
1, 2).contiguous()
270-
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
270+
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
271+
1, 2).contiguous()
272+
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
273+
274+
self.transpose = False
275+
else:
276+
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
277+
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
278+
279+
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
280+
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
271281

272282
if not is_310p():
273283
layer.w13_weight.data = torch_npu.npu_format_cast(
@@ -358,12 +368,11 @@ def __init__(
358368
num_redundant_experts,
359369
has_bias,
360370
)
361-
362371
setup_token_dispatchers(self.moe_config.ep_size,
363372
top_k=self.top_k,
364373
num_experts=self.global_num_experts,
365374
num_local_experts=self.local_num_experts)
366-
375+
self.hidden_size = hidden_size
367376
self.moe_config.tp_group = get_tp_group()
368377
self.moe_config.dp_group = get_dp_group()
369378
self.moe_config.ep_group = get_ep_group()
@@ -430,6 +439,61 @@ def forward_impl(self, hidden_states: torch.Tensor,
430439

431440
return final_hidden_states
432441

442+
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
443+
# Ensure training and inference weight shapes match during RL weight updates
444+
if (
445+
loaded_weight.shape[1] != expert_data.shape[1] and \
446+
loaded_weight.shape[0] != expert_data.shape[0]
447+
):
448+
shard_dim = int(not shard_dim)
449+
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
450+
return loaded_weight, shard_dim
451+
452+
def _load_w13(self,
453+
expert_data: torch.Tensor,
454+
shard_dim: int,
455+
shard_id: str,
456+
loaded_weight: torch.Tensor,
457+
tp_rank: int,
458+
load_full: bool = False):
459+
# Index the loaded weight for tp sharding.
460+
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
461+
loaded_weight, shard_dim = self.transpose_weight(
462+
loaded_weight, expert_data, shard_dim)
463+
shard_size = expert_data.shape[shard_dim] // 2
464+
if not load_full:
465+
loaded_weight = loaded_weight.narrow(shard_dim,
466+
shard_size * tp_rank,
467+
shard_size)
468+
# Narrow parameter and load.
469+
# w1, gate_proj: Load into first logical weight of w13.
470+
if shard_id == "w1":
471+
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
472+
# w3, up_proj: Load into second logical weight of w13.
473+
else:
474+
assert shard_id == "w3"
475+
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
476+
expert_data.copy_(loaded_weight)
477+
478+
def _load_w2(self,
479+
expert_data: torch.Tensor,
480+
shard_dim: int,
481+
loaded_weight: torch.Tensor,
482+
tp_rank: int,
483+
load_full: bool = False):
484+
# Index the loaded weight for tp sharding.
485+
# down_proj: "RowParallel" so tp sharding on input_dim
486+
# Narrow parameter and load.
487+
loaded_weight, shard_dim = self.transpose_weight(
488+
loaded_weight, expert_data, shard_dim)
489+
shard_size = expert_data.shape[shard_dim]
490+
if not load_full:
491+
loaded_weight = loaded_weight.narrow(shard_dim,
492+
shard_size * tp_rank,
493+
shard_size)
494+
# w2, down_proj: Load into only logical weight of w2.
495+
expert_data.copy_(loaded_weight)
496+
433497

434498
class AscendSharedFusedMoE(AscendFusedMoE):
435499

0 commit comments

Comments
 (0)