Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 39 additions & 53 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,11 @@ def _scale(weights: torch.Tensor, q_scale: torch.Tensor,
return weights * q_scale.squeeze(-1) * s


@maybe_compile(dynamic=True)
def _to_float(hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states.float()


class Indexer(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -715,7 +720,7 @@ def __init__(self,
self.hidden_size,
self.n_heads,
bias=False,
dtype=dtype,
dtype=torch.float32,
quant_config=None,
skip_create_weights_in_init=skip_create_weights_in_init,
use_custom_cublas_mm=True)
Expand Down Expand Up @@ -1234,82 +1239,63 @@ def sparse_attn_indexer(
dtype=torch.int32)
return topk_indices_buffer

def weight_scale(self, hidden_states: torch.Tensor,
indexer_weights: Optional[torch.Tensor],
q_scale: torch.Tensor) -> torch.Tensor:
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
hidden_states)
def _weight_scale(self, weights: torch.Tensor,
q_scale: torch.Tensor) -> torch.Tensor:
weights = _scale(weights, q_scale, self.weight_scale_factor)
return weights

def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor,
position_ids: torch.Tensor):
"""Project Q/K and apply RoPE"""
q = self.wq_b(qr)
k = self.k_norm(indexer_k)
q = q.view(-1, self.n_heads, self.head_dim)
q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim],
dim=-1)
k_pe, k_nope = k.split([self.rope_dim, self.head_dim - self.rope_dim],
dim=-1)
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
k_pe = k_pe[:, 0, :]
return q_pe, q_nope, k_pe, k_nope

def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
"""Concatenate, rotate, and FP8 quantize for Q or K"""
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
q_or_k = rotate_activation(q_or_k)
q_or_k = q_or_k.view(-1, self.head_dim)
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
q_or_k, use_ue8m0=self.scale_fmt == "ue8m0")
return q_or_k

@torch.inference_mode()
def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
metadata: DSAtrtllmAttentionMetadata,
position_ids: torch.Tensor, indexer_k: Optional[torch.Tensor],
indexer_weights: Optional[torch.Tensor]):
position_ids: torch.Tensor, indexer_k: torch.Tensor):
quant_block_size = metadata.kv_cache_manager.quant_block_size
assert quant_block_size == 128, "Only support quant_block_size = 128 for now"

if indexer_k is not None:
q, k = maybe_execute_in_parallel(
lambda: self.wq_b(
qr), # TODO: fuse wq_b and move this outside of the indexer
lambda: self.k_norm(indexer_k),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
else:
q, k = maybe_execute_in_parallel(
lambda: self.wq_b(qr),
lambda: self.k_norm(self.wk(hidden_states)),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)

# q/k rope + possible fast_hadamard_transform
q = q.view(-1, self.n_heads, self.head_dim)

q, k = maybe_execute_in_parallel(
lambda: torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1),
lambda: torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1),
q_and_k, weights = maybe_execute_in_parallel(
lambda: self._qk_projection_and_rope(qr, indexer_k, position_ids),
lambda: self.weights_proj(_to_float(hidden_states)),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)

q_pe, q_nope = q
k_pe, k_nope = k
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])

k_pe = k_pe[:, 0, :]

def _prep_q_or_k(qk_pe, qk_nope):
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
q_or_k = rotate_activation(q_or_k)
q_or_k = q_or_k.view(-1, self.head_dim)
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
q_or_k, use_ue8m0=self.scale_fmt == "ue8m0")
return q_or_k

q_pe, q_nope, k_pe, k_nope = q_and_k
q, k = maybe_execute_in_parallel(
lambda: _prep_q_or_k(q_pe, q_nope),
lambda: _prep_q_or_k(k_pe, k_nope),
lambda: self._prep_q_or_k(q_pe, q_nope),
lambda: self._prep_q_or_k(k_pe, k_nope),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)

q_fp8, q_scale = q
k_fp8, k_scale = k
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
q_scale = q_scale.view(-1, self.n_heads, 1)

weights, _ = maybe_execute_in_parallel(
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
lambda: self._weight_scale(weights, q_scale),
lambda: self._update_k_cache(
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
self.ln_events[0],
Expand Down
80 changes: 42 additions & 38 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,10 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
fused_a_scale = torch.cat(
[q_a_proj_scale, fused_a_scale], dim=0)

module.weight_scale.data.copy_(fused_a_scale)
# For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized
# to include indexer weights, which is filled in post_load_weights.
module.weight_scale.data[0:fused_a_scale.
shape[0]].copy_(fused_a_scale)
# For DeepseekV32: kv_a_proj_with_mqa is oversized
# to include indexer k weights, which is filled in post_load_weights.
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
elif names[-1] in params_map:
module_weights = []
Expand Down Expand Up @@ -556,13 +557,6 @@ def __init__(
config = model_config.pretrained_config
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1

# DSV3.2 nvfp4 ckpt has kv_a_proj_with_mqa module in bfloat16
# TODO: check it more directly/robustly, e.g., indexer_weight_quant == fuseA_quant == indexer_quant
if model_config.get_quant_config().quant_algo == QuantAlgo.NVFP4:
self.fuse_a_indexer_k_weight = True
else:
self.fuse_a_indexer_k_weight = False

super().__init__(hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
Expand All @@ -586,36 +580,46 @@ def __init__(

self.indexer = self.mqa.indexer

if self.fuse_a_indexer_k_weight:
# For DeepseekV32, the kv_a_proj_with_mqa includes:
# q_a_proj + kv_a_proj_with_mqa + indexer.wk + indexer.weights_proj
self.kv_a_proj_with_mqa = DeepseekV3Linear(
config.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank +
self.indexer.head_dim + self.indexer.n_heads,
bias=False,
dtype=config.torch_dtype,
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
use_custom_cublas_mm=True)
# For DeepseekV32, the kv_a_proj_with_mqa includes:
# q_a_proj + kv_a_proj_with_mqa + indexer.wk
self.kv_a_proj_with_mqa = DeepseekV3Linear(
config.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank +
self.indexer.head_dim,
bias=False,
dtype=config.torch_dtype,
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
use_custom_cublas_mm=True)

def post_load_weights(self):
if self.fuse_a_indexer_k_weight:
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype == self.indexer.weights_proj.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
# Copy indexer weights into the fused kv_a_proj_with_mqa module
indexer_wk_weight = self.indexer.wk.weight.data
offset = self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank
self.kv_a_proj_with_mqa.weight.data[offset:offset +
self.indexer.head_dim].copy_(
indexer_wk_weight)
offset += self.indexer.head_dim
indexer_weights_proj_weight = self.indexer.weights_proj.weight.data
self.kv_a_proj_with_mqa.weight.data[offset:offset +
self.indexer.n_heads].copy_(
indexer_weights_proj_weight)
self.indexer.wk = None
self.indexer.weights_proj = None
"""
Concatenate indexer.wk weights into kv_a_proj_with_mqa's last dimension, to fuse indexer.wk projection with kv_a_proj_with_mqa GEMM.
"""
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
# Copy indexer weights into the fused kv_a_proj_with_mqa module
indexer_wk_weight = self.indexer.wk.weight.data
offset = self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank
self.kv_a_proj_with_mqa.weight.data[offset:offset +
self.indexer.head_dim].copy_(
indexer_wk_weight)

# Copy indexer scale data if it exists
if hasattr(self.indexer.wk,
'weight_scale') and self.indexer.wk.weight_scale is not None:
indexer_wk_scale = self.indexer.wk.weight_scale.data
assert self.kv_a_proj_with_mqa.weight_scale.dim(
) == 2, "weight_scale must be a 2D tensor"
group_size = self.kv_a_proj_with_mqa.weight.shape[
0] // self.kv_a_proj_with_mqa.weight_scale.shape[0]
scale_offset = offset // group_size
scale_size = indexer_wk_scale.shape[0]
# Copy indexer scale to the corresponding position in the fused module
self.kv_a_proj_with_mqa.weight_scale.data[
scale_offset:scale_offset + scale_size].copy_(indexer_wk_scale)

self.indexer.wk = None


class Deepseekv3RoutingImpl():
Expand Down
19 changes: 5 additions & 14 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,19 +1221,11 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
if position_ids is not None:
position_ids = position_ids[..., :num_tokens]

if self.fuse_a_indexer_k_weight:
q, compressed_kv, k_pe, indexer_k, indexer_weights = self.kv_a_proj_with_mqa(
hidden_states).split([
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim,
self.indexer.head_dim, self.indexer.n_heads
], -1)
else:
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
hidden_states).split([
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
], -1)
indexer_k = None
indexer_weights = None
q, compressed_kv, k_pe, indexer_k = self.kv_a_proj_with_mqa(
hidden_states).split([
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim,
self.indexer.head_dim
], -1)

# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
q, compressed_kv = maybe_execute_in_parallel(
Expand All @@ -1255,7 +1247,6 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
attn_metadata,
position_ids,
indexer_k=indexer_k, # indexer K proj
indexer_weights=indexer_weights, # indexer weights proj
)

assert q.shape[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,7 @@ def yarn_get_mscale(scale=1, mscale=1):
hidden_states,
attn_metadata,
position_ids,
None, # indexer_k
None, # indexer_weights
indexer_k=mla.mqa.indexer.wk(hidden_states), # indexer_k
)

# Validate indexer output against expected causal indices (since seq_len < topk=2048)
Expand Down