[CPU] Enable Granite 4 / Mamba models on CPU backend #39157
[CPU] Enable Granite 4 / Mamba models on CPU backend #39157Akashcodes732 wants to merge 20 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@dhcp-9-123-5-76.bl1-in.ibm.com> Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@Akashs-MBP.lan> Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
There was a problem hiding this comment.
Code Review
This pull request introduces CPU fallbacks for Mamba-related operations, including causal convolution and selective state updates, by implementing pure PyTorch versions and registering them through the CustomOp framework. The review feedback identifies several critical bugs in these new implementations: _causal_conv1d_fn_cpu incorrectly performs double-updates on the convolution state due to in-place modifications on a tensor view, _causal_conv1d_update_cpu fails to mask out padding slots which can lead to memory corruption, and _selective_state_update_cpu lacks the per-token state updates and correct indexing required for speculative decoding.
| x_seq = x[:, seq_start:seq_end] # (dim, seq_len) | ||
|
|
||
| if has_initial_state is not None and has_initial_state[b]: | ||
| state = conv_states[cache_idx] # (dim, state_len) |
There was a problem hiding this comment.
The state tensor is obtained via basic indexing, which returns a view of conv_states. The subsequent in-place updates to state (lines 69-70) directly modify the global conv_states cache during the convolution loop. However, there is a second state update logic at the end of the sequence (lines 72-78) which also modifies conv_states. This results in a double-update/double-shift of the convolution state for sequences shorter than the kernel width, leading to incorrect results. Cloning the state at the beginning of the sequence processing ensures the loop works on a local copy, while the final state is correctly committed at the end of the function.
if has_initial_state is not None and has_initial_state[b]:
state = conv_states[cache_idx].clone() # (dim, state_len)
else:| new_state = torch.cat( | ||
| [states[:, :, 1:], x_t.unsqueeze(-1)], dim=-1 | ||
| ) | ||
| conv_state[cache_idxs] = new_state |
There was a problem hiding this comment.
The state update does not account for pad_slot_id. If cache_idxs contains pad_slot_id (which is typically -1 in vLLM), conv_state[cache_idxs] = new_state will incorrectly update the last element of the conv_state tensor. The update should be restricted to valid indices using the valid_mask.
| conv_state[cache_idxs] = new_state | |
| conv_state[cache_idxs[valid_mask]] = new_state[valid_mask] |
| if dst_state_batch_indices is not None: | ||
| dst_idx = dst_state_batch_indices[seq_idx, 0].item() | ||
| else: | ||
| dst_idx = state_idx | ||
|
|
||
| s = state[state_idx].float() | ||
|
|
||
| for t in range(seq_len): | ||
| token_idx = bos + t | ||
|
|
||
| x_val = x[token_idx].float() | ||
| dt_val = dt[token_idx].float() | ||
|
|
||
| if dt_bias is not None: | ||
| dt_val = dt_val + dt_bias.float() | ||
| if dt_softplus: | ||
| dt_val = torch.nn.functional.softplus(dt_val) | ||
|
|
||
| A_val = A.float() | ||
|
|
||
| B_val = B[token_idx].float() | ||
| B_expanded = B_val.repeat_interleave(nheads_ngroups_ratio, dim=0) | ||
| C_val = C[token_idx].float() | ||
| C_expanded = C_val.repeat_interleave(nheads_ngroups_ratio, dim=0) | ||
|
|
||
| dA = torch.exp(A_val * dt_val.unsqueeze(-1)) | ||
| dBx = (B_expanded.unsqueeze(1) * (x_val * dt_val).unsqueeze(-1)) | ||
| s = s * dA + dBx | ||
|
|
||
| out_val = (s * C_expanded.unsqueeze(1)).sum(dim=-1) | ||
|
|
||
| if D is not None: | ||
| out_val = out_val + x_val * D.float() | ||
|
|
||
| if z is not None: | ||
| z_val = z[token_idx].float() | ||
| out_val = out_val * z_val * torch.sigmoid(z_val) | ||
|
|
||
| out[token_idx] = out_val.to(out.dtype) | ||
|
|
||
| state[dst_idx] = s.to(state.dtype) |
There was a problem hiding this comment.
The CPU fallback for selective_state_update has two significant issues regarding speculative decoding (when num_accepted_tokens is provided):
- Missing Intermediate Updates: In speculative decoding mode, the SSM state must be updated for every token in the sequence to allow the verification step to choose the correct state. The current implementation only updates the state once at the end of the sequence (line 297), whereas the Triton kernel (and the Mamba architecture requirements) updates it per-token.
- Incorrect Indexing: It uses a fixed index
dst_state_batch_indices[seq_idx, 0]for the update, which is incorrect for sequences with multiple tokens where each token might correspond to a different cache block.
The suggested change aligns the CPU fallback with the Triton kernel logic by performing per-token state updates when num_accepted_tokens is present.
if num_accepted_tokens is None:
if dst_state_batch_indices is not None:
dst_idx = dst_state_batch_indices[seq_idx, 0].item()
else:
dst_idx = state_idx
s = state[state_idx].float()
for t in range(seq_len):
token_idx = bos + t
x_val = x[token_idx].float()
dt_val = dt[token_idx].float()
if dt_bias is not None:
dt_val = dt_val + dt_bias.float()
if dt_softplus:
dt_val = torch.nn.functional.softplus(dt_val)
A_val = A.float()
B_val = B[token_idx].float()
B_expanded = B_val.repeat_interleave(nheads_ngroups_ratio, dim=0)
C_val = C[token_idx].float()
C_expanded = C_val.repeat_interleave(nheads_ngroups_ratio, dim=0)
dA = torch.exp(A_val * dt_val.unsqueeze(-1))
dBx = (B_expanded.unsqueeze(1) * (x_val * dt_val).unsqueeze(-1))
s = s * dA + dBx
if num_accepted_tokens is not None:
token_dst_idx = dst_state_batch_indices[seq_idx, t].item()
if token_dst_idx != null_block_id:
state[token_dst_idx] = s.to(state.dtype)
out_val = (s * C_expanded.unsqueeze(1)).sum(dim=-1)
if D is not None:
out_val = out_val + x_val * D.float()
if z is not None:
z_val = z[token_idx].float()
out_val = out_val * z_val * torch.sigmoid(z_val)
out[token_idx] = out_val.to(out.dtype)
if num_accepted_tokens is None:
if dst_idx != null_block_id:
state[dst_idx] = s.to(state.dtype)7602444 to
4e9a3a2
Compare
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
|
Hi @Akashcodes732, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
|
Hi @Akashcodes732, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
|
Hi @bigPYJ1151, Can you please take a look at this PR |
|
Hi @bigPYJ1151 , Can you please take a look at this PR ? |
|
Hi @bigPYJ1151, Can you please take a look at the changes in this PR |
|
Hi @bigPYJ1151 @tdoublep @tomeras91 , Can anyone please take a look at this PR ? |
|
Hi @Akashcodes732 thanks for the PR! I'm a bit concern about the performance of the torch native impl, because SSM and causal_conv1d are hotspot operations in Mamba models. In my opinion if we can't provide optimized performance, it's better to simply state we don't support. Is it possible to let code agent generate simple fused vector kernels of SSM and causal_conv1d? 😂 I'm inspired by #39445 , looks like CC also can produce CPU vec code. And please replace triton kernels for CPU via the pattern in #32662 |
Signed-off-by: Akash Kaothalkar <akashkaothalkar@Akashs-MBP.lan>
|
This pull request has merge conflicts that must be resolved before it can be |
# Conflicts: # csrc/cpu/torch_bindings.cpp
|
Hi @Akashcodes732, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
bigPYJ1151
left a comment
There was a problem hiding this comment.
Thanks @Akashcodes732 overall looks good. Can you follow the pattern in #32662 to replace Triton kernels on CPU? This will help to avoid adding too much branches.
| layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( | ||
| x, weight, bias | ||
| ) |
There was a problem hiding this comment.
| layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( | |
| x, weight, bias | |
| ) | |
| layer.cpu_linear = torch.nn.functional.linear |
|
Hi @Akashcodes732, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
|
Hi @Akashcodes732, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @Akashcodes732, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Purpose
Fixes: #27971
This PR enables execution for Granite 4 / Mamba architecture models on CPU backends which were previously crashing due to tightly coupled Triton kernel dependencies in the underlying layer implementations.
Prior to this PR, Mamba SSM blocks and short convolutions lacked standard CPU fallback routing, enforcing GPU/Triton configurations during module compilation and causing fatal NameError and ImportError crashes on CPU deployments.
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.