Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
138 commits
Select commit Hold shift + click to select a range
0123240
init orangina
Jun 25, 2025
1b6d426
trivial modification
jychen21 Jun 25, 2025
b28605c
debugging weight loading
jychen21 Jun 27, 2025
966bab4
fix o_proj linear error
jychen21 Jun 27, 2025
8c545d9
fix config compatiblity & RMSNorm failed with no kernel is available
jychen21 Jun 29, 2025
918e3c0
rename
jychen21 Jun 29, 2025
0bba216
add sliding window attn every other layer & fix attn sinks weight loa…
jychen21 Jun 29, 2025
87d81a6
update
jychen21 Jun 29, 2025
81bb3bf
use silu as a WA
jychen21 Jun 29, 2025
8febdd6
moe bias support (not finished)
jychen21 Jun 30, 2025
606961e
fix moe intermediate size configuration to match intermediate size
xutizhou Jul 2, 2025
d1b159a
fix the problem that mlp weight is not really loaded, and mlp bias su…
linhu-nv Jul 3, 2025
fab80db
Merge branch 'linhu/dev' into 'feat/orangina'
linhu-nv Jul 3, 2025
884666a
add structure
xutizhou Jul 3, 2025
fc37426
to display
xutizhou Jul 3, 2025
2355609
rm model sstructure
xutizhou Jul 3, 2025
c49f8ee
Update OpenAIMoeAttention to set sinks parameter dtype to bfloat16
xutizhou Jul 3, 2025
c426b23
Refactor OpenAIMoeAttention by removing RMSNorm normalization for que…
xutizhou Jul 3, 2025
4eb4680
Add a TODO comment to ensure correct sliding window size for flashinf…
xutizhou Jul 3, 2025
f4ab9cb
mark attn o_proj all reduce
xutizhou Jul 3, 2025
ee3c851
Add bias support to FusedMoE and related quantization methods
xutizhou Jul 3, 2025
7958fa3
Add TODO comment to indicate future replacement of gate with router i…
xutizhou Jul 3, 2025
5bb606f
fix rope yarn mismatch issue
jychen21 Jul 3, 2025
76c206d
Merge remote-tracking branch 'refs/remotes/origin/feat/orangina' into…
xutizhou Jul 4, 2025
c9fdd37
Merge branch 'feat/orangina' into 'feat/orangina'
xutizhou Jul 4, 2025
51f9a92
Fix naming mismatch for gate and router parameters in OpenAIMoeForCau…
xutizhou Jul 4, 2025
4abddbc
Update comment to clarify naming convention for gate and router param…
xutizhou Jul 4, 2025
9b9320d
Merge branch 'feat/orangina' into 'feat/orangina'
xutizhou Jul 4, 2025
314bf1a
Bug fix: FusedMoE expert weight loader can not load weights
jychen21 Jul 4, 2025
3f0bbd4
bf16 fusedmoe integration
zhuofan1123 Jul 4, 2025
00ae587
Merge branch 'moe' into 'feat/orangina'
zhuofan1123 Jul 4, 2025
d129e5b
add accuracy test
jychen21 Jul 6, 2025
ca3061f
trivial code changes
jychen21 Jul 7, 2025
63af97e
add one prompt test
jychen21 Jul 7, 2025
e9fcf71
loop import issue solved even we dont change the import codes
linhu-nv Jul 9, 2025
209247a
Merge branch 'linhu/dev' into 'feat/orangina'
linhu-nv Jul 9, 2025
38b2d2e
add mxfp4 triton api
zhuofan1123 Jul 8, 2025
da5fc63
Merge branch 'moe' into 'feat/orangina'
zhuofan1123 Jul 9, 2025
358347c
Enhance FlashInfer backend with attention sink support and add relate…
xutizhou Jul 9, 2025
ccfa992
Merge remote-tracking branch 'upstream/feat/orangina' into feat/atten…
xutizhou Jul 9, 2025
1322f26
Update OpenAIMoeAttention to incorporate attention sink parameter in …
xutizhou Jul 9, 2025
b5c0eb4
Refactor OpenAIMoeAttention to use plural 'sinks' for clarity in atte…
xutizhou Jul 9, 2025
1a61c44
Add enable_attention_sink parameter to OpenAIMoeAttention initializat…
xutizhou Jul 9, 2025
c528f22
Add sink attention mechanism to OpenAIMoe and Qwen3Moe models, introd…
xutizhou Jul 9, 2025
5c3f293
support mxfp4 moe
zhuofan1123 Jul 9, 2025
b4ac555
pad weight for Hopper
zhuofan1123 Jul 9, 2025
6bb7fa2
add args for mxfp4
zhuofan1123 Jul 10, 2025
842f3e6
Merge branch 'moe' into 'feat/orangina'
zhuofan1123 Jul 10, 2025
f4faaae
make continuous after transpose
zhuofan1123 Jul 10, 2025
4b0d768
Update d_rcp calculation in FlashInfer backend to incorporate exponen…
xutizhou Jul 10, 2025
43355f8
Refactor OpenAIMoeAttention to utilize precomputed sink values, simpl…
xutizhou Jul 10, 2025
4a57691
Merge remote-tracking branch 'upstream/feat/orangina' into feat/atten…
xutizhou Jul 10, 2025
b558084
Add debug prints for prompt and generated output in throughput test
xutizhou Jul 10, 2025
3a78899
Remove unused sink_softmax and sink_attention_ref functions from Qwen…
xutizhou Jul 10, 2025
ecca06e
Remove unused imports from openai_moe.py to clean up the codebase and…
xutizhou Jul 10, 2025
452d567
Accuracy fix: Attn using reference sdpa impl as a WA
jychen21 Jul 10, 2025
c9e075a
Add a TODO comment to OpenAIMoeAttention regarding potential exponent…
xutizhou Jul 11, 2025
a2c3376
Add tensor logging functionality in FlashInfer backend to track input…
xutizhou Jul 11, 2025
70c28c7
Refactor OpenAIMoeAttention to implement sink attention mechanism, en…
xutizhou Jul 11, 2025
5b8e09d
Refactor debug logging in OpenAIMoeAttention to always print layer_id…
xutizhou Jul 11, 2025
f1ba131
Update OpenAIMoeAttention to conditionally log tensor shapes based on…
xutizhou Jul 11, 2025
ce0ec2a
Refactor FlashInfer backend to simplify window size calculation by re…
xutizhou Jul 11, 2025
32fa9ee
Update tolerance levels in OpenAIMoeAttention tensor comparison to im…
xutizhou Jul 11, 2025
27a562f
Merge remote-tracking branch 'upstream/feat/orangina' into feat/atten…
xutizhou Jul 11, 2025
cd9a5b1
Add sdpa function to openai_moe.py for enhanced attention mechanism w…
xutizhou Jul 14, 2025
281b18e
Add new openai_moe.py file and implement QK attention calculation in …
xutizhou Jul 14, 2025
1d00b7e
Update sliding_window_size parameter in OpenAIMoeAttention to ensure …
xutizhou Jul 14, 2025
300f14f
Add flashinfer_attention_ref method to OpenAIMoeAttention for improve…
xutizhou Jul 14, 2025
b322437
Refactor flashinfer_attention_ref method in OpenAIMoeAttention to acc…
xutizhou Jul 14, 2025
55840e6
Update OpenAIMoeAttention to check if sinks need fp32 before exp and …
xutizhou Jul 14, 2025
399feed
debugging acc & bug fix
jychen21 Jul 14, 2025
dc9187e
Refactor attention output calculation in OpenAIMoeAttention by renami…
xutizhou Jul 14, 2025
fdff13c
Merge branch 'feat/attention_sink_final' into 'feat/orangina'
xutizhou Jul 14, 2025
9c08480
Update sliding_window handling in OpenAIMoeAttention to default to -1…
xutizhou Jul 15, 2025
d29a62c
Merge branch 'feat/attention_sink_final' into 'feat/orangina'
xutizhou Jul 15, 2025
467ca00
Update sliding_window handling in OpenAIMoeAttention to default to -1…
xutizhou Jul 15, 2025
6d3c212
Merge branch 'feat/attention_sink_final' into 'feat/orangina'
xutizhou Jul 15, 2025
2401d39
Add key tokens into promt for wrong detokenizing
jychen21 Jul 15, 2025
d847171
Implement torch native attention version supporting both sink and sli…
jychen21 Jul 16, 2025
7ddc192
remove sdpa ref cause decode phase can not simply use this, use 'torc…
jychen21 Jul 16, 2025
8f87d2e
remove sdpa ref cause decode phase can not simply use this, use 'torc…
jychen21 Jul 16, 2025
4f87f27
disable shuffle for pre-final weights
zhuofan1123 Jul 16, 2025
2e99f16
First e2e accuracy test, verified on gsm8k(0.735) and mmlu(0.828)
jychen21 Jul 16, 2025
92ebb1b
uncomment mmlu acc target assertion
jychen21 Jul 16, 2025
1d507a4
fix mxfp4 for tp
zhuofan1123 Jul 16, 2025
72eea4f
renaming model to gpt-oss as recommended
jychen21 Jul 17, 2025
b3bcf23
Refactor weight processing in UnquantizedFusedMoEMethodOpenAI to remo…
xutizhou Jul 18, 2025
2a33809
Refactor weight and bias parameter handling in FusedMoE to streamline…
xutizhou Jul 18, 2025
eef6505
Add two modes for SwiGLU act (chunk / pairwise)
jychen21 Jul 20, 2025
2c8e56d
remove WA in layernorm forward_cuda, just call layernorm forward_nati…
jychen21 Jul 20, 2025
930b974
Enhance FlashInfer attention mechanism by adjusting window handling a…
xutizhou Jul 21, 2025
d7c40d7
Merge remote-tracking branch 'upstream/feat/orangina' into feat/orangina
xutizhou Jul 21, 2025
07375b3
sliding_window remove -1 for torch native impl
jychen21 Jul 21, 2025
3e38c7c
Refactor attention sink handling in OpenAIMoeAttention to conditional…
xutizhou Jul 21, 2025
40c3220
Merge remote-tracking branch 'upstream/feat/orangina' into feat/orangina
xutizhou Jul 21, 2025
ede293e
load weight for pair-wise act
zhuofan1123 Jul 21, 2025
519ad18
Enhance SwiGLU implementation by adding a pair_wise option for tensor…
xutizhou Jul 21, 2025
0f110f0
Refactor FlashInfer attention backend to utilize layer.attention_sink…
xutizhou Jul 21, 2025
e41fda4
Refactor sink_attention_ref function to improve handling of query and…
xutizhou Jul 21, 2025
9b2fd12
fix torch native backend bug
xutizhou Jul 21, 2025
2ca2c16
Remove workaround for orangina in RMSNorm and add pair_wise_act param…
xutizhou Jul 21, 2025
d66f0f3
Refactor get_attention_sliding_window_size function to simplify logic…
xutizhou Jul 21, 2025
1513244
update acc test serving args
jychen21 Jul 22, 2025
b74b149
add clamp limit
zhuofan1123 Jul 24, 2025
a4f7c2e
fix flashinfer accuracy bug
xutizhou Jul 25, 2025
010093a
Merge branch 'main' into feat/orangina
jychen21 Jul 25, 2025
ce0ab79
Merge branch 'feat/orangina-backup' into feat/orangina
jychen21 Jul 25, 2025
eba1e69
fix code format
jychen21 Jul 25, 2025
e270628
Tune accuracy test params: temperature1.0 top_p1.0 top_k0.0, add chat…
jychen21 Jul 27, 2025
9b5b80b
Refactor version assertion logic and enhance SWAChunkCache eviction h…
xutizhou Jul 31, 2025
939a2b8
reduce mxfp4 memory usage
zhuofan1123 Jul 31, 2025
4a56bef
Update SWAChunkCache to require attention_chunk_size as an int and in…
xutizhou Jul 31, 2025
5b7169b
Merge remote-tracking branch 'upstream/feat/orangina' into feat/orangina
xutizhou Jul 31, 2025
f2a7796
Set default attention_chunk_size to 128 in model_config.py and remove…
xutizhou Jul 31, 2025
c407e58
Adapt simple eval to support orangina reasoning mode, system_message,…
jychen21 Aug 1, 2025
3e76fad
gpqa bug fix
jychen21 Aug 1, 2025
c9ad3cf
Fix: Make anser comparison case-insenstive (GPQA)
jychen21 Aug 2, 2025
6e7da63
Update default attention_chunk_size to None in model_config.py
xutizhou Aug 4, 2025
7268b8a
Update default enable_attention_sink to False in FlashInferAttnBacken…
xutizhou Aug 4, 2025
1f3aeff
support mxfp4 quant config
zhuofan1123 Aug 1, 2025
c322662
fix scale loading issue when tp_size>2
zhuofan1123 Aug 3, 2025
4288d18
make gemm2_output contiguous for hopper
zhuofan1123 Aug 4, 2025
3feb217
remove original moe impl
zhuofan1123 Aug 4, 2025
48f7864
add args for fp8 activation
zhuofan1123 Aug 4, 2025
7eba750
rename arg
zhuofan1123 Aug 4, 2025
99d0f75
remove checkpoint_weights_transposed
zhuofan1123 Aug 4, 2025
022b646
Merge branch 'quant' into 'feat/orangina'
zhuofan1123 Aug 4, 2025
2c56ec8
rename to gpt_oss
zhuofan1123 Aug 4, 2025
8205914
Add Triton kernel to set tensor to zero and update scale initializati…
xutizhou Aug 4, 2025
46653db
fix flashinfer + cuda graph
PerkzZheng Aug 4, 2025
0835c26
remove file
zhuofan1123 Aug 4, 2025
6661e8c
Update sink parameter type to float32 and remove unused flashinfer code
xutizhou Aug 4, 2025
d16cdd4
Merge remote-tracking branch 'upstream/feat/orangina' into feat/orangina
xutizhou Aug 4, 2025
9982849
recover QUERY_TEMPLATE_MULTICHOICE and ANSWER_PATTERN_MULTICHOICE
jychen21 Aug 4, 2025
e4db1ba
Merge remote-tracking branch 'github/main' into final_rebase
xutizhou Aug 5, 2025
1972bb1
[refactor] Update imports and enhance deepep mode handling in GptOssM…
xutizhou Aug 5, 2025
8ac1978
Update SamplerResponse
jychen21 Aug 5, 2025
6640f5c
Clean up
Aug 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "3rdparty/triton"]
path = 3rdparty/triton
url = https://github.com/dongfengy/triton.git
branch = fused_moe_triton_0613
Comment on lines +1 to +4
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The submodule for Triton points to a personal fork (dongfengy/triton) and a specific branch (fused_moe_triton_0613). This introduces a dependency on a personal repository, which can be a maintenance and security risk.

It's highly recommended to use an official repository or a fork under the project's organization to ensure stability and long-term maintenance. If this is a temporary measure for development, it should be replaced before merging into the main branch.

1 change: 1 addition & 0 deletions 3rdparty/triton
Submodule triton added at aa9743
2 changes: 2 additions & 0 deletions python/sglang/bench_offline_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def throughput_test_once(

st = time.perf_counter()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
print(f"prompt: {prompt}")
print(f"gen_out: {gen_out}")
Comment on lines +237 to +238
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements appear to be for debugging purposes. They should be removed from the benchmark script to avoid polluting the output.

latency = time.perf_counter() - st

if profile:
Expand Down
15 changes: 15 additions & 0 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
return out


class SwiGLU(CustomOp):
def forward_native(self, x: torch.Tensor, alpha: float = 1.702, pair_wise: bool = True) -> torch.Tensor:
# reference implementation
if not pair_wise:
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
else:
x_glu, x_linear = x[..., ::2], x[..., 1::2]
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
return out_glu * (x_linear + 1) # Note that here add an extra bias of 1 to the linear layer

def forward_cuda(self, x: torch.Tensor, alpha: float = 1.702, pair_wise: bool = True) -> torch.Tensor:
# TODO: Implement the CUDA kernel for SwiGLU in sgl-kernel
return self.forward_native(x, alpha, pair_wise)


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.

Expand Down
399 changes: 335 additions & 64 deletions python/sglang/srt/layers/attention/flashinfer_backend.py

Large diffs are not rendered by default.

321 changes: 321 additions & 0 deletions python/sglang/srt/layers/attention/torch_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,324 @@ def forward_decode(

def support_triton(self):
return False


class TorchNativeAttnSinkBackend(TorchNativeAttnBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__(model_runner)
self.forward_metadata = None
self.device = model_runner.device

@staticmethod
def _scaled_dot_product_attention(Q, K, V, S, scaling, sliding_window):
# sliding_window <= 0 means no sliding window
# Q: [n_tokens_q, n_heads, q_mult, d_head]
# K: [n_tokens_kv, n_heads, d_head]
# V: [n_tokens_kv, n_heads, d_head]
n_tokens_q, n_heads, q_mult, d_head = Q.shape
n_tokens_kv = K.shape[0]

assert K.shape == (n_tokens_kv, n_heads, d_head)
assert V.shape == (n_tokens_kv, n_heads, d_head)

K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens_q, -1)

if n_tokens_q == n_tokens_kv: # Prefill
mask = torch.triu(
Q.new_full((n_tokens_q, n_tokens_kv), -float("inf")), diagonal=1
)
else: # Decode
mask = Q.new_zeros((n_tokens_q, n_tokens_kv))

if sliding_window is not None and sliding_window > 0:
mask += torch.tril(
mask.new_full((n_tokens_q, n_tokens_kv), -float("inf")),
diagonal=n_tokens_kv - n_tokens_q - sliding_window,
)

QK = torch.einsum("qhmd,khmd->hmqk", Q, K)
QK *= scaling
QK += mask[None, None, :, :]
QK = torch.cat([QK, S], dim=-1)

W = torch.softmax(QK, dim=-1)
W = W[..., :-1]

attn = torch.einsum("hmqk,khmd->qhmd", W, V)

return attn.reshape(n_tokens_q, -1)

def _run_sdpa_forward_extend(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
extend_prefix_lens: torch.Tensor,
extend_seq_lens: torch.Tensor,
num_kv_heads: int,
q_mult: int,
scaling=None,
sliding_window=None,
attention_sinks=None,
enable_gqa=False,
causal=False,
):
"""Run the extend forward by using custom sdpa op.

Args:
query: [num_tokens, num_q_heads, head_size]
output: [num_tokens, num_q_heads, head_size]
k_cache: [max_total_num_tokens, num_kv_heads, head_size]
v_cache: [max_total_num_tokens, num_kv_heads, head_size]
req_to_token: [max_num_reqs, max_context_len]
req_pool_indices: [num_seqs]
seq_lens: [num_seqs]
extend_prefix_lens: [num_seqs]
extend_seq_lens: [num_seqs]
num_kv_heads: int
q_mult: int
scaling: float or None
sliding_window: int or None
attention_sinks: torch.Tensor or None
enable_gqa: bool
causal: bool

Returns:
output: [num_tokens, num_q_heads, head_size]
"""

assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
assert seq_lens.shape[0] == extend_seq_lens.shape[0]

# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)

start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
# TODO: this loop process a sequence per iter, this is inefficient.
# Need optimize the performance later.

extend_seq_len_q = extend_seq_lens[seq_idx]
prefill_seq_len_q = extend_prefix_lens[seq_idx]

seq_len_kv = seq_lens[seq_idx]
end_q = start_q + extend_seq_len_q
end_kv = start_kv + seq_len_kv

per_req_query = query[:, start_q:end_q, :]
per_req_query_redudant = torch.empty(
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
dtype=per_req_query.dtype,
device=per_req_query.device,
)

per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query

# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)

per_req_query_redudant = per_req_query_redudant.permute(1, 0, 2).reshape(
seq_len_kv, num_kv_heads, q_mult, per_req_query_redudant.shape[-1]
)
per_req_key = per_req_key.permute(1, 0, 2)
per_req_value = per_req_value.permute(1, 0, 2)

per_req_out_redudant = TorchNativeAttnSinkBackend._scaled_dot_product_attention(
per_req_query_redudant,
per_req_key,
per_req_value,
attention_sinks,
scaling=scaling,
sliding_window=sliding_window,
).reshape(seq_len_kv, -1, per_req_value.shape[-1])
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
start_q, start_kv = end_q, end_kv
return output

def _run_sdpa_forward_decode(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
num_kv_heads: int,
q_mult: int,
scaling=None,
sliding_window=None,
attention_sinks=None,
enable_gqa=False,
causal=False,
):
"""Run the decode forward by using custom sdpa op.

Args:
query: [num_tokens, num_q_heads, head_size]
output: [num_tokens, num_q_heads, head_size]
k_cache: [max_total_num_tokens, num_kv_heads, head_size]
v_cache: [max_total_num_tokens, num_kv_heads, head_size]
req_to_token: [max_num_reqs, max_context_len]
req_pool_indices: [num_seqs]
seq_lens: [num_seqs]
num_kv_heads: int
q_mult: int
scaling: float or None
sliding_window: int or None
attention_sinks: torch.Tensor or None
enable_gqa: bool
causal: bool

Returns:
output: [num_tokens, num_q_heads, head_size]
"""

# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)

start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
# TODO: this loop process a sequence per iter, this is inefficient.
# Need optimize the performance later.

seq_len_q = 1
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv

per_req_query = query[:, start_q:end_q, :]

# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]

per_req_query = per_req_query.permute(1, 0, 2).reshape(
seq_len_q, num_kv_heads, q_mult, per_req_query.shape[-1]
)
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_key = per_req_key.permute(1, 0, 2)
per_req_value = per_req_value.permute(1, 0, 2)

per_req_out = (
TorchNativeAttnSinkBackend._scaled_dot_product_attention(
per_req_query,
per_req_key,
per_req_value,
attention_sinks,
scaling=scaling,
sliding_window=sliding_window,
)
.reshape(seq_len_q, -1, per_req_value.shape[-1])
)
output[start_q:end_q, :, :] = per_req_out
start_q, start_kv = end_q, end_kv

return output

def forward_extend(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

use_gqa = layer.tp_q_head_num != layer.tp_k_head_num

q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)

causal = True
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False

self._run_sdpa_forward_extend(
q_,
o_,
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_prefix_lens,
forward_batch.extend_seq_lens,
layer.tp_k_head_num,
layer.tp_q_head_num // layer.tp_k_head_num,
scaling=layer.scaling,
sliding_window=layer.sliding_window_size + 1, # torch native attn sink uses sliding window without -1
attention_sinks=layer.attention_sinks,
enable_gqa=use_gqa,
causal=causal,
)
return o

def forward_decode(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

use_gqa = layer.tp_q_head_num != layer.tp_k_head_num

q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)

self._run_sdpa_forward_decode(
q_,
o_,
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
layer.tp_k_head_num,
layer.tp_q_head_num // layer.tp_k_head_num,
scaling=layer.scaling,
sliding_window=layer.sliding_window_size + 1, # torch native attn sink uses sliding window without -1
attention_sinks=layer.attention_sinks,
enable_gqa=use_gqa,
causal=False,
)

return o
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

_config: Optional[Dict[str, Any]] = None


@contextmanager
def override_config(config):
global _config
Expand All @@ -30,6 +29,7 @@ def get_config() -> Optional[Dict[str, Any]]:

__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"override_config",
"get_config",
Expand Down
Loading