diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 724a83cae2..7f4d56f3b7 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -209,6 +209,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + sinks: Optional[torch.Tensor] = None, **kwargs, ) -> None: torch.nn.Module.__init__(self) @@ -274,6 +275,11 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "TritonMLAImpl") + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") def forward( self, @@ -450,6 +456,7 @@ def __init__( attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, + sinks: Optional[torch.Tensor] = None, ) -> None: super(AttentionImpl, self).__init__() self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -516,6 +523,11 @@ def __init__( raise NotImplementedError("Encoder self-attention " "is not implemented for " "HPUAttentionImpl") + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") self.is_chunked_attention = False @@ -597,6 +609,12 @@ def forward( if kv_cache is not None and isinstance(kv_cache, tuple): key_cache, value_cache, k_scales, v_scales = \ HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size) + if key.dtype == torch.float32 and key.dtype != key_cache.dtype: + key = key.to(key_cache.dtype) + if key.dtype == torch.float32 and value.dtype != value_cache.dtype: + value = value.to(value_cache.dtype) + if query.dtype != key.dtype: + query = query.to(key.dtype) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are @@ -735,6 +753,7 @@ def common_attention_args(self, 'key_cache': key_cache, 'value_cache': value_cache, 'block_size': block_size, + "sinks": self.sinks, 'k_scales': k_scales, 'v_scales': v_scales, } diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 831d516487..efe50242d2 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -67,8 +67,8 @@ def matmul_shape(lhs, rhs): return result -def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_size, matmul_av_op, batch2block_matmul_op, - block2batch_matmul_op): +def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, batch_size, matmul_av_op, + batch2block_matmul_op, block2batch_matmul_op): # When fp32_softmax is enabled attn is left in fp32 after Q@K # We can return to native dtype after we renormalize and calculate the adjustments if block_bias is not None and attn.dtype != block_bias.dtype: @@ -82,11 +82,29 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_siz if block_bias is not None: attn.add_(block_bias) block_max = attn.amax(dim=-1, keepdim=True) + if sink is not None: + block_max = torch.maximum(block_max, sink) attn = attn.sub(block_max) attn = attn.exp() if attn.dtype == torch.float32: attn = attn.to(value.dtype) - block_sums = attn.sum(dim=-1, keepdim=True) + if sink is None: + block_sums = attn.sum(dim=-1, keepdim=True) + else: + attn_shape = attn.shape + block_sums = attn.view(-1, attn_shape[-1]).sum(dim=-1, keepdim=True) + attn_shape = list(attn_shape) + attn_shape[-1] = 1 + block_sums = block_sums.view(attn_shape) + attn_sink = sink.sub(block_max) + attn_sink = attn_sink.exp() + if attn_sink.dtype == torch.float32: + attn_sink = attn_sink.to(value.dtype) + #TODO: Removing this .sum and using attn_sink directly + #results in wrong output which does not make sense. + #Looks like a Synapse issue, need to investigate further. + block_sums_sink = attn_sink.sum(dim=-1, keepdim=True) + block_sums = block_sums + block_sums_sink attn = matmul_av_op(attn, value) if get_config().fused_block_softmax_adjustment: out_shape = list(attn.shape[:3]) + [1] * (attn.dim() - 3) @@ -166,6 +184,7 @@ def flat_pa_mla(query, key_cache, value_cache, block_list, block_mapping, block_ block_bias, block_groups, block_mapping, + None, batch_size=batch_size, matmul_av_op=matmul_av_op, batch2block_matmul_op=batch2block_matmul_op, @@ -179,7 +198,7 @@ def flat_pa_mla(query, key_cache, value_cache, block_list, block_mapping, block_ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias, block_groups, block_size, scale, matmul_qk_op, position_bias, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op, keys_fetch_func, - values_fetch_func, k_scales, v_scales, **ignored_args): + values_fetch_func, sinks, k_scales, v_scales, **ignored_args): batch_size, _, hidden_size = query.shape _, kv_heads, head_size = key_cache.shape q_heads = hidden_size // head_size @@ -197,6 +216,13 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias value = values_fetch_func(value_cache.unflatten(0, (-1, block_size)), **get_kv_fetch_extra_args(blocks=block_list, scales=v_scales_uf)).transpose(1, 2) block_bias = block_bias.view(key.size(0), 1, 1, -1) + sink = None + if sinks is not None: + sinks = sinks.reshape(sinks.shape[0], 1) + sink = sinks.reshape(1, sinks.shape[0], 1, sinks.shape[1]) + sink = sink.expand(query.shape[0], -1, query.shape[-2], -1) + if kv_heads != q_heads: + sink = sink.unflatten(1, (kv_heads, -1)) if kv_heads != q_heads: query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) @@ -234,6 +260,7 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias block_bias, block_groups, block_mapping, + sink, batch_size=batch_size, matmul_av_op=matmul_av_op, batch2block_matmul_op=batch2block_matmul_op, @@ -292,6 +319,7 @@ def _naive_prompt_attention(query: torch.Tensor, matmul_qk_op=torch.matmul, softmax_op=torch.softmax, matmul_av_op=torch.matmul, + sinks: Optional[torch.Tensor] = None, **ignored_args) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -323,10 +351,19 @@ def _naive_prompt_attention(query: torch.Tensor, if attn_weights.dtype != attn_bias.dtype: attn_bias = attn_bias.to(dtype=attn_weights.dtype) attn_weights.add_(attn_bias) + if sinks is not None: + sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + if query_heads != kv_heads: + sink = sink.unflatten(1, (kv_heads, -1)) + combined_logits = torch.cat([attn_weights, sink], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + attn_weights = combined_logits if get_config().fp32_softmax: attn_weights = torch.softmax(attn_weights, dim=-1) else: attn_weights = softmax_op(attn_weights, dim=-1) + if sinks is not None: + attn_weights = attn_weights[..., :-1] attn_weights = attn_weights.to(query.dtype) attn_weights = matmul_av_op(attn_weights, value) @@ -345,6 +382,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, valid_seq_lengths: Optional[torch.Tensor] = None, window_size: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, **ignored_args) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -366,10 +404,18 @@ def _fsdpa_prompt_attention(query: torch.Tensor, query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, padding_side ] - args += [window_size] if window_size else [] + if sinks is not None: + args += [window_size] if window_size else [None] + else: + args += [window_size] if window_size else [] + # use sinks in fsdpa + if sinks is not None: + args += [sinks] attn_weights = fsdpa_op(*args) attn_weights = attn_weights.transpose(1, 2) + if sinks is not None: + htcore.mark_step() return attn_weights @@ -498,6 +544,9 @@ def __init__(self): def set_weight(self, w): self.weight = w + def set_bias(self, b): + self.bias = b + def forward(self, state, expert_id, w): raise NotImplementedError() @@ -509,12 +558,14 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, + bias=None, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): super().__init__() self.experts_min = experts_min self.experts_max = experts_max self.global_num_experts = global_num_experts self.num_experts = num_total_experts + self.bias = bias if MAX_EXPERTS_PER_SLICE > 0: max_expert_per_slice = MAX_EXPERTS_PER_SLICE @@ -566,8 +617,9 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, + bias=None, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): - super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn) + super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, bias, dispatch_fn) self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) @@ -580,31 +632,62 @@ def forward(self, hidden_states, expert_routing_table, router_weights, permuted_ w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] if self.moe_n_slice == 1: - return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list, - w3=w2_list, - permuted_weights=permuted_weights, - activation=activation, - experts_min=self.experts_min, - experts_max=self.experts_max, - **kwargs) + if self.bias is not None: + w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] + w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] + return torch.ops.hpu.mixture_of_experts.bias_fused_weights(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + w12_bias=w1_bias_list, + w3_bias=w2_bias_list, + permuted_weights=permuted_weights, + experts_min=self.experts_min, + experts_max=self.experts_max) + else: + return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + permuted_weights=permuted_weights, + activation=activation, + experts_min=self.experts_min, + experts_max=self.experts_max, + **kwargs) for i in range(self.moe_n_slice): w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] min_expert = self.experts_min + i * self.num_expert_per_group max_expert = min_expert + self.num_expert_per_group - 1 - slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list_slice, - w3=w2_list_slice, - permuted_weights=permuted_weights, - activation=activation, - experts_min=min_expert, - experts_max=max_expert, - **kwargs) + if self.bias is not None: + w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] + w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] + w1_bias_list_slice = w1_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + w2_bias_list_slice = w2_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts.bias_fused_weights( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + w12_bias=w1_bias_list_slice, + w3_bias=w2_bias_list_slice, + permuted_weights=permuted_weights, + experts_min=self.experts_min, + experts_max=self.experts_max) + else: + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list_slice, + w3=w2_list_slice, + permuted_weights=permuted_weights, + activation=activation, + experts_min=min_expert, + experts_max=max_expert, + **kwargs) if i == 0: final_hidden_states = slice_final_hidden_states else: diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index 4ae9ce4248..47fe5deb49 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -170,14 +170,16 @@ def forward( valid_sequence_lengths, padding_side="left", window_size=None, + sinks=None, ): if window_size is not None: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side, False, False, - window_size) + window_size, sinks) else: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, - recompute_mode, valid_sequence_lengths, padding_side) + recompute_mode, valid_sequence_lengths, padding_side, False, False, + (-1, -1), sinks) class ModuleFP8FusedSDPA(torch.nn.Module): diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index 747e2bc187..985d18df8e 100755 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -3,6 +3,7 @@ import torch import vllm +from vllm.config import get_current_vllm_config from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp) from vllm_gaudi.extension.runtime import get_config @@ -18,6 +19,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.use_dispatch_fn = get_config().use_dispatch_fn torch.hpu.synchronize() + vllm_config = get_current_vllm_config() + self.model_type = None + if vllm_config is not None and vllm_config.model_config is not None \ + and vllm_config.model_config.hf_config is not None: + self.model_type = vllm_config.model_config.hf_config.model_type @property def is_monolithic(self) -> bool: @@ -28,6 +34,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # custom handling for HPU num_experts = layer.local_num_experts ep_shift = layer.ep_rank * num_experts + has_bias = hasattr(layer, 'w13_bias') and hasattr(layer, 'w2_bias') experts_min, experts_max = ep_shift, num_experts + ep_shift - 1 @@ -36,17 +43,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: dispatch_fn = None - layer.moe_op = VllmMixtureOfExpertsOp( - layer.global_num_experts, - num_experts, - experts_min, - experts_max, - dispatch_fn, - ) + bias = has_bias if has_bias is True else None + layer.moe_op = VllmMixtureOfExpertsOp(layer.global_num_experts, num_experts, experts_min, experts_max, bias, + dispatch_fn) for expert_id in range(layer.local_num_experts): layer.moe_op.w13_list[expert_id].set_weight(layer.w13_weight.data[expert_id]) layer.moe_op.w2_list[expert_id].set_weight(layer.w2_weight.data[expert_id]) + if has_bias: + layer.moe_op.w13_list[expert_id].set_bias(layer.w13_bias.data[expert_id]) + layer.moe_op.w2_list[expert_id].set_bias(layer.w2_bias.data[expert_id]) def apply_monolithic( self, @@ -110,9 +116,13 @@ def forward_oot( topk_weights, topk_ids = layer.router.select_experts(hidden_states=x, router_logits=router_logits) else: import torch.nn.functional as F - topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + if self.model_type is not None and self.model_type in ["gpt_oss"]: + topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1) + topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32) + else: + topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(x.dtype) if not layer.use_grouped_topk: @@ -134,6 +144,15 @@ def forward_oot( topk_ids = topk_ids.view(-1, topk_ids.shape[-1]) topk_weights = topk_weights.view(-1, topk_weights.shape[-1]) + if self.model_type in ["gpt_oss"]: + return layer.moe_op( + x, + topk_ids.to(torch.int64), + topk_weights.to(x.dtype), + permuted_weights=True, + activation=layer.activation, + ).view(*input_shape) + output = layer.moe_op( x, topk_ids, diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 592aef4f87..5fcafe777c 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1550,6 +1550,19 @@ def is_decoder_only(self, req_id) -> bool: return bool(req_id in self.input_batch.req_type and \ self.input_batch.req_type[req_id] == "decode") + def _get_model_type(self) -> Optional[str]: + """ + Safely extract the model type from vllm_config. + + Returns: + The model type string if available, None otherwise. + """ + if (self.vllm_config is not None and self.vllm_config.model_config is not None + and self.vllm_config.model_config.hf_config is not None): + + return self.vllm_config.model_config.hf_config.model_type + return None + def _get_num_decodes(self) -> int: num_reqs = self.input_batch.num_reqs assert num_reqs > 0 @@ -2434,6 +2447,12 @@ def _create_decode_input_data(self, if self.interleaved_sliding_window: sliding_block_size = (self.sliding_window // self.block_size) + + # Adjust sliding block size for specific model types + model_type = self._get_model_type() + if model_type is not None and model_type in ["gpt_oss"]: + sliding_block_size += 1 + window_block_tables = [block_table[-sliding_block_size:] for block_table in block_tables_list] window_block_list, window_block_groups, window_block_usage = \ self.get_habana_paged_attn_buffers(