diff --git a/aiter/rotary_embedding.py b/aiter/rotary_embedding.py index a540b16121..f8ce57019c 100644 --- a/aiter/rotary_embedding.py +++ b/aiter/rotary_embedding.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn -from aiter import dtypes +from aiter import dtypes, fused_mrope_3d_rms # from custom_op import CustomOp @@ -1143,6 +1143,315 @@ def get_next_input_positions( for _ in range(3) ] +class MRotaryEmbeddingQKNormFused(nn.Module): + """Rotary Embedding with Multimodal Sections fused with QKNorm""" + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=dtypes.fp32) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=dtypes.fp32) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + return cos, sin + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + mrope_interleaved: bool = False, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + assert self.head_size == self.rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + self.mrope_interleaved = mrope_interleaved + + cos, sin = self._compute_cos_sin_cache() + cos = cos.to(dtype) + sin = sin.to(dtype) + cache = torch.cat((cos, sin), dim=-1) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.mrope_section = mrope_section + if self.mrope_section: + expected_sum = rotary_dim // 2 + actual_sum = sum(self.mrope_section) + if actual_sum != expected_sum: + print( + f"MRoPE section sum mismatch: expected {expected_sum}, got {actual_sum}. " + f"Adjusting mrope_section to match rotary_dim // 2 = {expected_sum}" + ) + # Auto-correct by scaling the mrope_section proportionally + if actual_sum > 0: + scale_factor = expected_sum / actual_sum + self.mrope_section = [ + max(1, int(section * scale_factor)) + for section in self.mrope_section + ] + # Ensure the sum exactly matches by adjusting the last element + current_sum = sum(self.mrope_section) + if current_sum != expected_sum: + self.mrope_section[-1] += expected_sum - current_sum + else: + # If all sections are 0, create a default distribution + self.mrope_section = [ + expected_sum // len(self.mrope_section) + ] * len(self.mrope_section) + # Handle remainder + remainder = expected_sum % len(self.mrope_section) + for i in range(remainder): + self.mrope_section[i] += 1 + + print( + f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" + ) + + def forward( + self, + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + positions: torch.Tensor, + num_heads: int, + num_kv_heads: int, + eps: float + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert positions.ndim == 1 or positions.ndim == 2 + num_tokens = positions.shape[-1] + num_heads_q = num_heads + num_heads_k = num_kv_heads + num_heads_v = num_kv_heads + is_interleaved = True if positions.ndim == 2 and self.mrope_section is not None else False + assert is_interleaved == self.mrope_interleaved + fused_mrope_3d_rms(qkv, q_weight, k_weight, self.cos_sin_cache, positions, num_tokens, + num_heads_q, num_heads_k, num_heads_v, self.head_size, self.is_neox_style, self.mrope_section, + is_interleaved, eps) + q_size = num_heads_q * self.head_size + k_size = num_heads_k * self.head_size + v_size = num_heads_v * self.head_size + + qkv = qkv.view(num_tokens, q_size + k_size + v_size) + q, k, v = qkv.split([q_size, k_size, v_size], dim=-1) + return q, k, v + +class DualChunkRotaryEmbedding(nn.Module): + """Rotary positional embedding for Dual Chunk Attention.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + chunk_size: int, + local_size: int, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.chunk_size = chunk_size + self.local_size = local_size + self.dtype = dtype + self.device = torch.device(f"cuda:{torch.cuda.current_device()}") + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = ( + self._compute_cos_sin_cache() + ) + + self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) + self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) + self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) + self.register_buffer( + "cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False + ) + self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + chunk_len = self.chunk_size - self.local_size + q_t = torch.arange(chunk_len, dtype=torch.float) + qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp( + max=self.chunk_size + ) + k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len + + # count from chunk_len, no clamp(self.chunk_size) restriction + qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len + # count from self.chunk_size for q_inter's rope + q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size + + q_freqs = torch.outer(q_t, inv_freq) + qc_freqs = torch.outer(qc_t, inv_freq) + k_freqs = torch.outer(k_t, inv_freq) + qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq) + q_inter_freqs = torch.outer(q_inter_t, inv_freq) + + q_cos = q_freqs.cos() + q_sin = q_freqs.sin() + qc_cos = qc_freqs.cos() + qc_sin = qc_freqs.sin() + k_cos = k_freqs.cos() + k_sin = k_freqs.sin() + + qc_no_clamp_cos = qc_no_clamp_freqs.cos() + qc_no_clamp_sin = qc_no_clamp_freqs.sin() + q_inter_cos = q_inter_freqs.cos() + q_inter_sin = q_inter_freqs.sin() + + q_cache = torch.cat((q_cos, q_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + else: + query_pass = None + key_pass = None + + positions_with_offsets = ( + torch.add(positions, offsets) if offsets is not None else positions + ) + key = self._apply_rotary_embedding( + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass + ) + chunk_len = self.chunk_size - self.local_size + query = self._apply_rotary_embedding( + self.cos_sin_q_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_succ = self._apply_rotary_embedding( + self.cos_sin_qc_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_inter = self._apply_rotary_embedding( + self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), + query_rot, + query_pass, + ) + query_succ_critical = self._apply_rotary_embedding( + self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_inter_critical = self._apply_rotary_embedding( + self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + + # merge query into one tensor to simplify the interfaces + query = torch.cat( + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1, + ) + return query, key + + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin + + if self.rotary_dim < self.head_size: + hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) + else: + hidden = hidden_rot + return hidden.flatten(-2).squeeze(0) + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" + return s _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -1156,6 +1465,7 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -1167,6 +1477,17 @@ def get_rope( rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) key = ( @@ -1176,12 +1497,28 @@ def get_rope( base, is_neox_style, rope_scaling_args, + dual_chunk_attention_args, dtype, ) if key in _ROPE_DICT: return _ROPE_DICT[key] - if rope_scaling is None: + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + **extra_kwargs, + ) + elif rope_scaling is None: rotary_emb = RotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype ) @@ -1211,6 +1548,27 @@ def get_rope( high_freq_factor, original_max_position, ) + elif scaling_type == "default": + if "mrope_section" in rope_scaling and "aiter_rope_fused_qknorm" in rope_scaling: + rotary_emb = MRotaryEmbeddingQKNormFused( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling["mrope_interleaved"] if "mrope_interleaved" in rope_scaling else False + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding( head_size,