Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
maybe_execute_in_parallel
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._torch.utils import maybe_compile
from tensorrt_llm._utils import get_size_in_bytes
from tensorrt_llm.bindings import DataType
from tensorrt_llm.bindings.executor import KvCacheConfig
Expand Down Expand Up @@ -572,6 +573,12 @@ def update_for_spec_dec(self):
self.on_update_kv_lens()


@maybe_compile(dynamic=True)
def _scale(weights: torch.Tensor, q_scale: torch.Tensor,
s: float) -> torch.Tensor:
return weights * q_scale.squeeze(-1) * s


class Indexer(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -964,9 +971,6 @@ def sparse_attn_indexer(
if not use_custom_topk:
topk_indices_buffer[:hidden_states.shape[0]] = -1

# Store k_fp8 and k_scale into indexer k cache
self._update_k_cache(k_fp8, k_scale, metadata)

if has_prefill:
# Use chunked prefill to reduce memory footprint
if metadata.indexer_prefill_chunks is not None:
Expand Down Expand Up @@ -1121,9 +1125,7 @@ def weight_scale(self, hidden_states: torch.Tensor,
q_scale: torch.Tensor) -> torch.Tensor:
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
hidden_states)
weights = weights.unsqueeze(-1) * q_scale * self.weight_scale_factor
# output weights is guaranteed to be float32 due to type promotion from q_scale (float32)
weights = weights.squeeze(-1)
weights = _scale(weights, q_scale, self.weight_scale_factor)
return weights

@torch.inference_mode()
Expand Down Expand Up @@ -1192,7 +1194,15 @@ def _prep_q_or_k(qk_pe, qk_nope):
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
q_scale = q_scale.view(-1, self.n_heads, 1)

weights = self.weight_scale(hidden_states, indexer_weights, q_scale)
weights, _ = maybe_execute_in_parallel(
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
lambda: self._update_k_cache(
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)

# Return topk indices buffer for sparse attention [num_tokens, index_topk]
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
k_scale, weights)
Expand Down
13 changes: 1 addition & 12 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
is_piecewise_running, is_torch_compiling)
is_torch_compiling, maybe_compile)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel
from .rms_norm import RMSNorm
Expand Down Expand Up @@ -76,17 +76,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
return metadata, attn_layer


def maybe_compile(func):

def wrapper(*args, **kwargs):
if is_piecewise_running():
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
return func(*args, **kwargs)
return torch.compile(func)(*args, **kwargs)

return wrapper


@maybe_compile
def maybe_compiled_copy_(dst, src):
dst.copy_(src)
Expand Down
23 changes: 23 additions & 0 deletions tensorrt_llm/_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,26 @@ def get_device_uuid(device_idx: int) -> str:
property = torch.cuda.get_device_properties(device_idx)
uuid = "GPU-" + str(property.uuid)
return uuid


def maybe_compile(func=None, **compile_kwargs):
"""
Conditionally compile a function with torch.compile.
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
Args:
func: The function to decorate (optional, for direct decoration).
**compile_kwargs: Keyword arguments for torch.compile.
Returns:
The conditionally compiled function..
"""

def decorator(f):

def wrapper(*args, **kwargs):
if is_piecewise_running():
return f(*args, **kwargs)
return torch.compile(f, **compile_kwargs)(*args, **kwargs)

return wrapper

return decorator(func) if func else decorator
2 changes: 2 additions & 0 deletions tests/unittest/_torch/attention/sparse/test_dsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
f" Chunk {i}: Q[{chunk.token_start}:{chunk.token_end}] ({num_q} tokens), "
f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)")

indexer._update_k_cache(k_fp8, k_scale, metadata_chunked)
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked,
hidden_states, q_fp8,
k_fp8, k_scale, weights)
Expand Down Expand Up @@ -1206,6 +1207,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
f"✓ Created {num_baseline_chunks} chunk(s) (effectively non-chunked)"
)

indexer._update_k_cache(k_fp8, k_scale, metadata_baseline)
topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline,
hidden_states, q_fp8,
k_fp8, k_scale, weights)
Expand Down