diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 3b8e42d2320..ce5a19be881 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -648,6 +648,13 @@ class BeamHistory: cum_logprobs: torch.Tensor | None = None +class SamplingRequestsMetadata(NamedTuple): + req_num_generated_tokens: torch.Tensor + req_num_beams: torch.Tensor + req_num_steps: torch.Tensor + req_offsets: torch.Tensor + + @dataclass(kw_only=True) class SampleStateTensorsHostTorch(SampleStateTensors): finish_reasons: torch.Tensor @@ -1672,8 +1679,8 @@ def update_requests( self.handle_logprobs(req, count=processed) req.py_decoding_iter += 1 - def return_log_probs(self, scheduled_requests: ScheduledRequests) -> bool: - return any(req.py_return_log_probs for req in scheduled_requests.all_requests()) + def _return_log_probs(self, requests: list[LlmRequest]) -> bool: + return any(req.py_return_log_probs for req in requests) @override @torch.inference_mode() @@ -1693,7 +1700,6 @@ def sample_async( self.setup_sampler_step(scheduled_requests) requests = scheduled_requests.all_requests() new_tokens = self.store.new_tokens - return_log_probs = self.return_log_probs(scheduled_requests) seq_slots_host = torch.tensor( [r.py_seq_slot for r in requests], dtype=torch.int64, # for index_fill_ @@ -1713,9 +1719,7 @@ def sample_async( new_tokens, num_context_logits_prefix_sum, seq_slots=seq_slots_host, - return_log_probs=return_log_probs, seq_lens=seq_lens_host, - resource_manager=resource_manager, ) finish_reasons = self.store.finish_reasons @@ -1850,6 +1854,65 @@ def provision_bias_index() -> int: # sharing). logits[logits_bias_mask_cuda] += biases_tensor_cuda + def _handle_log_probs( + self, + requests: list[LlmRequest], + logits_cuda: torch.Tensor, + *, + logits_cuda_indexer: _PackedStepIndexer, + req_num_generated_tokens: torch.Tensor, + ) -> None: + """Handle top-k logprobs. + + This is done outside the sampling loop, because the returned logprobs are specified to not reflect + temperature scaling, top-k/top-p masking, etc. + """ + if self._return_log_probs(requests): + assert logits_cuda.dim() == 2, "logits should be 2D" + + logprobs_req_indices = [ + req_id for req_id, req in enumerate(requests) if req.py_num_logprobs + ] + logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] + logprobs_logit_indices_cuda = logprobs_logit_indices.to( + device=logits_cuda.device, non_blocking=True + ) + logprobs_cuda = F.log_softmax( + logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True), + dim=-1, + ) + topk_vals_cuda, topk_indices_cuda = torch.topk( + logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 + ) + # Use a single D2H copy to reduce overheads + topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) + topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) + topk_vals.copy_(topk_vals_cuda, non_blocking=True) + topk_indices.copy_(topk_indices_cuda, non_blocking=True) + current_offset = 0 + for req_id, steps in zip( + logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist() + ): + req = requests[req_id] + next_offset = current_offset + steps + # NB: Assigning views on memory which is being filled asynchronously + req.py_topk_logprobs_vals = topk_vals[ + current_offset:next_offset, : req.py_num_logprobs + ] + req.py_topk_logprobs_indices = topk_indices[ + current_offset:next_offset, : req.py_num_logprobs + ] + + # context requests do not have multiple input beams, but they need multiple output beams + if req.is_context_init_state: + req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand( + req.sampling_config.beam_width, -1 + ) + req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand( + req.sampling_config.beam_width, -1 + ) + current_offset = next_offset + @nvtx_range("sample_batched_by_strategy") @torch.inference_mode() def _sample_batched_by_strategy( @@ -2071,11 +2134,46 @@ def _select_generated_logits( scheduled_requests: ScheduledRequests, raw_logits_cuda: torch.Tensor, *, - req_num_generation_steps: torch.Tensor, num_context_logits_prefix_sum: list[int], - generation_requests_total_steps: int, - num_logits_to_keep: int, - ) -> torch.Tensor: + ) -> tuple[SamplingRequestsMetadata, torch.Tensor]: + requests = scheduled_requests.all_requests() + + req_num_generation_steps_list = [1 + get_draft_token_length(req) for req in requests] + req_num_generation_steps = torch.tensor( + req_num_generation_steps_list, dtype=torch.int32, pin_memory=True + ) + + # context requests do not have multiple beams yet, so beam width may differ in mixed batches + req_num_beams_list = [ + req.sampling_config.beam_width if not req.is_context_init_state else 1 + for req in requests + ] + req_num_beams = torch.tensor(req_num_beams_list, dtype=torch.int32, pin_memory=True) + + req_num_generated_tokens = req_num_generation_steps * req_num_beams + # NB: These offsets consider generated tokens _only_ (draft and target, but not context). + # Filter out the context tokens below. + req_offsets, sum_num_generated_tokens = _PackedStepIndexer.calculate_request_offsets( + req_num_generated_tokens, pin_memory=True + ) + + generation_requests_total_steps = ( + # NB: requests == scheduled_requests.context_requests + scheduled_requests.generation_requests + sum_num_generated_tokens + - cast(int, req_offsets[len(scheduled_requests.context_requests)].item()) + if scheduled_requests.generation_requests + else 0 + ) + + sampling_requests_metadata = SamplingRequestsMetadata( + req_num_generated_tokens=req_num_generated_tokens, + req_num_beams=req_num_beams, + req_num_steps=req_num_generation_steps, + req_offsets=req_offsets, + ) + + num_logits_to_keep = sum_num_generated_tokens + # raw_logits should contain only the generated logits. # If return context logits is requested, select only the generated logits. # @@ -2085,22 +2183,22 @@ def _select_generated_logits( assert ( len(num_context_logits_prefix_sum) == len(scheduled_requests.context_requests) + 1 ) - req_num_generation_steps_cuda = req_num_generation_steps.to( + req_num_generated_tokens_cuda = req_num_generated_tokens.to( raw_logits_cuda.device, non_blocking=True ) context_req_offsets_cuda = torch.tensor( num_context_logits_prefix_sum, dtype=torch.int32, pin_memory=True ).to(device=raw_logits_cuda.device, non_blocking=True) - # Since the goal is to keep the req_num_steps[i] last tokens for each requests[i], - # only end-offsets of the token storage locations matter. - next_context_req_offsets_cuda = context_req_offsets_cuda.roll( - -1 - ) # trailing '0' is overwritten below - # Since logits for generation requests are densely packed, cover them all by a single - # fictituous entry in 'context_req_offsets_cuda'. if scheduled_requests.generation_requests: - req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[ + # Since the goal is to keep the req_num_steps[i] last tokens for each requests[i], + # only end-offsets of the token storage locations matter. + next_context_req_offsets_cuda = context_req_offsets_cuda.roll( + -1 + ) # trailing '0' is overwritten below + # Since logits for generation requests are densely packed, cover them all by a single + # fictituous entry in 'context_req_offsets_cuda'. + req_num_steps_fictitious_cuda = req_num_generated_tokens_cuda[ : (len(scheduled_requests.context_requests) + 1) ].clone() req_num_steps_fictitious_cuda[-1].fill_(generation_requests_total_steps) @@ -2109,10 +2207,12 @@ def _select_generated_logits( non_blocking=True, ) else: - req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[ + req_num_steps_fictitious_cuda = req_num_generated_tokens_cuda[ : len(scheduled_requests.context_requests) ] - next_context_req_offsets_cuda = next_context_req_offsets_cuda[:-1] + # Since the goal is to keep the req_num_steps[i] last tokens for each requests[i], + # only end-offsets of the token storage locations matter. + next_context_req_offsets_cuda = context_req_offsets_cuda[1:] # Now, the generated tokens for context request i are at indices # range(next_context_req_offsets_cuda[i] - req_num_steps_fictitious_cuda[i], @@ -2126,7 +2226,10 @@ def _select_generated_logits( ) raw_logits_cuda = raw_logits_cuda[indices_to_keep_cuda] - return raw_logits_cuda + + logits_cuda = raw_logits_cuda[:num_logits_to_keep] + + return sampling_requests_metadata, logits_cuda @staticmethod def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int: @@ -2395,8 +2498,6 @@ def _process_requests( *, seq_slots: torch.Tensor, seq_lens: torch.Tensor | None = None, - return_log_probs: bool, - resource_manager: Optional[ResourceManager] = None, ) -> torch.Tensor: seq_slots = seq_slots.to(dtype=torch.int32) # int32 suffices here @@ -2404,104 +2505,38 @@ def _process_requests( requests = scheduled_requests.all_requests() cuda_device = raw_logits_cuda.device - req_num_steps_list = [1 + get_draft_token_length(req) for req in requests] - req_num_steps = torch.tensor(req_num_steps_list, dtype=torch.int32, pin_memory=True) - - # context requests do not have multiple beams yet, so beam width may differ in mixed batches - req_num_beams_list = [ - req.sampling_config.beam_width if not req.is_context_init_state else 1 - for req in requests - ] - req_num_beams = torch.tensor(req_num_beams_list, dtype=torch.int32, pin_memory=True) - - req_num_generated_tokens = req_num_steps * req_num_beams - # NB: These offsets consider generated tokens _only_ (draft and target, but not context) - # and are thus only correct after _select_generated_logits() below. - # - req_offsets, sum_num_generated_tokens = _PackedStepIndexer.calculate_request_offsets( - req_num_generated_tokens, pin_memory=True - ) - raw_logits_cuda = self._select_generated_logits( + sampling_requests_metadata, logits_cuda = self._select_generated_logits( scheduled_requests, raw_logits_cuda, - req_num_generation_steps=req_num_generated_tokens, num_context_logits_prefix_sum=num_context_logits_prefix_sum, - generation_requests_total_steps=( - # NB: requests == scheduled_requests.context_requests + scheduled_requests.generation_requests - sum_num_generated_tokens - - cast(int, req_offsets[len(scheduled_requests.context_requests)].item()) - if scheduled_requests.generation_requests - else 0 - ), - num_logits_to_keep=sum_num_generated_tokens, ) # Handle embedding bias - logits_cuda = raw_logits_cuda[:sum_num_generated_tokens] - self._apply_embedding_bias(logits_cuda, requests, req_num_steps) + self._apply_embedding_bias(logits_cuda, requests, sampling_requests_metadata.req_num_steps) logits_cuda = self._apply_min_length_penalty( - logits_cuda, requests, req_num_steps, req_num_beams + logits_cuda, + requests, + sampling_requests_metadata.req_num_steps, + sampling_requests_metadata.req_num_beams, ) # Indexer for accessing tokens in 'logits_cuda', corresponding to the # requests in 'requests'. steps_dim_size = new_tokens_cuda.size(0) logits_cuda_indexer = _PackedStepIndexer( - num_steps=req_num_generated_tokens, + num_steps=sampling_requests_metadata.req_num_generated_tokens, max_steps=steps_dim_size * self.max_beam_width, - req_offsets=req_offsets, + req_offsets=sampling_requests_metadata.req_offsets, ) - # Handle top-k logprobs. This is done outside the sampling loop, - # because the returned logprobs are specified to not reflect temperature scaling, - # top-k/top-p masking, etc. - if return_log_probs: - assert logits_cuda.dim() == 2, "logits should be 2D" - - logprobs_req_indices = [ - req_id for req_id, req in enumerate(requests) if req.py_num_logprobs - ] - logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] - logprobs_logit_indices_cuda = logprobs_logit_indices.to( - device=logits_cuda.device, non_blocking=True - ) - logprobs_cuda = F.log_softmax( - logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True), - dim=-1, - ) - topk_vals_cuda, topk_indices_cuda = torch.topk( - logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 - ) - # Use a single D2H copy to reduce overheads - topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) - topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) - topk_vals.copy_(topk_vals_cuda, non_blocking=True) - topk_indices.copy_(topk_indices_cuda, non_blocking=True) - current_offset = 0 - for req_id, steps in zip( - logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist() - ): - req = requests[req_id] - next_offset = current_offset + steps - # NB: Assigning views on memory which is being filled asynchronously - req.py_topk_logprobs_vals = topk_vals[ - current_offset:next_offset, : req.py_num_logprobs - ] - req.py_topk_logprobs_indices = topk_indices[ - current_offset:next_offset, : req.py_num_logprobs - ] - - # context requests do not have multiple input beams, but they need multiple output beams - if req.is_context_init_state: - req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand( - req.sampling_config.beam_width, -1 - ) - req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand( - req.sampling_config.beam_width, -1 - ) - current_offset = next_offset + self._handle_log_probs( + requests, + logits_cuda, + logits_cuda_indexer=logits_cuda_indexer, + req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens, + ) # Perform sampling in batches batched_sampling_result = self._sample_batched_by_strategy( @@ -2510,10 +2545,10 @@ def _process_requests( model_outputs, cuda_device=cuda_device, logits_cuda_indexer=logits_cuda_indexer, - req_offsets=req_offsets, + req_offsets=sampling_requests_metadata.req_offsets, seq_slots=seq_slots, seq_lens=seq_lens, - req_num_generated_tokens=req_num_generated_tokens, + req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens, token_dtype=new_tokens_cuda.dtype, ) @@ -2521,7 +2556,7 @@ def _process_requests( new_tokens_host = self._unbatch_sampling_results( batched_sampling_result, new_tokens_cuda=new_tokens_cuda, - req_num_generated_tokens=req_num_generated_tokens, + req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens, seq_slots=seq_slots, ) diff --git a/tests/unittest/_torch/executor/test_overlap_scheduler.py b/tests/unittest/_torch/executor/test_overlap_scheduler.py index b825907c615..bb687f2fa37 100644 --- a/tests/unittest/_torch/executor/test_overlap_scheduler.py +++ b/tests/unittest/_torch/executor/test_overlap_scheduler.py @@ -60,28 +60,26 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type): use_beam_search=True) # Test with overlap scheduler enabled - llm = create_llm(model_path, - disable_overlap_scheduler=False, - sampler_type=sampler_type) - outputs_with_overlap = llm.generate(prompts, - sampling_params=sampling_config, - use_tqdm=True) - texts_with_overlap = [[ - completion.text for completion in request_output.outputs - ] for request_output in outputs_with_overlap] - llm.shutdown() + with create_llm(model_path, + disable_overlap_scheduler=False, + sampler_type=sampler_type) as llm: + outputs_with_overlap = llm.generate(prompts, + sampling_params=sampling_config, + use_tqdm=True) + texts_with_overlap = [[ + completion.text for completion in request_output.outputs + ] for request_output in outputs_with_overlap] # Test with overlap scheduler disabled - llm = create_llm(model_path, - disable_overlap_scheduler=True, - sampler_type=sampler_type) - outputs_without_overlap = llm.generate(prompts, - sampling_params=sampling_config, - use_tqdm=True) - texts_without_overlap = [[ - completion.text for completion in request_output.outputs - ] for request_output in outputs_without_overlap] - llm.shutdown() + with create_llm(model_path, + disable_overlap_scheduler=True, + sampler_type=sampler_type) as llm: + outputs_without_overlap = llm.generate(prompts, + sampling_params=sampling_config, + use_tqdm=True) + texts_without_overlap = [[ + completion.text for completion in request_output.outputs + ] for request_output in outputs_without_overlap] # Verify outputs are consistent for with_overlap, without_overlap in zip(texts_with_overlap, diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index e6530125dee..2daa357b54d 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -430,8 +430,15 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool) @contextmanager def _test_runner(is_warmup: bool) -> Generator[Callable[[], None], None, None]: + draft_len_req1 = draft_len + draft_len_req2 = draft_len + 1 # test with different draft lens + class ContextRequestMock: - def __init__(self, return_context_logits: bool): + def __init__(self, is_last_context_chunk: bool, return_context_logits: bool): + self.is_context_init_state = True + self.is_last_context_chunk = is_last_context_chunk + self.py_draft_tokens = torch.tensor([], dtype=torch.int32, device=device) + self.sampling_config = SamplingConfig(beam_width=1) self._return_context_logits = return_context_logits @property @@ -439,7 +446,10 @@ def py_return_context_logits(self) -> bool: return self._return_context_logits class GenRequestMock: - pass + def __init__(self, draft_len: int): + self.is_context_init_state = False + self.py_draft_tokens = torch.empty(draft_len, dtype=torch.int32, device=device) + self.sampling_config = SamplingConfig(beam_width=1) class ScheduledRequestsMock: @property @@ -448,9 +458,24 @@ def context_requests(self) -> list[LlmRequest]: [ # NB: One request with py_return_context_logits is enough # to trigger tested code. - cast(LlmRequest, ContextRequestMock(True)), - cast(LlmRequest, ContextRequestMock(False)), - cast(LlmRequest, ContextRequestMock(True)), + cast( + LlmRequest, + ContextRequestMock( + is_last_context_chunk=True, return_context_logits=True + ), + ), + cast( + LlmRequest, + ContextRequestMock( + is_last_context_chunk=True, return_context_logits=False + ), + ), + cast( + LlmRequest, + ContextRequestMock( + is_last_context_chunk=True, return_context_logits=True + ), + ), ] if with_ctx else [] @@ -462,14 +487,18 @@ def generation_requests(self) -> list[LlmRequest]: # is not empty. return ( [ - cast(LlmRequest, GenRequestMock()), - cast(LlmRequest, GenRequestMock()), + cast(LlmRequest, GenRequestMock(draft_len=draft_len_req1)), + cast(LlmRequest, GenRequestMock(draft_len=draft_len_req2)), ] if with_gen else [] ) - vocab_size = 12 + def all_requests(self) -> list[LlmRequest]: + return self.context_requests + self.generation_requests + + expected_num_requests = with_ctx * 3 + with_gen * 2 + expected_req_num_beams = torch.tensor([1] * expected_num_requests, dtype=torch.int32) num_context_logits_prefix_sum = [ 0, @@ -483,9 +512,7 @@ def generation_requests(self) -> list[LlmRequest]: else [] ), ] - draft_len_req1 = draft_len - draft_len_req2 = draft_len + 1 # test with different draft lens - req_num_generation_steps = [ + expected_req_num_generation_steps = [ *( [ 1, # context req. 1 @@ -504,12 +531,20 @@ def generation_requests(self) -> list[LlmRequest]: else [] ), ] - req_num_generation_steps_tensor = torch.tensor(req_num_generation_steps, dtype=torch.int32) - num_logits_to_keep = cast(int, req_num_generation_steps_tensor.sum().item()) + expected_req_num_generation_steps_tensor = torch.tensor( + expected_req_num_generation_steps, dtype=torch.int32 + ) + + expected_req_offsets = torch.cumsum(expected_req_num_generation_steps_tensor, dim=0).roll(1) + expected_req_offsets[0] = 0 + + # num_logits_to_keep = cast(int, req_num_generation_steps_tensor.sum().item()) generation_requests_total_steps = (draft_len_req1 + 1) + ( draft_len_req2 + 1 ) # cf. req_num_generation_steps + vocab_size = 12 + num_total_steps = num_context_logits_prefix_sum[-1] + generation_requests_total_steps all_logits = torch.empty((num_total_steps, vocab_size)) @@ -537,8 +572,14 @@ def generation_requests(self) -> list[LlmRequest]: ), # gen logits from gen. req. 2 ] + expected_logits = all_logits[expected_logit_indices] + @dataclass class UutResult: + req_num_generated_tokens: torch.Tensor + req_num_beams: torch.Tensor + req_num_steps: torch.Tensor + req_offsets: torch.Tensor selected_logits: torch.Tensor @dataclass @@ -548,22 +589,36 @@ class UutResultWrapper: res = UutResultWrapper() def _uut(res=res): - selected_logits = TorchSampler._select_generated_logits( + ( + sampling_requests_metadata, + selected_logits, + ) = TorchSampler._select_generated_logits( cast(ScheduledRequests, ScheduledRequestsMock()), all_logits_cuda, - req_num_generation_steps=req_num_generation_steps_tensor, num_context_logits_prefix_sum=num_context_logits_prefix_sum, - generation_requests_total_steps=generation_requests_total_steps, - num_logits_to_keep=num_logits_to_keep, ) - res.result = UutResult(selected_logits=selected_logits) + res.result = UutResult( + req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens, + req_num_beams=sampling_requests_metadata.req_num_beams, + req_num_steps=sampling_requests_metadata.req_num_steps, + req_offsets=sampling_requests_metadata.req_offsets, + selected_logits=selected_logits, + ) yield _uut - # Check logits + # Check results assert res.result is not None - selected_logits = res.result.selected_logits - torch.testing.assert_close(selected_logits.to("cpu"), all_logits[expected_logit_indices]) + + torch.testing.assert_close( + res.result.req_num_generated_tokens.to("cpu"), expected_req_num_generation_steps_tensor + ) + torch.testing.assert_close(res.result.req_num_beams.to("cpu"), expected_req_num_beams) + torch.testing.assert_close( + res.result.req_num_steps.to("cpu"), expected_req_num_generation_steps_tensor + ) + torch.testing.assert_close(res.result.req_offsets.to("cpu"), expected_req_offsets) + torch.testing.assert_close(res.result.selected_logits.to("cpu"), expected_logits) _run_test_with_warmup(_test_runner, max_sync_s=0.3)