Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions PR_DESCRIPTION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Pull Request: Add Quasar Attention and Standalone Model Implementation

## Summary
This PR introduces **Quasar Attention**, a highly optimized linear attention variant derived from Kimi Delta Attention (KDA) but featuring significant architectural optimizations and kernel refinements. Quasar achieves superior throughput and memory efficiency, particularly at long context lengths.

This PR includes:
1. **Quasar Attention Triton Kernels**: Fused chunk-wise forward and backward kernels in `fla/ops/quasar`.
2. **QuasarAttention Layer**: A standalone attention layer in `fla/layers/quasar.py`.
3. **Quasar Model**: A complete HuggingFace-compatible model implementation in `fla/models/quasar`, including `QuasarConfig`, `QuasarModel`, and `QuasarForCausalLM`.
4. **Library Integration**: Full registration of Quasar components in the `fla` library root interfaces.

## Benchmarks
Quasar demonstrates superior hardware efficiency compared to baseline linear attention architectures.

### High-Throughput Performance
**Setup**: 8x NVIDIA B200, 2B Model, 64k Context Length

| Architecture | Throughput (Tokens/sec) |
| :--- | :--- |
| **Quasar** | **478,559** |
| Kimi Delta Attention (KDA) | 456,163 |
| Gated Delta Attention | 447,784 |

### Scaling and Memory Efficiency
**Setup**: Single NVIDIA B200, 1B Model

| Context Length | Quasar Throughput | KDA Throughput | Speedup |
| :--- | :--- | :--- | :--- |
| 16k | 123,259 tok/s | 105,052 tok/s | **+17.3%** |
| 32k | 146,828 tok/s | 110,225 tok/s | **+33.2%** |

## References
- **Quasar Attention Repository**: [https://github.com/SILX-LABS/quasar-attention](https://github.com/SILX-LABS/quasar-attention)
- **Official Release**: Quasar Attention significantly improves upon KDA by optimizing the gating mechanism and kernel fusion for modern GPU architectures like Blackwell (B200).

## Implementation Details
- **Branding**: All components follow the `quasar` nomenclature to prevent symbol collisions with upstream KDA implementations.
- **Independence**: The Quasar module is self-contained, including its own recomputed kernels and configuration classes.
- **Compatibility**: Supports both standalone Quasar models and hybrid attention configurations within the FLA framework.
57 changes: 57 additions & 0 deletions fla/distributed_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
"""
Centralized compatibility module for torch.distributed imports.
All distributed-related imports should go through here to handle environments
where distributed tensor APIs are not available.
"""

import torch

# DeviceMesh
try:
from torch.distributed import DeviceMesh
except ImportError:
try:
from torch.distributed.device_mesh import DeviceMesh
except ImportError:
DeviceMesh = None

# DTensor
try:
from torch.distributed.tensor import DTensor
except (ImportError, AttributeError):
DTensor = None

# Replicate, Shard, distribute_module, Placement
try:
from torch.distributed.tensor import Placement, Replicate, Shard, distribute_module
except (ImportError, AttributeError):
Placement = Replicate = Shard = distribute_module = None

# ParallelStyle
try:
from torch.distributed.tensor.parallel import ParallelStyle
except (ImportError, AttributeError):
ParallelStyle = None

# Convenience flag
HAS_DISTRIBUTED = all([
DeviceMesh is not None,
DTensor is not None,
Placement is not None,
Replicate is not None,
Shard is not None,
distribute_module is not None,
ParallelStyle is not None,
])

__all__ = [
'DeviceMesh',
'DTensor',
'Placement',
'Replicate',
'Shard',
'distribute_module',
'ParallelStyle',
'HAS_DISTRIBUTED',
]
2 changes: 2 additions & 0 deletions fla/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .mamba2 import Mamba2
from .mesa_net import MesaNet
from .mla import MultiheadLatentAttention
from .quasar import QuasarAttention
from .mom import MomAttention
from .multiscale_retention import MultiScaleRetention
from .nsa import NativeSparseAttention
Expand Down Expand Up @@ -56,6 +57,7 @@
'MultiheadLatentAttention',
'MultiScaleRetention',
'NativeSparseAttention',
'QuasarAttention',
'PaTHAttention',
'ReBasedLinearAttention',
'RodimusAttention',
Expand Down
Loading
Loading