Skip to content

[CPU] Enable Granite 4 / Mamba models on CPU backend #39157

Draft
Akashcodes732 wants to merge 20 commits intovllm-project:mainfrom
Akashcodes732:feat/granite4_enablement
Draft

[CPU] Enable Granite 4 / Mamba models on CPU backend #39157
Akashcodes732 wants to merge 20 commits intovllm-project:mainfrom
Akashcodes732:feat/granite4_enablement

Conversation

@Akashcodes732
Copy link
Copy Markdown
Contributor

@Akashcodes732 Akashcodes732 commented Apr 7, 2026

Purpose

Fixes: #27971

This PR enables execution for Granite 4 / Mamba architecture models on CPU backends which were previously crashing due to tightly coupled Triton kernel dependencies in the underlying layer implementations.

Prior to this PR, Mamba SSM blocks and short convolutions lacked standard CPU fallback routing, enforcing GPU/Triton configurations during module compilation and causing fatal NameError and ImportError crashes on CPU deployments.

Test Plan

vllm bench serve --model ibm-granite/granite-4.0-tiny-preview --port 8001 --dataset-name random --random-input-len 1024 --random-output-len 128 --num-prompts 5

Test Result

============ Serving Benchmark Result ============
Successful requests:                     5         
Failed requests:                         0         
Benchmark duration (s):                  167.68    
Total input tokens:                      5120      
Total generated tokens:                  640       
Request throughput (req/s):              0.03      
Output token throughput (tok/s):         3.82      
Peak output token throughput (tok/s):    15.00    
Peak concurrent requests:                5.00      
Total token throughput (tok/s):          34.35     
---------------Time to First Token----------------
Mean TTFT (ms):                          80673.47  
Median TTFT (ms):                        105822.21 
P99 TTFT (ms):                           106362.19 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          682.34    
Median TPOT (ms):                        485.81    
P99 TPOT (ms):                           1129.68   
---------------Inter-token Latency----------------
Mean ITL (ms):                           682.34    
Median ITL (ms):                         466.58    
P99 ITL (ms):                            857.72    

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify Bot added the cpu Related to CPU backends label Apr 7, 2026
Akash Kaothalkar and others added 6 commits April 7, 2026 12:07
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@dhcp-9-123-5-76.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@Akashs-MBP.lan>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces CPU fallbacks for Mamba-related operations, including causal convolution and selective state updates, by implementing pure PyTorch versions and registering them through the CustomOp framework. The review feedback identifies several critical bugs in these new implementations: _causal_conv1d_fn_cpu incorrectly performs double-updates on the convolution state due to in-place modifications on a tensor view, _causal_conv1d_update_cpu fails to mask out padding slots which can lead to memory corruption, and _selective_state_update_cpu lacks the per-token state updates and correct indexing required for speculative decoding.

x_seq = x[:, seq_start:seq_end] # (dim, seq_len)

if has_initial_state is not None and has_initial_state[b]:
state = conv_states[cache_idx] # (dim, state_len)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The state tensor is obtained via basic indexing, which returns a view of conv_states. The subsequent in-place updates to state (lines 69-70) directly modify the global conv_states cache during the convolution loop. However, there is a second state update logic at the end of the sequence (lines 72-78) which also modifies conv_states. This results in a double-update/double-shift of the convolution state for sequences shorter than the kernel width, leading to incorrect results. Cloning the state at the beginning of the sequence processing ensures the loop works on a local copy, while the final state is correctly committed at the end of the function.

        if has_initial_state is not None and has_initial_state[b]:
            state = conv_states[cache_idx].clone()  # (dim, state_len)
        else:

new_state = torch.cat(
[states[:, :, 1:], x_t.unsqueeze(-1)], dim=-1
)
conv_state[cache_idxs] = new_state
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The state update does not account for pad_slot_id. If cache_idxs contains pad_slot_id (which is typically -1 in vLLM), conv_state[cache_idxs] = new_state will incorrectly update the last element of the conv_state tensor. The update should be restricted to valid indices using the valid_mask.

Suggested change
conv_state[cache_idxs] = new_state
conv_state[cache_idxs[valid_mask]] = new_state[valid_mask]

Comment on lines +257 to +297
if dst_state_batch_indices is not None:
dst_idx = dst_state_batch_indices[seq_idx, 0].item()
else:
dst_idx = state_idx

s = state[state_idx].float()

for t in range(seq_len):
token_idx = bos + t

x_val = x[token_idx].float()
dt_val = dt[token_idx].float()

if dt_bias is not None:
dt_val = dt_val + dt_bias.float()
if dt_softplus:
dt_val = torch.nn.functional.softplus(dt_val)

A_val = A.float()

B_val = B[token_idx].float()
B_expanded = B_val.repeat_interleave(nheads_ngroups_ratio, dim=0)
C_val = C[token_idx].float()
C_expanded = C_val.repeat_interleave(nheads_ngroups_ratio, dim=0)

dA = torch.exp(A_val * dt_val.unsqueeze(-1))
dBx = (B_expanded.unsqueeze(1) * (x_val * dt_val).unsqueeze(-1))
s = s * dA + dBx

out_val = (s * C_expanded.unsqueeze(1)).sum(dim=-1)

if D is not None:
out_val = out_val + x_val * D.float()

if z is not None:
z_val = z[token_idx].float()
out_val = out_val * z_val * torch.sigmoid(z_val)

out[token_idx] = out_val.to(out.dtype)

state[dst_idx] = s.to(state.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The CPU fallback for selective_state_update has two significant issues regarding speculative decoding (when num_accepted_tokens is provided):

  1. Missing Intermediate Updates: In speculative decoding mode, the SSM state must be updated for every token in the sequence to allow the verification step to choose the correct state. The current implementation only updates the state once at the end of the sequence (line 297), whereas the Triton kernel (and the Mamba architecture requirements) updates it per-token.
  2. Incorrect Indexing: It uses a fixed index dst_state_batch_indices[seq_idx, 0] for the update, which is incorrect for sequences with multiple tokens where each token might correspond to a different cache block.

The suggested change aligns the CPU fallback with the Triton kernel logic by performing per-token state updates when num_accepted_tokens is present.

        if num_accepted_tokens is None:
            if dst_state_batch_indices is not None:
                dst_idx = dst_state_batch_indices[seq_idx, 0].item()
            else:
                dst_idx = state_idx

        s = state[state_idx].float()

        for t in range(seq_len):
            token_idx = bos + t

            x_val = x[token_idx].float()
            dt_val = dt[token_idx].float()

            if dt_bias is not None:
                dt_val = dt_val + dt_bias.float()
            if dt_softplus:
                dt_val = torch.nn.functional.softplus(dt_val)

            A_val = A.float()

            B_val = B[token_idx].float()
            B_expanded = B_val.repeat_interleave(nheads_ngroups_ratio, dim=0)
            C_val = C[token_idx].float()
            C_expanded = C_val.repeat_interleave(nheads_ngroups_ratio, dim=0)

            dA = torch.exp(A_val * dt_val.unsqueeze(-1))
            dBx = (B_expanded.unsqueeze(1) * (x_val * dt_val).unsqueeze(-1))
            s = s * dA + dBx

            if num_accepted_tokens is not None:
                token_dst_idx = dst_state_batch_indices[seq_idx, t].item()
                if token_dst_idx != null_block_id:
                    state[token_dst_idx] = s.to(state.dtype)

            out_val = (s * C_expanded.unsqueeze(1)).sum(dim=-1)

            if D is not None:
                out_val = out_val + x_val * D.float()

            if z is not None:
                z_val = z[token_idx].float()
                out_val = out_val * z_val * torch.sigmoid(z_val)

            out[token_idx] = out_val.to(out.dtype)

        if num_accepted_tokens is None:
            if dst_idx != null_block_id:
                state[dst_idx] = s.to(state.dtype)

@Akashcodes732 Akashcodes732 force-pushed the feat/granite4_enablement branch from 7602444 to 4e9a3a2 Compare April 7, 2026 06:38
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 7, 2026

Hi @Akashcodes732, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 7, 2026

Hi @Akashcodes732, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
@Akashcodes732
Copy link
Copy Markdown
Contributor Author

Hi @bigPYJ1151,

Can you please take a look at this PR

@Akashcodes732
Copy link
Copy Markdown
Contributor Author

Hi @bigPYJ1151 ,

Can you please take a look at this PR ?

@bigPYJ1151 bigPYJ1151 self-assigned this Apr 10, 2026
@Akashcodes732
Copy link
Copy Markdown
Contributor Author

Hi @bigPYJ1151, Can you please take a look at the changes in this PR

@Akashcodes732
Copy link
Copy Markdown
Contributor Author

Hi

@bigPYJ1151 @tdoublep @tomeras91 ,

Can anyone please take a look at this PR ?

@bigPYJ1151
Copy link
Copy Markdown
Member

Hi @Akashcodes732 thanks for the PR! I'm a bit concern about the performance of the torch native impl, because SSM and causal_conv1d are hotspot operations in Mamba models. In my opinion if we can't provide optimized performance, it's better to simply state we don't support.

Is it possible to let code agent generate simple fused vector kernels of SSM and causal_conv1d? 😂 I'm inspired by #39445 , looks like CC also can produce CPU vec code.

And please replace triton kernels for CPU via the pattern in #32662

Signed-off-by: Akash Kaothalkar <akashkaothalkar@Akashs-MBP.lan>
@mergify mergify Bot added the ci/build label Apr 25, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Akashcodes732.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 25, 2026
@mergify mergify Bot removed the needs-rebase label Apr 25, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 25, 2026

Hi @Akashcodes732, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Copy Markdown
Member

@bigPYJ1151 bigPYJ1151 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Akashcodes732 overall looks good. Can you follow the pattern in #32662 to replace Triton kernels on CPU? This will help to avoid adding too much branches.

Comment thread vllm/model_executor/layers/utils.py Outdated
Comment on lines +234 to +236
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
x, weight, bias
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
x, weight, bias
)
layer.cpu_linear = torch.nn.functional.linear

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 5, 2026

Hi @Akashcodes732, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 5, 2026

Hi @Akashcodes732, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

fix
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 5, 2026

Hi @Akashcodes732, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@Akashcodes732 Akashcodes732 marked this pull request as draft May 5, 2026 10:09
Akash kaothalkar added 4 commits May 5, 2026 15:44
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 6, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Akashcodes732.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 6, 2026
Akash kaothalkar added 2 commits May 6, 2026 15:53
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends needs-rebase

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: ibm-granite/granite-4.0-h-tiny model fails for CPU on vLLM

2 participants