Skip to content
247 changes: 141 additions & 106 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,13 @@ class BeamHistory:
cum_logprobs: torch.Tensor | None = None


class SamplingRequestsMetadata(NamedTuple):
req_num_generated_tokens: torch.Tensor
req_num_beams: torch.Tensor
req_num_steps: torch.Tensor
req_offsets: torch.Tensor


@dataclass(kw_only=True)
class SampleStateTensorsHostTorch(SampleStateTensors):
finish_reasons: torch.Tensor
Expand Down Expand Up @@ -1672,8 +1679,8 @@ def update_requests(
self.handle_logprobs(req, count=processed)
req.py_decoding_iter += 1

def return_log_probs(self, scheduled_requests: ScheduledRequests) -> bool:
return any(req.py_return_log_probs for req in scheduled_requests.all_requests())
def _return_log_probs(self, requests: list[LlmRequest]) -> bool:
return any(req.py_return_log_probs for req in requests)

@override
@torch.inference_mode()
Expand All @@ -1693,7 +1700,6 @@ def sample_async(
self.setup_sampler_step(scheduled_requests)
requests = scheduled_requests.all_requests()
new_tokens = self.store.new_tokens
return_log_probs = self.return_log_probs(scheduled_requests)
seq_slots_host = torch.tensor(
[r.py_seq_slot for r in requests],
dtype=torch.int64, # for index_fill_
Expand All @@ -1713,9 +1719,7 @@ def sample_async(
new_tokens,
num_context_logits_prefix_sum,
seq_slots=seq_slots_host,
return_log_probs=return_log_probs,
seq_lens=seq_lens_host,
resource_manager=resource_manager,
)

finish_reasons = self.store.finish_reasons
Expand Down Expand Up @@ -1850,6 +1854,65 @@ def provision_bias_index() -> int:
# sharing).
logits[logits_bias_mask_cuda] += biases_tensor_cuda

def _handle_log_probs(
self,
requests: list[LlmRequest],
logits_cuda: torch.Tensor,
*,
logits_cuda_indexer: _PackedStepIndexer,
req_num_generated_tokens: torch.Tensor,
) -> None:
"""Handle top-k logprobs.

This is done outside the sampling loop, because the returned logprobs are specified to not reflect
temperature scaling, top-k/top-p masking, etc.
"""
if self._return_log_probs(requests):
assert logits_cuda.dim() == 2, "logits should be 2D"

logprobs_req_indices = [
req_id for req_id, req in enumerate(requests) if req.py_num_logprobs
]
logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices]
logprobs_logit_indices_cuda = logprobs_logit_indices.to(
device=logits_cuda.device, non_blocking=True
)
logprobs_cuda = F.log_softmax(
logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True),
dim=-1,
)
topk_vals_cuda, topk_indices_cuda = torch.topk(
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
)
# Use a single D2H copy to reduce overheads
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
current_offset = 0
for req_id, steps in zip(
logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist()
):
req = requests[req_id]
next_offset = current_offset + steps
# NB: Assigning views on memory which is being filled asynchronously
req.py_topk_logprobs_vals = topk_vals[
current_offset:next_offset, : req.py_num_logprobs
]
req.py_topk_logprobs_indices = topk_indices[
current_offset:next_offset, : req.py_num_logprobs
]

# context requests do not have multiple input beams, but they need multiple output beams
if req.is_context_init_state:
req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand(
req.sampling_config.beam_width, -1
)
req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand(
req.sampling_config.beam_width, -1
)
current_offset = next_offset

@nvtx_range("sample_batched_by_strategy")
@torch.inference_mode()
def _sample_batched_by_strategy(
Expand Down Expand Up @@ -2071,11 +2134,46 @@ def _select_generated_logits(
scheduled_requests: ScheduledRequests,
raw_logits_cuda: torch.Tensor,
*,
req_num_generation_steps: torch.Tensor,
num_context_logits_prefix_sum: list[int],
generation_requests_total_steps: int,
num_logits_to_keep: int,
) -> torch.Tensor:
) -> tuple[SamplingRequestsMetadata, torch.Tensor]:
requests = scheduled_requests.all_requests()

req_num_generation_steps_list = [1 + get_draft_token_length(req) for req in requests]
req_num_generation_steps = torch.tensor(
req_num_generation_steps_list, dtype=torch.int32, pin_memory=True
)

# context requests do not have multiple beams yet, so beam width may differ in mixed batches
req_num_beams_list = [
req.sampling_config.beam_width if not req.is_context_init_state else 1
for req in requests
]
req_num_beams = torch.tensor(req_num_beams_list, dtype=torch.int32, pin_memory=True)

req_num_generated_tokens = req_num_generation_steps * req_num_beams
# NB: These offsets consider generated tokens _only_ (draft and target, but not context).
# Filter out the context tokens below.
req_offsets, sum_num_generated_tokens = _PackedStepIndexer.calculate_request_offsets(
req_num_generated_tokens, pin_memory=True
)

generation_requests_total_steps = (
# NB: requests == scheduled_requests.context_requests + scheduled_requests.generation_requests
sum_num_generated_tokens
- cast(int, req_offsets[len(scheduled_requests.context_requests)].item())
if scheduled_requests.generation_requests
else 0
)

sampling_requests_metadata = SamplingRequestsMetadata(
req_num_generated_tokens=req_num_generated_tokens,
req_num_beams=req_num_beams,
req_num_steps=req_num_generation_steps,
req_offsets=req_offsets,
)

num_logits_to_keep = sum_num_generated_tokens

# raw_logits should contain only the generated logits.
# If return context logits is requested, select only the generated logits.
#
Expand All @@ -2085,22 +2183,22 @@ def _select_generated_logits(
assert (
len(num_context_logits_prefix_sum) == len(scheduled_requests.context_requests) + 1
)
req_num_generation_steps_cuda = req_num_generation_steps.to(
req_num_generated_tokens_cuda = req_num_generated_tokens.to(
raw_logits_cuda.device, non_blocking=True
)
context_req_offsets_cuda = torch.tensor(
num_context_logits_prefix_sum, dtype=torch.int32, pin_memory=True
).to(device=raw_logits_cuda.device, non_blocking=True)

# Since the goal is to keep the req_num_steps[i] last tokens for each requests[i],
# only end-offsets of the token storage locations matter.
next_context_req_offsets_cuda = context_req_offsets_cuda.roll(
-1
) # trailing '0' is overwritten below
# Since logits for generation requests are densely packed, cover them all by a single
# fictituous entry in 'context_req_offsets_cuda'.
if scheduled_requests.generation_requests:
req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[
# Since the goal is to keep the req_num_steps[i] last tokens for each requests[i],
# only end-offsets of the token storage locations matter.
next_context_req_offsets_cuda = context_req_offsets_cuda.roll(
-1
) # trailing '0' is overwritten below
# Since logits for generation requests are densely packed, cover them all by a single
# fictituous entry in 'context_req_offsets_cuda'.
req_num_steps_fictitious_cuda = req_num_generated_tokens_cuda[
: (len(scheduled_requests.context_requests) + 1)
].clone()
req_num_steps_fictitious_cuda[-1].fill_(generation_requests_total_steps)
Expand All @@ -2109,10 +2207,12 @@ def _select_generated_logits(
non_blocking=True,
)
else:
req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[
req_num_steps_fictitious_cuda = req_num_generated_tokens_cuda[
: len(scheduled_requests.context_requests)
]
next_context_req_offsets_cuda = next_context_req_offsets_cuda[:-1]
# Since the goal is to keep the req_num_steps[i] last tokens for each requests[i],
# only end-offsets of the token storage locations matter.
next_context_req_offsets_cuda = context_req_offsets_cuda[1:]

# Now, the generated tokens for context request i are at indices
# range(next_context_req_offsets_cuda[i] - req_num_steps_fictitious_cuda[i],
Expand All @@ -2126,7 +2226,10 @@ def _select_generated_logits(
)

raw_logits_cuda = raw_logits_cuda[indices_to_keep_cuda]
return raw_logits_cuda

logits_cuda = raw_logits_cuda[:num_logits_to_keep]

return sampling_requests_metadata, logits_cuda

@staticmethod
def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
Expand Down Expand Up @@ -2395,113 +2498,45 @@ def _process_requests(
*,
seq_slots: torch.Tensor,
seq_lens: torch.Tensor | None = None,
return_log_probs: bool,
resource_manager: Optional[ResourceManager] = None,
) -> torch.Tensor:
seq_slots = seq_slots.to(dtype=torch.int32) # int32 suffices here

raw_logits_cuda = model_outputs["logits"]

requests = scheduled_requests.all_requests()
cuda_device = raw_logits_cuda.device
req_num_steps_list = [1 + get_draft_token_length(req) for req in requests]
req_num_steps = torch.tensor(req_num_steps_list, dtype=torch.int32, pin_memory=True)

# context requests do not have multiple beams yet, so beam width may differ in mixed batches
req_num_beams_list = [
req.sampling_config.beam_width if not req.is_context_init_state else 1
for req in requests
]
req_num_beams = torch.tensor(req_num_beams_list, dtype=torch.int32, pin_memory=True)

req_num_generated_tokens = req_num_steps * req_num_beams
# NB: These offsets consider generated tokens _only_ (draft and target, but not context)
# and are thus only correct after _select_generated_logits() below.
#
req_offsets, sum_num_generated_tokens = _PackedStepIndexer.calculate_request_offsets(
req_num_generated_tokens, pin_memory=True
)

raw_logits_cuda = self._select_generated_logits(
sampling_requests_metadata, logits_cuda = self._select_generated_logits(
scheduled_requests,
raw_logits_cuda,
req_num_generation_steps=req_num_generated_tokens,
num_context_logits_prefix_sum=num_context_logits_prefix_sum,
generation_requests_total_steps=(
# NB: requests == scheduled_requests.context_requests + scheduled_requests.generation_requests
sum_num_generated_tokens
- cast(int, req_offsets[len(scheduled_requests.context_requests)].item())
if scheduled_requests.generation_requests
else 0
),
num_logits_to_keep=sum_num_generated_tokens,
)

# Handle embedding bias
logits_cuda = raw_logits_cuda[:sum_num_generated_tokens]
self._apply_embedding_bias(logits_cuda, requests, req_num_steps)
self._apply_embedding_bias(logits_cuda, requests, sampling_requests_metadata.req_num_steps)

logits_cuda = self._apply_min_length_penalty(
logits_cuda, requests, req_num_steps, req_num_beams
logits_cuda,
requests,
sampling_requests_metadata.req_num_steps,
sampling_requests_metadata.req_num_beams,
)

# Indexer for accessing tokens in 'logits_cuda', corresponding to the
# requests in 'requests'.
steps_dim_size = new_tokens_cuda.size(0)
logits_cuda_indexer = _PackedStepIndexer(
num_steps=req_num_generated_tokens,
num_steps=sampling_requests_metadata.req_num_generated_tokens,
max_steps=steps_dim_size * self.max_beam_width,
req_offsets=req_offsets,
req_offsets=sampling_requests_metadata.req_offsets,
)

# Handle top-k logprobs. This is done outside the sampling loop,
# because the returned logprobs are specified to not reflect temperature scaling,
# top-k/top-p masking, etc.
if return_log_probs:
assert logits_cuda.dim() == 2, "logits should be 2D"

logprobs_req_indices = [
req_id for req_id, req in enumerate(requests) if req.py_num_logprobs
]
logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices]
logprobs_logit_indices_cuda = logprobs_logit_indices.to(
device=logits_cuda.device, non_blocking=True
)
logprobs_cuda = F.log_softmax(
logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True),
dim=-1,
)
topk_vals_cuda, topk_indices_cuda = torch.topk(
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
)
# Use a single D2H copy to reduce overheads
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
current_offset = 0
for req_id, steps in zip(
logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist()
):
req = requests[req_id]
next_offset = current_offset + steps
# NB: Assigning views on memory which is being filled asynchronously
req.py_topk_logprobs_vals = topk_vals[
current_offset:next_offset, : req.py_num_logprobs
]
req.py_topk_logprobs_indices = topk_indices[
current_offset:next_offset, : req.py_num_logprobs
]

# context requests do not have multiple input beams, but they need multiple output beams
if req.is_context_init_state:
req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand(
req.sampling_config.beam_width, -1
)
req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand(
req.sampling_config.beam_width, -1
)
current_offset = next_offset
self._handle_log_probs(
requests,
logits_cuda,
logits_cuda_indexer=logits_cuda_indexer,
req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens,
)

# Perform sampling in batches
batched_sampling_result = self._sample_batched_by_strategy(
Expand All @@ -2510,18 +2545,18 @@ def _process_requests(
model_outputs,
cuda_device=cuda_device,
logits_cuda_indexer=logits_cuda_indexer,
req_offsets=req_offsets,
req_offsets=sampling_requests_metadata.req_offsets,
seq_slots=seq_slots,
seq_lens=seq_lens,
req_num_generated_tokens=req_num_generated_tokens,
req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens,
token_dtype=new_tokens_cuda.dtype,
)

# Fill results into output buffers
new_tokens_host = self._unbatch_sampling_results(
batched_sampling_result,
new_tokens_cuda=new_tokens_cuda,
req_num_generated_tokens=req_num_generated_tokens,
req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens,
seq_slots=seq_slots,
)

Expand Down
Loading
Loading