Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ def make_inputs(
b = torch.randn(B, HV, device=device, dtype=dtype)

# prefill params for chunk_kda must keep batch dim = 1
prefill_g = torch.randn(1, B, HV, K, device=device, dtype=dtype)
prefill_beta = torch.sigmoid(torch.randn(1, B, HV, device=device, dtype=dtype))
# chunk_kda requires g, beta, v to have the same head count as k (H),
# matching the real KimiLinear model where num_heads == num_kv_heads.
prefill_v = torch.randn(1, B, H, V, device=device, dtype=dtype)
prefill_g = torch.randn(1, B, H, K, device=device, dtype=dtype)
prefill_beta = torch.sigmoid(torch.randn(1, B, H, device=device, dtype=dtype))

cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32)

Expand All @@ -66,8 +69,11 @@ def make_inputs(
b = torch.randn(B, 1, HV, device=device, dtype=dtype)

# prefill params for chunk_kda dense path
prefill_g = torch.randn(B, 1, HV, K, device=device, dtype=dtype)
prefill_beta = torch.sigmoid(torch.randn(B, 1, HV, device=device, dtype=dtype))
# chunk_kda requires g, beta, v to have the same head count as k (H),
# matching the real KimiLinear model where num_heads == num_kv_heads.
prefill_v = torch.randn(B, 1, H, V, device=device, dtype=dtype)
prefill_g = torch.randn(B, 1, H, K, device=device, dtype=dtype)
prefill_beta = torch.sigmoid(torch.randn(B, 1, H, device=device, dtype=dtype))

cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32)
else:
Expand All @@ -94,6 +100,7 @@ def make_inputs(
v=v,
a=a,
b=b,
prefill_v=prefill_v,
prefill_g=prefill_g,
prefill_beta=prefill_beta,
A_log=A_log,
Expand Down Expand Up @@ -147,12 +154,13 @@ def run_cutedsl(inp):

def run_prefill_then_decode_baseline(inp):
ssm_states = inp["ssm_states"].clone()
prefill_v_clone = inp["prefill_v"].clone()
v_clone = inp["v"].clone()

_ = chunk_kda(
q=inp["q"],
k=inp["k"],
v=v_clone,
v=prefill_v_clone,
g=inp["prefill_g"],
beta=inp["prefill_beta"],
initial_state=ssm_states,
Expand Down Expand Up @@ -182,12 +190,13 @@ def run_prefill_then_decode_baseline(inp):

def run_prefill_then_decode_cutedsl(inp):
ssm_states = inp["ssm_states"].clone()
prefill_v_clone = inp["prefill_v"].clone()
v_clone = inp["v"].clone()

_ = chunk_kda(
q=inp["q"],
k=inp["k"],
v=v_clone,
v=prefill_v_clone,
g=inp["prefill_g"],
beta=inp["prefill_beta"],
initial_state=ssm_states,
Expand Down
Loading
Loading