Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/ops/sparse_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .sparsity_config import SparsityConfig, DenseSparsityConfig, FixedSparsityConfig, VariableSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig
from .sparsity_config import SparsityConfig, DenseSparsityConfig, FixedSparsityConfig, VariableSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig, LocalSlidingWindowSparsityConfig
from .sparse_self_attention import SparseSelfAttention
from .bert_sparse_self_attention import BertSparseSelfAttention
from .sparse_attention_utils import SparseAttentionUtils
60 changes: 60 additions & 0 deletions deepspeed/ops/sparse_attention/sparsity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,63 @@ def make_layout(self, seq_len):

layout = self.check_and_propagate_first_head_layout(layout)
return layout


class LocalSlidingWindowSparsityConfig(SparsityConfig):
"""Configuration class to store `Local Sliding Window` sparsity configuration - a purely-local sliding window attention.
This class extends parent class of `SparsityConfig` and customizes it for `Local` sparsity.
"""
def __init__(self,
num_heads,
block=16,
num_sliding_window_blocks=3,
attention='unidirectional'):
"""Initialize the Local Sliding Window Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window.
attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
"""

super().__init__(num_heads, block)
self.num_sliding_window_blocks = num_sliding_window_blocks
self.attention = attention

def set_sliding_window_layout(self, h, layout):
"""Sets sliding local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set
"""

num_blocks = layout.shape[1]
if (num_blocks < self.num_sliding_window_blocks):
raise ValueError(
f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
)

w = self.num_sliding_window_blocks // 2
for row in range(0, num_blocks):
start = max(0, row - w)
end = min(row + w + 1,
num_blocks) if self.attention == "bidirectional" else row + 1
layout[h, row, start:end] = 1
return layout

def make_layout(self, seq_len):
"""Generates `Local Sliding Window` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` sparsity layout of all head
"""

layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_sliding_window_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
10 changes: 10 additions & 0 deletions tests/unit/test_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def test_bslongformersparsityconfig_module_availability():
return True


def test_localwindowsparsityconfig_module_availability():
return True
try:
from deepspeed.ops.sparse_attention import LocalSlidingWindowSparsityConfig
except ImportError:
print("LocalSlidingWindowSparsityConfig Module is not installed!")
return False
return True


def test_sparseselfattention_module_availability():
return True
try:
Expand Down