Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions 65691.pbs111.OU
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/var/spool/pbs/mom_priv/jobs/65691.pbs111.SC: line 10: deactivate: command not found
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: runjiachen (runjiachen-nus). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.18.7
wandb: Run data is saved locally in /home/users/nus/e1113744/native-sparse-attention-pytorch/wandb/run-20250707_181358-5mxyov4r
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run summer-butterfly-2
wandb: ⭐️ View project at https://wandb.ai/runjiachen-nus/native-sparse-attention
wandb: 🚀 View run at https://wandb.ai/runjiachen-nus/native-sparse-attention/runs/5mxyov4r
wandb: WARNING Calling wandb.run.save without any arguments is deprecated.Changes to attributes are automatically persisted.
training: 0%| | 0/100000 [00:00<?, ?it/s]training: 0%| | 0/100000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/users/nus/e1113744/native-sparse-attention-pytorch/train.py", line 160, in <module>
loss = model(data, return_loss = True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/nus/e1113744/llm-foundry/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/nus/e1113744/llm-foundry/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/nus/e1113744/native-sparse-attention-pytorch/native_sparse_attention_pytorch/transformer.py", line 308, in forward
attn_out, layer_cache = attn(
^^^^^
File "/home/users/nus/e1113744/llm-foundry/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/nus/e1113744/llm-foundry/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Attention.forward() got an unexpected keyword argument 'cache'

real 0m28.958s
user 0m12.499s
sys 0m5.281s
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Resource Usage on 2025-07-07 18:14:13.967119:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
JobId: 65691.pbs111
Project: 71001002
Exit Status: 1
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
NCPUs: Requested(14), Used(14)
CPU Time Used: 00:00:17
Memory: Requested(235gb), Used(1829988kb)
Vmem Used: 522802812kb
Walltime: Requested(12:00:00), Used(00:00:30)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Execution Nodes Used: (a2ap-dgx034:ngpus=1:ncpus=14:mem=246415360kb)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
GPU Duration: 34.13secs
GPU Power Consumed: 60.97W
GPU Max GPU Memory Used: 704.0MB
Memory Throughput Rate (Average): a2ap-dgx034:(gpu7:0%)
Memory Throughput Rate (Max): a2ap-dgx034:(gpu7:0%)
Memory Throughput Rate (Min): a2ap-dgx034:(gpu7:0%)
GPU SM Utilization (Average): a2ap-dgx034:(gpu7:0%)
GPU SM Utilization (Max): a2ap-dgx034:(gpu7:0%)
GPU SM Utilization (Min): a2ap-dgx034:(gpu7:0%)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Warning: All GPUs have a percentage of 0 utilisation.
GPU application profile: Idle
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
64 changes: 64 additions & 0 deletions 65693.pbs111.OU
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/var/spool/pbs/mom_priv/jobs/65693.pbs111.SC: line 10: deactivate: command not found
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: runjiachen (runjiachen-nus). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.18.7
wandb: Run data is saved locally in /home/users/nus/e1113744/native-sparse-attention-pytorch/wandb/run-20250707_181437-cjdzttf7
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run morning-paper-3
wandb: ⭐️ View project at https://wandb.ai/runjiachen-nus/native-sparse-attention
wandb: 🚀 View run at https://wandb.ai/runjiachen-nus/native-sparse-attention/runs/cjdzttf7
wandb: WARNING Calling wandb.run.save without any arguments is deprecated.Changes to attributes are automatically persisted.
training: 0%| | 0/100000 [00:00<?, ?it/s]training: 0%| | 0/100000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/users/nus/e1113744/native-sparse-attention-pytorch/train.py", line 160, in <module>
loss = model(data, return_loss = True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/nus/e1113744/llm-foundry/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/nus/e1113744/llm-foundry/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/nus/e1113744/native-sparse-attention-pytorch/native_sparse_attention_pytorch/transformer.py", line 308, in forward
attn_out, layer_cache = attn(
^^^^^
File "/home/users/nus/e1113744/llm-foundry/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/nus/e1113744/llm-foundry/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Attention.forward() got an unexpected keyword argument 'cache'

real 0m14.447s
user 0m12.093s
sys 0m3.989s
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Resource Usage on 2025-07-07 18:14:45.053803:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
JobId: 65693.pbs111
Project: 71001002
Exit Status: 1
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
NCPUs: Requested(14), Used(14)
CPU Time Used: 00:00:16
Memory: Requested(235gb), Used(1141256kb)
Vmem Used: 521964248kb
Walltime: Requested(12:00:00), Used(00:00:16)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Execution Nodes Used: (a2ap-dgx034:ngpus=1:ncpus=14:mem=246415360kb)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
GPU Duration: 17.99secs
GPU Power Consumed: 56.66W
GPU Max GPU Memory Used: 704.0MB
Memory Throughput Rate (Average): a2ap-dgx034:(gpu7:0%)
Memory Throughput Rate (Max): a2ap-dgx034:(gpu7:0%)
Memory Throughput Rate (Min): a2ap-dgx034:(gpu7:0%)
GPU SM Utilization (Average): a2ap-dgx034:(gpu7:0%)
GPU SM Utilization (Max): a2ap-dgx034:(gpu7:0%)
GPU SM Utilization (Min): a2ap-dgx034:(gpu7:0%)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Warning: All GPUs have a percentage of 0 utilisation.
GPU application profile: Idle
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
109 changes: 109 additions & 0 deletions native_sparse_attention_pytorch/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import time
import torch
import numpy as np

# Import your modules
from native_sparse_attention import SparseAttention
from transformer import Attention as StandardAttention


def benchmark_module(module, x, runs=50, warmups=5):
"""
Benchmark forward and backward pass of a module.
Returns forward and backward times as numpy arrays (in seconds).
"""
# Warm-up to stabilize JIT/CUDA and trigger Triton compilation
for _ in range(warmups):
out = module(x)
loss = out.sum()
loss.backward()
module.zero_grad()

# Forward timing
fwd_times = []
for _ in range(runs):
if torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
out = module(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
fwd_times.append(time.perf_counter() - t0)

# Backward timing
bwd_times = []
for _ in range(runs):
out = module(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
out.sum().backward()
if torch.cuda.is_available():
torch.cuda.synchronize()
bwd_times.append(time.perf_counter() - t0)
module.zero_grad()

return np.array(fwd_times), np.array(bwd_times)


def main():


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 128
seq_len = 1024
d_model = 512
n_heads = 8
d_head = d_model // n_heads

# NSA-specific hyperparameters
sliding_window_size = 32
compress_block_size = 4
compress_block_sliding_stride = 4
selection_block_size = 16
num_selected_blocks = 4

# Create input tensor
x = torch.randn(batch_size, seq_len, d_model, device=device)

# Instantiate modules with identical hyperparameters
std_attn = StandardAttention(
dim=d_model,
dim_head=d_head,
heads=n_heads,
causal=True
).to(device)

nsa_attn = SparseAttention(
d_model,
d_head,
n_heads,
sliding_window_size,
compress_block_size,
compress_block_sliding_stride,
selection_block_size,
num_selected_blocks,
use_triton_kernel=True
).to(device)

# Run benchmarks
runs = 5000
warmups = 500
std_fwd, std_bwd = benchmark_module(std_attn, x, runs, warmups)
nsa_fwd, nsa_bwd = benchmark_module(nsa_attn, x, runs, warmups)

# Report results
print(f"{'Module':<25}{'Fwd Mean (ms)':>15}{'Fwd Std (ms)':>15}{'Bwd Mean (ms)':>15}{'Bwd Std (ms)':>15}")
for name, fwd, bwd in [
("StandardAttention", std_fwd, std_bwd),
("SparseAttention (NSA)", nsa_fwd, nsa_bwd),
]:
print(f"{name:<25}{fwd.mean()*1000:>15.3f}{fwd.std()*1000:>15.3f}{bwd.mean()*1000:>15.3f}{bwd.std()*1000:>15.3f}")

print(nsa_attn.timer)


if __name__ == "__main__":
main()
Loading