Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
408fa19
Basic Bamba Rewrite + Numerical Precision (#133)
lucaslie Sep 18, 2025
928d3a7
torch ssm and causal conv support (#134)
lucaslie Sep 21, 2025
73a8db9
Fix the bamba unit test (#136)
nvchenghaoz Sep 23, 2025
0ffb68a
Add custom op for ssm_transform and causal_conv (#137)
nvchenghaoz Sep 24, 2025
507c988
[https://nvbugs/5527956][fix] AutoDeploy: fix metadata to device copi…
lucaslie Sep 25, 2025
0c7a7ec
small bamba test for parallel execution (#142)
lucaslie Sep 25, 2025
631d051
fix overflow expression for number of pages in sequence interface (#144)
lucaslie Sep 26, 2025
0f8f08a
Fix the sampler and update the triton/cuda kernels (#146)
nvchenghaoz Sep 26, 2025
f89d496
NVIDIA-Nemotron-Nano-12B-v2 support (#147)
lucaslie Sep 27, 2025
de2ce2d
waive mamba tests (#149)
lucaslie Sep 29, 2025
e44be92
Fix the causal conv error
nvchenghaoz Sep 29, 2025
e7c44df
Fix the triton kernel test error
nvchenghaoz Sep 29, 2025
3e6ed8a
Add Nemotron-h acc test
nvchenghaoz Oct 1, 2025
e26846f
Add gsm8k testing
nvchenghaoz Oct 1, 2025
dabd3d7
resolve rebase issue
nvchenghaoz Oct 1, 2025
fdee78a
Pass the torch-compile and torch-simple
nvchenghaoz Oct 2, 2025
ffc646d
Disable the test as the golden data is wrong
nvchenghaoz Oct 2, 2025
cd6ed8b
fill seq info data with valid dummy data
lucaslie Sep 30, 2025
ac1670c
Revert "fill seq info data with valid dummy data"
nvchenghaoz Oct 2, 2025
8ad51fe
fill seq info data with valid dummy data
lucaslie Sep 30, 2025
af04130
Update the test setup
nvchenghaoz Oct 2, 2025
fe22f1c
Fix the redundant type conversion
nvchenghaoz Oct 3, 2025
5644504
Resolve the rebase issue
nvchenghaoz Oct 3, 2025
e719b2c
Merge branch 'main' into chenghao/fix-causal-conv
nvchenghaoz Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ def _cuda_cached_causal_conv1d(
# Use true start offsets for decode tokens (tail after prefills)
decode_idx = seq_start[num_prefill:].to(torch.long)
x_decode = inp_flat.index_select(0, decode_idx) # [num_decode, C_in]

slot_idx_decode = slot_idx[num_prefill:].to(torch.int32)
y_dec = causal_conv1d_update(
x_decode, # [batch, dim]
conv_state_cache,
w2d,
bias,
activation=None,
cache_seqlens=None,
conv_state_indices=slot_idx[num_prefill:].to(torch.int32),
conv_state_indices=slot_idx_decode,
pad_slot_id=PAD_SLOT_ID,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,12 @@ def _triton_cached_ssm_transform(

# Decode: batch single-token updates via selective_state_update
if num_decode > 0:
decode_idx = seq_start[num_prefill:].to(torch.long)
# In generate-only (s == 1), each batch element has one token and seq_start entries
# are typically zeros. Use arange over the flattened batch to index tokens correctly.
if s == 1:
decode_idx = torch.arange(bs, device=device, dtype=torch.long)
else:
decode_idx = seq_start[num_prefill:].to(torch.long)
slot_idx_decode = slot_idx[num_prefill:].to(torch.long)

x_decode = hs_flat.index_select(0, decode_idx) # [nd, H, D]
Expand Down Expand Up @@ -237,7 +242,8 @@ def get_cache_initializers(
ssm_state_size = max(1, B_fake.shape[-1])

def _get_ssm_cache(si: SequenceInfo):
return torch.empty(
# Initialize to zeros so brand-new sequences start from a clean state.
return torch.zeros(
Comment on lines +245 to +246
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not needed when we correctly index caches

si.max_batch_size,
num_heads,
head_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def mamba_env():
return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol}


@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5548861")
def test_triton_generate_only_with_slot_mapping(mamba_env):
device = mamba_env["device"]
dtype = mamba_env["dtype"]
Expand Down
Loading