| 
 | 1 | +# -*- coding: utf-8 -*-  | 
 | 2 | + | 
 | 3 | +import os  | 
 | 4 | + | 
 | 5 | +import pytest  | 
 | 6 | +import torch  | 
 | 7 | +import triton  | 
 | 8 | +import torch_npu  | 
 | 9 | +from fla.ops.common.utils import prepare_token_indices  | 
 | 10 | +from native_sparse_attention.ops.naive import naive_nsa  | 
 | 11 | +from native_sparse_attention.ops.parallel import parallel_nsa  | 
 | 12 | + | 
 | 13 | + | 
 | 14 | +def get_abs_err(x, y):  | 
 | 15 | +    return (x-y).flatten().abs().max().item()  | 
 | 16 | + | 
 | 17 | + | 
 | 18 | +def get_err_ratio(x, y):  | 
 | 19 | +    err = (x-y).flatten().square().mean().sqrt().item()  | 
 | 20 | +    base = (x).flatten().square().mean().sqrt().item()  | 
 | 21 | +    return err / base  | 
 | 22 | + | 
 | 23 | + | 
 | 24 | +def assert_close(prefix, ref, tri, ratio):  | 
 | 25 | +    msg = f"{prefix} diff: {get_abs_err(ref, tri):.6f} ratio: {get_err_ratio(ref, tri):.6f}"  | 
 | 26 | +    print(msg)  | 
 | 27 | +    assert get_err_ratio(ref, tri) < ratio, msg  | 
 | 28 | + | 
 | 29 | + | 
 | 30 | +@pytest.mark.parametrize("B", [1])  | 
 | 31 | +@pytest.mark.parametrize("T", [256, 1024, 2000])  | 
 | 32 | +@pytest.mark.parametrize("H", [4])  | 
 | 33 | +@pytest.mark.parametrize("HQ", [64])  | 
 | 34 | +@pytest.mark.parametrize("D", [100, 64])  | 
 | 35 | +@pytest.mark.parametrize("S", [16])  | 
 | 36 | +@pytest.mark.parametrize("block_size", [32])  | 
 | 37 | +@pytest.mark.parametrize("window_size", [0, 32])  | 
 | 38 | +@pytest.mark.parametrize("dtype", [torch.bfloat16])  | 
 | 39 | +@pytest.mark.parametrize("scale", [0.1])  | 
 | 40 | +def test_parallel(  | 
 | 41 | +    B: int,  | 
 | 42 | +    H: int,  | 
 | 43 | +    HQ: int,  | 
 | 44 | +    T: int,  | 
 | 45 | +    D: int,  | 
 | 46 | +    S: int,  | 
 | 47 | +    block_size: int,  | 
 | 48 | +    window_size: int,  | 
 | 49 | +    dtype: torch.dtype,  | 
 | 50 | +    scale: float  | 
 | 51 | +):  | 
 | 52 | +    torch.manual_seed(42)  | 
 | 53 | +    os.environ['TRITON_F32_DEFAULT'] = 'ieee'  | 
 | 54 | + | 
 | 55 | +    perm_q = torch.randperm(T, device='npu')  | 
 | 56 | +    perm_k = torch.randperm(T, device='npu')  | 
 | 57 | +    perm_v = torch.randperm(T, device='npu')  | 
 | 58 | +    q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True)  | 
 | 59 | +    k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)  | 
 | 60 | +    v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)  | 
 | 61 | +    g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)  | 
 | 62 | +    g_slc = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)  | 
 | 63 | +    g_swa = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)  | 
 | 64 | +    do = torch.randn((B, T, HQ, D), dtype=dtype, device='npu')  | 
 | 65 | + | 
 | 66 | +    block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu')  | 
 | 67 | +    for b in range(B):  | 
 | 68 | +        for t in range(T):  | 
 | 69 | +            for h in range(H):  | 
 | 70 | +                i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]  | 
 | 71 | +                block_indices[b, t, h, :len(i_i)] = i_i  | 
 | 72 | +    block_indices = block_indices.sort(-1)[0]  | 
 | 73 | +    block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu')  | 
 | 74 | + | 
 | 75 | +    ref = naive_nsa(  | 
 | 76 | +        q=q,  | 
 | 77 | +        k=k,  | 
 | 78 | +        v=v,  | 
 | 79 | +        g_slc=g_slc,  | 
 | 80 | +        g_swa=g_swa,  | 
 | 81 | +        block_indices=block_indices,  | 
 | 82 | +        block_counts=block_counts,  | 
 | 83 | +        block_size=block_size,  | 
 | 84 | +        window_size=window_size,  | 
 | 85 | +        scale=scale  | 
 | 86 | +    )  | 
 | 87 | +    ref.backward(do)  | 
 | 88 | +    ref_dq, q.grad = q.grad.clone(), None  | 
 | 89 | +    ref_dk, k.grad = k.grad.clone(), None  | 
 | 90 | +    ref_dv, v.grad = v.grad.clone(), None  | 
 | 91 | +    ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None  | 
 | 92 | +    if window_size > 0:  | 
 | 93 | +        ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None  | 
 | 94 | + | 
 | 95 | +    tri = parallel_nsa(  | 
 | 96 | +        q=q,  | 
 | 97 | +        k=k,  | 
 | 98 | +        v=v,  | 
 | 99 | +        g_cmp=g_cmp,  | 
 | 100 | +        g_slc=g_slc,  | 
 | 101 | +        g_swa=g_swa,  | 
 | 102 | +        block_indices=block_indices,  | 
 | 103 | +        block_counts=block_counts,  | 
 | 104 | +        block_size=block_size,  | 
 | 105 | +        window_size=window_size,  | 
 | 106 | +        scale=scale  | 
 | 107 | +    )  | 
 | 108 | +    tri.backward(do)  | 
 | 109 | +    tri_dq, q.grad = q.grad.clone(), None  | 
 | 110 | +    tri_dk, k.grad = k.grad.clone(), None  | 
 | 111 | +    tri_dv, v.grad = v.grad.clone(), None  | 
 | 112 | +    tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None  | 
 | 113 | +    if window_size > 0:  | 
 | 114 | +        tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None  | 
 | 115 | + | 
 | 116 | +    assert_close(" o", ref, tri, 0.005)  | 
 | 117 | +    assert_close("dq", ref_dq, tri_dq, 0.005)  | 
 | 118 | +    assert_close("dk", ref_dk, tri_dk, 0.005)  | 
 | 119 | +    assert_close("dv", ref_dv, tri_dv, 0.005)  | 
 | 120 | +    assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005)  | 
 | 121 | +    if window_size > 0:  | 
 | 122 | +        assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005)  | 
 | 123 | + | 
 | 124 | + | 
 | 125 | +@pytest.mark.parametrize("N", [4])  | 
 | 126 | +@pytest.mark.parametrize("T", [64, 128, 200, 250, 256, 300, 400, 512, 1000, 2048])  | 
 | 127 | +@pytest.mark.parametrize("H", [4])  | 
 | 128 | +@pytest.mark.parametrize("HQ", [64])  | 
 | 129 | +@pytest.mark.parametrize("D", [100, 64])  | 
 | 130 | +@pytest.mark.parametrize("S", [16])  | 
 | 131 | +@pytest.mark.parametrize("block_size", [32])  | 
 | 132 | +@pytest.mark.parametrize("window_size", [0, 32])  | 
 | 133 | +@pytest.mark.parametrize("dtype", [torch.bfloat16])  | 
 | 134 | +def test_parallel_varlen(  | 
 | 135 | +    N: int,  | 
 | 136 | +    T: int,  | 
 | 137 | +    H: int,  | 
 | 138 | +    HQ: int,  | 
 | 139 | +    D: int,  | 
 | 140 | +    S: int,  | 
 | 141 | +    block_size: int,  | 
 | 142 | +    window_size: int,  | 
 | 143 | +    dtype: torch.dtype,  | 
 | 144 | +):  | 
 | 145 | +    torch.manual_seed(42)  | 
 | 146 | +    os.environ['TRITON_F32_DEFAULT'] = 'ieee'  | 
 | 147 | + | 
 | 148 | +    # randomly split the sequence into N segments  | 
 | 149 | +    offsets = torch.cat([  | 
 | 150 | +        torch.tensor([0], dtype=torch.long),  | 
 | 151 | +        torch.arange(16, T)[torch.randperm(T - 1)[:N-1]],  | 
 | 152 | +        torch.tensor([T], dtype=torch.long)  | 
 | 153 | +    ], 0).npu().sort()[0]  | 
 | 154 | +    # seq-first required for inputs with variable lengths  | 
 | 155 | +    perm_q = torch.randperm(T, device='npu')  | 
 | 156 | +    perm_k = torch.randperm(T, device='npu')  | 
 | 157 | +    perm_v = torch.randperm(T, device='npu')  | 
 | 158 | +    q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)  | 
 | 159 | +    k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)  | 
 | 160 | +    v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)  | 
 | 161 | +    g_cmp = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)  | 
 | 162 | +    g_slc = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)  | 
 | 163 | +    g_swa = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)  | 
 | 164 | +    do = torch.randn((1, T, HQ, D), dtype=dtype, device='npu')  | 
 | 165 | + | 
 | 166 | +    token_indices = prepare_token_indices(offsets).tolist()  | 
 | 167 | +    block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='npu')  | 
 | 168 | +    for i in range(T):  | 
 | 169 | +        _, t = token_indices[i]  | 
 | 170 | +        for h in range(H):  | 
 | 171 | +            i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]  | 
 | 172 | +            block_indices[0, i, h, :len(i_i)] = i_i  | 
 | 173 | +    block_indices = block_indices.sort(-1)[0]  | 
 | 174 | +    block_counts = torch.randint(1, S + 1, (1, T, H), device='npu')  | 
 | 175 | + | 
 | 176 | +    ref = naive_nsa(  | 
 | 177 | +        q=q,  | 
 | 178 | +        k=k,  | 
 | 179 | +        v=v,  | 
 | 180 | +        g_slc=g_slc,  | 
 | 181 | +        g_swa=g_swa,  | 
 | 182 | +        block_indices=block_indices,  | 
 | 183 | +        block_counts=block_counts,  | 
 | 184 | +        block_size=block_size,  | 
 | 185 | +        window_size=window_size,  | 
 | 186 | +        cu_seqlens=offsets  | 
 | 187 | +    )  | 
 | 188 | +    ref.backward(do)  | 
 | 189 | +    ref_dq, q.grad = q.grad.clone(), None  | 
 | 190 | +    ref_dk, k.grad = k.grad.clone(), None  | 
 | 191 | +    ref_dv, v.grad = v.grad.clone(), None  | 
 | 192 | +    ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None  | 
 | 193 | +    if window_size > 0:  | 
 | 194 | +        ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None  | 
 | 195 | + | 
 | 196 | +    tri = parallel_nsa(  | 
 | 197 | +        q=q,  | 
 | 198 | +        k=k,  | 
 | 199 | +        v=v,  | 
 | 200 | +        g_cmp=g_cmp,  | 
 | 201 | +        g_slc=g_slc,  | 
 | 202 | +        g_swa=g_swa,  | 
 | 203 | +        block_indices=block_indices,  | 
 | 204 | +        block_counts=block_counts,  | 
 | 205 | +        block_size=block_size,  | 
 | 206 | +        window_size=window_size,  | 
 | 207 | +        cu_seqlens=offsets  | 
 | 208 | +    )  | 
 | 209 | +    tri.backward(do)  | 
 | 210 | +    tri_dq, q.grad = q.grad.clone(), None  | 
 | 211 | +    tri_dk, k.grad = k.grad.clone(), None  | 
 | 212 | +    tri_dv, v.grad = v.grad.clone(), None  | 
 | 213 | +    tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None  | 
 | 214 | +    if window_size > 0:  | 
 | 215 | +        tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None  | 
 | 216 | + | 
 | 217 | +    assert_close(" o", ref, tri, 0.004)  | 
 | 218 | +    assert_close("dq", ref_dq, tri_dq, 0.005)  | 
 | 219 | +    assert_close("dk", ref_dk, tri_dk, 0.005)  | 
 | 220 | +    assert_close("dv", ref_dv, tri_dv, 0.005)  | 
 | 221 | +    assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005)  | 
 | 222 | +    if window_size > 0:  | 
 | 223 | +        assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005)  | 
0 commit comments