|
51 | 51 | from vllm.model_executor.utils import set_weight_attrs |
52 | 52 | from vllm.sequence import SamplerOutput |
53 | 53 | from vllm.utils import print_warning_once |
54 | | - |
| 54 | +import intel_extension_for_pytorch as ipex |
| 55 | +from intel_extension_for_pytorch.cpu._auto_kernel_selection import ( |
| 56 | + _enable_tpp, |
| 57 | + _disable_tpp, |
| 58 | +) |
| 59 | +class _IPEXlinearMOECPU(nn.Module): |
| 60 | + def __init__(self, W13, W2, W3=None, tpp=False, woq=False): |
| 61 | + super().__init__() |
| 62 | + self.tpp = tpp |
| 63 | + self.woq = woq |
| 64 | + self.num_experts = W2.shape[0] |
| 65 | + self.hidden_size = W2.shape[1] |
| 66 | + self.intermediate_size = W2.shape[2] |
| 67 | + |
| 68 | + linear_list = [] |
| 69 | + for i in range(W2.shape[0]): |
| 70 | + if W3 is not None: |
| 71 | + _W1 = W13[i] |
| 72 | + else: |
| 73 | + _W1 = W13[i][0 : self.intermediate_size, :] |
| 74 | + _W3 = W13[i][self.intermediate_size : 2 * self.intermediate_size, :] |
| 75 | + linear1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| 76 | + linear1.weight = nn.Parameter(_W1) |
| 77 | + linear2 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| 78 | + linear2.weight = nn.Parameter(W2[i]) |
| 79 | + linear3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| 80 | + linear3.weight = nn.Parameter(_W3) |
| 81 | + linear_per_expert = nn.ModuleList([linear1, linear2, linear3]) |
| 82 | + linear_list.append(linear_per_expert) |
| 83 | + self.linear_module_list = nn.ModuleList([linear_list[i] for i in range(W2.shape[0])]) |
| 84 | + |
| 85 | + def forward(self, hidden_states, score, topk): |
| 86 | + batch_size, head_dim = hidden_states.shape |
| 87 | + routing_weights = torch.nn.functional.softmax(score, dim=1, dtype=torch.float32) |
| 88 | + routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1) |
| 89 | + routing_weights = routing_weights.to(hidden_states.dtype) |
| 90 | + final_hidden_states = torch.zeros( |
| 91 | + (batch_size, head_dim), |
| 92 | + dtype=hidden_states.dtype, |
| 93 | + device=hidden_states.device, |
| 94 | + ) |
| 95 | + expert_mask = torch.nn.functional.one_hot( |
| 96 | + selected_experts, num_classes=self.num_experts |
| 97 | + ).permute(2, 1, 0) |
| 98 | + for expert_idx in range(self.num_experts): |
| 99 | + idx, top_x = torch.where(expert_mask[expert_idx]) |
| 100 | + if ( |
| 101 | + hasattr(self.linear_module_list[expert_idx][0], "use_dnnl") |
| 102 | + and self.linear_module_list[expert_idx][0].use_dnnl |
| 103 | + ): |
| 104 | + final_hidden_states = torch.ops.torch_ipex.mixtral_moe( |
| 105 | + hidden_states, |
| 106 | + top_x, |
| 107 | + idx, |
| 108 | + self.linear_module_list[expert_idx][0]._get_forward_weight(), |
| 109 | + self.linear_module_list[expert_idx][0].ctx.get_data_handle(), |
| 110 | + self.linear_module_list[expert_idx][2]._get_forward_weight(), |
| 111 | + self.linear_module_list[expert_idx][2].ctx.get_data_handle(), |
| 112 | + self.linear_module_list[expert_idx][1]._get_forward_weight(), |
| 113 | + self.linear_module_list[expert_idx][1].ctx.get_data_handle(), |
| 114 | + hasattr(self.linear_module_list[expert_idx][0], "use_dnnl") |
| 115 | + and self.linear_module_list[expert_idx][0].use_dnnl, |
| 116 | + routing_weights, |
| 117 | + final_hidden_states, |
| 118 | + False, |
| 119 | + ) |
| 120 | + else: |
| 121 | + final_hidden_states = torch.ops.torch_ipex.mixtral_moe_tpp( |
| 122 | + hidden_states, |
| 123 | + top_x, |
| 124 | + idx, |
| 125 | + self.linear_module_list[expert_idx][0].weight.detach(), |
| 126 | + self.linear_module_list[expert_idx][2].weight.detach(), |
| 127 | + self.linear_module_list[expert_idx][1].weight.detach(), |
| 128 | + ( |
| 129 | + self.linear_module_list[expert_idx][0].tpp_fallback |
| 130 | + if hasattr( |
| 131 | + self.linear_module_list[expert_idx][0], "tpp_fallback" |
| 132 | + ) |
| 133 | + else True |
| 134 | + ), |
| 135 | + routing_weights, |
| 136 | + final_hidden_states, |
| 137 | + False, |
| 138 | + ) |
| 139 | + |
| 140 | + return final_hidden_states.view(-1, head_dim) |
55 | 141 |
|
56 | 142 | class MixtralMoE(nn.Module): |
57 | 143 | """A tensor-parallel MoE implementation for Mixtral that shards each expert |
@@ -108,14 +194,12 @@ def __init__( |
108 | 194 | self.hidden_size, |
109 | 195 | self.intermediate_size, |
110 | 196 | dtype=params_dtype)) |
111 | | - |
112 | 197 | set_weight_attrs(self.w13_weight, { |
113 | 198 | "weight_loader": self.weight_loader, |
114 | 199 | }) |
115 | 200 | set_weight_attrs(self.w2_weight, { |
116 | 201 | "weight_loader": self.weight_loader, |
117 | 202 | }) |
118 | | - |
119 | 203 | # Used for fp8. |
120 | 204 | self.w13_scale = None |
121 | 205 | self.w2_scale = None |
@@ -221,22 +305,17 @@ def process_weights_after_loading(self): |
221 | 305 |
|
222 | 306 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
223 | 307 | num_tokens, hidden_size = hidden_states.shape |
| 308 | + |
224 | 309 | hidden_states = hidden_states.view(-1, self.hidden_size) |
225 | 310 | # router_logits: (num_tokens, n_experts) |
226 | 311 | router_logits, _ = self.gate(hidden_states) |
227 | | - final_hidden_states = fused_moe(hidden_states, |
228 | | - self.w13_weight, |
229 | | - self.w2_weight, |
230 | | - router_logits, |
231 | | - self.top_k, |
232 | | - renormalize=True, |
233 | | - inplace=True, |
234 | | - use_fp8=self.use_fp8, |
235 | | - w1_scale=self.w13_scale, |
236 | | - w2_scale=self.w2_scale, |
237 | | - a1_scale=self.a13_scale, |
238 | | - a2_scale=self.a2_scale) |
239 | | - |
| 312 | + if not hasattr(self, "ipex_moe"): |
| 313 | + self.ipex_moe = _IPEXlinearMOECPU(self.w13_weight, self.w2_weight) |
| 314 | + _disable_tpp() |
| 315 | + if hidden_states.dtype is torch.bfloat16: |
| 316 | + _enable_tpp() |
| 317 | + self.ipex_moe = ipex.optimize(self.ipex_moe.eval(), dtype=hidden_states.dtype, inplace=True) |
| 318 | + final_hidden_states = self.ipex_moe(hidden_states, router_logits, self.top_k) |
240 | 319 | if self.tp_size > 1: |
241 | 320 | final_hidden_states = tensor_model_parallel_all_reduce( |
242 | 321 | final_hidden_states) |
|
0 commit comments