diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 212978cbae..57ba6f1d8d 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -157,6 +157,10 @@ def fused_moe_fake( num_local_tokens: Optional[torch.Tensor] = None, moe_sorting_dispatch_policy: bool = 0, dtype: Optional[torch.dtype] = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: Optional[torch.Tensor] = None, + bias2: Optional[torch.Tensor] = None, ) -> torch.Tensor: device = topk_ids.device M, topk = topk_ids.shape diff --git a/aiter/rotary_embedding.py b/aiter/rotary_embedding.py index a540b16121..1b328b1c3c 100644 --- a/aiter/rotary_embedding.py +++ b/aiter/rotary_embedding.py @@ -667,10 +667,10 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: self.max_position_embeddings * self.scaling_factor, dtype=dtypes.fp32 ) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * self.mscale - sin = freqs.sin() * self.mscale - cache = torch.cat((cos, sin), dim=-1) - return cache + cos = freqs.cos().unsqueeze(-2).unsqueeze(-2) * self.mscale + sin = freqs.sin().unsqueeze(-2).unsqueeze(-2) * self.mscale + # cache = torch.cat((cos, sin), dim=-1) + return cos, sin class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):