You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# expert_mask is of size (self.num_experts_per_partition + 1),
1137
+
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
1138
+
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
1139
+
# self.expert_mask = [1, 1, 1, 1, 0]
1140
+
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
1141
+
self.expert_mask=torch.zeros(
1142
+
(self.num_experts_per_partition+1),
1143
+
device=torch.cuda.current_device(),
1144
+
dtype=torch.int,
1145
+
)
1146
+
# the last one is invalid rank_id
1147
+
self.expert_mask[:-1] =1
1148
+
else:
1149
+
self.w13_weight_fp8= (
1150
+
self.w13_weight,
1151
+
(
1152
+
self.w13_weight_scale_inv
1153
+
ifself.use_block_quant
1154
+
elseself.w13_weight_scale
1155
+
),
1156
+
)
1157
+
self.w2_weight_fp8= (
1158
+
self.w2_weight,
1159
+
(
1160
+
self.w2_weight_scale_inv
1161
+
ifself.use_block_quant
1162
+
elseself.w2_weight_scale
1163
+
),
1164
+
)
1132
1165
1133
1166
defforward(
1134
1167
self,
@@ -1142,6 +1175,9 @@ def forward(
1142
1175
num_recv_tokens_per_expert: List[int],
1143
1176
forward_mode: ForwardMode,
1144
1177
):
1178
+
if_use_aiter:
1179
+
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
0 commit comments