-
-
Notifications
You must be signed in to change notification settings - Fork 17.7k
Revert "[Model Runner V2] Bug fix: logprob dtype int64/int32 issue" (#41761) #42418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -124,13 +124,9 @@ def compute_topk_logprobs( | |||||
| # tokens where applicable. | ||||||
| assert logprob_token_ids_state is not None | ||||||
| assert expanded_idx_mapping is not None | ||||||
|
|
||||||
| topk_indices = None | ||||||
| if num_logprobs > 0: | ||||||
| topk_token_ids = torch.topk(logits, num_logprobs, dim=-1).indices | ||||||
| topk_token_ids = topk_token_ids.to(torch.int32) | ||||||
| else: | ||||||
| # This tensor just used as an int32 pointer, data not accessed. | ||||||
| topk_token_ids = logprob_token_ids_state.token_ids.gpu | ||||||
| topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reverting the cast to
Suggested change
|
||||||
|
|
||||||
| num_cols = max(num_logprobs, max_per_req_token_ids) | ||||||
| logprob_token_ids = sampled_token_ids.new_zeros((batch_size, 1 + num_cols)) | ||||||
|
|
@@ -141,8 +137,8 @@ def compute_topk_logprobs( | |||||
| valid_mask, | ||||||
| valid_mask.stride(0), | ||||||
| sampled_token_ids, | ||||||
| topk_token_ids, | ||||||
| topk_token_ids.stride(0), | ||||||
| topk_indices if topk_indices is not None else logprob_token_ids, | ||||||
| topk_indices.stride(0) if topk_indices is not None else 0, | ||||||
| expanded_idx_mapping, | ||||||
| logprob_token_ids_state.num_token_ids.gpu, | ||||||
| logprob_token_ids_state.token_ids.gpu, | ||||||
|
|
@@ -206,12 +202,14 @@ def _fill_logprob_token_ids_kernel( | |||||
| # Override topk with per-request custom tokens. | ||||||
| src = per_req_token_ids_ptr + req_state_idx * per_req_token_ids_stride | ||||||
| valid = col < num_custom | ||||||
| # per_req_token_ids is int32; output is int64. | ||||||
| tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The explicit cast to
Suggested change
|
||||||
| else: | ||||||
| # Fill with topk indices (no-op when NUM_TOPK == 0). | ||||||
| src = topk_indices_ptr + batch_idx * topk_indices_stride | ||||||
| valid = col < NUM_TOPK | ||||||
| tokens = tl.load(src + col, mask=valid, other=0) | ||||||
|
|
||||||
| tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64) | ||||||
| tl.store(tid_base + col, tokens, mask=valid) | ||||||
| tl.store(mask_base + col, tl.full([PADDED_COLS], 1, tl.int1), mask=valid) | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removal of the validation check between
logprobsandlogprob_token_idsre-introduces a state where inconsistent parameters are accepted but not correctly handled by the sampler. Since the V1 sampler uses anif/elselogic that prioritizes custom tokens, providing both with different lengths will lead to unexpected output formats or partially filled rows. This validation should be restored to maintain API integrity.