forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 1
Fix the unit test errors / enable accuracy tests #150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nvchenghaoz
wants to merge
24
commits into
main
Choose a base branch
from
chenghao/fix-causal-conv
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
408fa19
Basic Bamba Rewrite + Numerical Precision (#133)
lucaslie 928d3a7
torch ssm and causal conv support (#134)
lucaslie 73a8db9
Fix the bamba unit test (#136)
nvchenghaoz 0ffb68a
Add custom op for ssm_transform and causal_conv (#137)
nvchenghaoz 507c988
[https://nvbugs/5527956][fix] AutoDeploy: fix metadata to device copi…
lucaslie 0c7a7ec
small bamba test for parallel execution (#142)
lucaslie 631d051
fix overflow expression for number of pages in sequence interface (#144)
lucaslie 0f8f08a
Fix the sampler and update the triton/cuda kernels (#146)
nvchenghaoz f89d496
NVIDIA-Nemotron-Nano-12B-v2 support (#147)
lucaslie de2ce2d
waive mamba tests (#149)
lucaslie e44be92
Fix the causal conv error
nvchenghaoz e7c44df
Fix the triton kernel test error
nvchenghaoz 3e6ed8a
Add Nemotron-h acc test
nvchenghaoz e26846f
Add gsm8k testing
nvchenghaoz dabd3d7
resolve rebase issue
nvchenghaoz fdee78a
Pass the torch-compile and torch-simple
nvchenghaoz ffc646d
Disable the test as the golden data is wrong
nvchenghaoz cd6ed8b
fill seq info data with valid dummy data
lucaslie ac1670c
Revert "fill seq info data with valid dummy data"
nvchenghaoz 8ad51fe
fill seq info data with valid dummy data
lucaslie af04130
Update the test setup
nvchenghaoz fe22f1c
Fix the redundant type conversion
nvchenghaoz 5644504
Resolve the rebase issue
nvchenghaoz e719b2c
Merge branch 'main' into chenghao/fix-causal-conv
nvchenghaoz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -124,7 +124,12 @@ def _triton_cached_ssm_transform( | |
|
|
||
| # Decode: batch single-token updates via selective_state_update | ||
| if num_decode > 0: | ||
| decode_idx = seq_start[num_prefill:].to(torch.long) | ||
| # In generate-only (s == 1), each batch element has one token and seq_start entries | ||
| # are typically zeros. Use arange over the flattened batch to index tokens correctly. | ||
| if s == 1: | ||
| decode_idx = torch.arange(bs, device=device, dtype=torch.long) | ||
| else: | ||
| decode_idx = seq_start[num_prefill:].to(torch.long) | ||
| slot_idx_decode = slot_idx[num_prefill:].to(torch.long) | ||
|
|
||
| x_decode = hs_flat.index_select(0, decode_idx) # [nd, H, D] | ||
|
|
@@ -237,7 +242,8 @@ def get_cache_initializers( | |
| ssm_state_size = max(1, B_fake.shape[-1]) | ||
|
|
||
| def _get_ssm_cache(si: SequenceInfo): | ||
| return torch.empty( | ||
| # Initialize to zeros so brand-new sequences start from a clean state. | ||
| return torch.zeros( | ||
|
Comment on lines
+245
to
+246
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: not needed when we correctly index caches |
||
| si.max_batch_size, | ||
| num_heads, | ||
| head_dim, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.