|
11 | 11 | import weakref |
12 | 12 | from collections import deque, namedtuple |
13 | 13 | from contextlib import contextmanager |
14 | | -from typing import Dict, List, Optional, Tuple, Union |
| 14 | +from typing import Dict, List, Optional, Union |
15 | 15 |
|
16 | 16 | import torch |
17 | 17 |
|
|
30 | 30 | from tensorrt_llm.logger import logger |
31 | 31 |
|
32 | 32 | from ..distributed import Distributed |
33 | | -from ..speculative.drafter import Drafter, create_drafter |
| 33 | +from ..speculative.drafter import Drafter |
34 | 34 | from .kv_cache_transceiver import KvCacheTransceiver |
35 | 35 | from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, |
36 | 36 | LlmResponse, executor_request_to_llm_request) |
@@ -305,7 +305,7 @@ def __init__(self, |
305 | 305 | if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): |
306 | 306 | self.event_loop = trace_func(self.event_loop) |
307 | 307 |
|
308 | | - if self.draft_model_engine is not None: |
| 308 | + if self.drafter is not None: |
309 | 309 | if self.event_loop.__name__ != self._executor_loop.__name__: |
310 | 310 | raise NotImplementedError( |
311 | 311 | "Drafting is not supported for selected executor loop. " |
@@ -918,8 +918,7 @@ def _executor_loop(self): |
918 | 918 |
|
919 | 919 | self._pad_attention_dp_dummy_request() |
920 | 920 |
|
921 | | - if self.draft_model_engine is not None or hasattr( |
922 | | - self, 'drafter') and self.drafter is not None: |
| 921 | + if self.drafter is not None: |
923 | 922 | self._prepare_draft_requests(self.active_requests) |
924 | 923 |
|
925 | 924 | scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( |
@@ -959,23 +958,9 @@ def _executor_loop(self): |
959 | 958 | scheduled_batch) |
960 | 959 |
|
961 | 960 | self.resource_manager.prepare_resources(scheduled_batch) |
962 | | - if self.draft_model_engine is not None and self.drafter is None: |
963 | | - spec_resource_manager = self.resource_manager.get_resource_manager( |
964 | | - ResourceManagerType.SPEC_RESOURCE_MANAGER) |
965 | | - self.drafter = create_drafter( |
966 | | - spec_decoding_mode=self.model_engine.spec_config. |
967 | | - spec_dec_mode, |
968 | | - spec_config=self.model_engine.spec_config, |
969 | | - draft_model_engine=self.draft_model_engine, |
970 | | - max_draft_tokens=self.max_draft_tokens, |
971 | | - draft_seq_slot_manager=self.draft_seq_slot_manager, |
972 | | - sampler=self.sampler, |
973 | | - resource_manager=self.resource_manager, |
974 | | - spec_resource_manager=spec_resource_manager, |
975 | | - ) |
976 | | - |
977 | 961 | if self.drafter is not None: |
978 | | - self.drafter.prepare_draft_tokens(scheduled_batch) |
| 962 | + self.drafter.prepare_draft_tokens( |
| 963 | + scheduled_batch, self.resource_manager) |
979 | 964 |
|
980 | 965 | if self.kv_cache_transceiver: |
981 | 966 | # For generation requests which have completed KV cache transfer |
@@ -1780,188 +1765,6 @@ def _update_requests(self, sample_state: SampleState): |
1780 | 1765 | logger.error(f"Encountered an error in sampling: {error_msg}") |
1781 | 1766 | self._handle_errors(error_msg) |
1782 | 1767 |
|
1783 | | - @nvtx_range("_prepare_draft_batch") |
1784 | | - def _prepare_draft_batch( |
1785 | | - self, scheduled_requests: ScheduledRequests |
1786 | | - ) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]: |
1787 | | - """ |
1788 | | - Prepares a batch for the draft model engine. Draft tokens are only produced |
1789 | | - for generation requests. |
1790 | | -
|
1791 | | - The requests are prepared as follows: |
1792 | | - 1. The first time the draft engine sees a request, it's a context request. |
1793 | | - 2. Otherwise, if draft tokens were accepted on the last target model decoding |
1794 | | - step, it's a chunked context request (we process all the accepted tokens together). |
1795 | | - 3. Otherwise, it's a generation request. |
1796 | | - """ |
1797 | | - try: |
1798 | | - draft_batch = ScheduledRequests() |
1799 | | - |
1800 | | - for request in scheduled_requests.generation_requests: |
1801 | | - if request.py_draft_pages_allocated == 0: |
1802 | | - # No space for draft tokens. |
1803 | | - continue |
1804 | | - |
1805 | | - # Stop drafting when we hit the max seqlen. We still need dummy draft |
1806 | | - # tokens attached to the requests to make sure everything works properly |
1807 | | - # with CUDA graph. These dummy tokens are already added by |
1808 | | - # _prepare_draft_requests to make the KV cache/scheduler aware of the fact |
1809 | | - # that we want to do spec decoding, so no need to do anything else here. |
1810 | | - # This makes the perf for this case suboptimal, but that's OK - this is |
1811 | | - # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. |
1812 | | - if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: |
1813 | | - continue |
1814 | | - |
1815 | | - num_draft_tokens = len( |
1816 | | - request.py_last_draft_tokens |
1817 | | - ) if request.py_last_draft_tokens is not None else 0 |
1818 | | - request.py_draft_tokens = [] |
1819 | | - |
1820 | | - num_accepted_tokens = request.py_num_accepted_draft_tokens |
1821 | | - num_rejected_tokens = num_draft_tokens - num_accepted_tokens |
1822 | | - assert num_rejected_tokens >= 0 |
1823 | | - |
1824 | | - spec_config = self.model_engine.spec_config |
1825 | | - beam_idx = 0 |
1826 | | - input_tokens = spec_config.get_draft_model_prompt( |
1827 | | - request.get_tokens()[beam_idx]) |
1828 | | - |
1829 | | - def create_new_request(input_tokens): |
1830 | | - return LlmRequest( |
1831 | | - request_id=request.py_request_id, |
1832 | | - max_new_tokens=request.py_max_new_tokens, |
1833 | | - input_tokens=input_tokens, |
1834 | | - sampling_config=request.sampling_config, |
1835 | | - return_perf_metrics=request.return_perf_metrics, |
1836 | | - is_streaming=False, |
1837 | | - is_draft=True) |
1838 | | - |
1839 | | - if request.max_beam_num_tokens - 1 == request.py_prompt_len: |
1840 | | - # This is the first time the draft model is seeing this request. |
1841 | | - # Prepare a context request. We discard the first token and take |
1842 | | - # the newly decoded one - this is the convention for EAGLE 2 and 3. |
1843 | | - new_request = create_new_request(input_tokens) |
1844 | | - draft_batch.context_requests.append(new_request) |
1845 | | - elif num_accepted_tokens == 0: |
1846 | | - new_request = create_new_request(input_tokens[:-1]) |
1847 | | - # Explicitly add the last token so get_last_tokens() returns |
1848 | | - # the right value |
1849 | | - new_request.add_new_token(input_tokens[-1], beam_idx) |
1850 | | - new_request.state = LlmRequestState.GENERATION_IN_PROGRESS |
1851 | | - draft_batch.generation_requests.append(new_request) |
1852 | | - else: |
1853 | | - new_request = create_new_request(input_tokens) |
1854 | | - new_request.context_chunk_size = num_accepted_tokens + 1 |
1855 | | - new_request.context_current_position = len( |
1856 | | - input_tokens) - num_accepted_tokens - 1 |
1857 | | - new_request.context_chunk_size = num_accepted_tokens + 1 |
1858 | | - new_request.context_current_position = len( |
1859 | | - input_tokens) - num_accepted_tokens - 1 |
1860 | | - |
1861 | | - draft_batch.context_requests.append(new_request) |
1862 | | - |
1863 | | - new_request.py_stop_words_list = request.py_stop_words_list |
1864 | | - |
1865 | | - return draft_batch |
1866 | | - |
1867 | | - except Exception as e: |
1868 | | - traceback.print_exc() |
1869 | | - error_msg = str(e) |
1870 | | - logger.error(f"Encountered an error in decode: {error_msg}") |
1871 | | - self._handle_errors(error_msg) |
1872 | | - |
1873 | | - @nvtx_range("_prepare_draft_tokens") |
1874 | | - def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests): |
1875 | | - if not self.draft_model_engine: |
1876 | | - raise ValueError("Draft model engine is not set") |
1877 | | - |
1878 | | - try: |
1879 | | - draft_batch = self._prepare_draft_batch(scheduled_requests) |
1880 | | - |
1881 | | - if draft_batch.batch_size == 0: |
1882 | | - return |
1883 | | - self.draft_seq_slot_manager.prepare_resources(draft_batch) |
1884 | | - |
1885 | | - req_id_to_old_request = { |
1886 | | - req.py_request_id: req |
1887 | | - for req in scheduled_requests.all_requests() |
1888 | | - } |
1889 | | - |
1890 | | - # Disable cuda graph for the 1st draft model forward |
1891 | | - if self.model_engine.spec_config.spec_dec_mode.needs_kv_cache_recompute( |
1892 | | - ): |
1893 | | - with self.draft_model_engine.no_cuda_graph(): |
1894 | | - outputs = self.draft_model_engine.forward( |
1895 | | - draft_batch, self.resource_manager) |
1896 | | - else: |
1897 | | - outputs = self.draft_model_engine.forward( |
1898 | | - draft_batch, self.resource_manager) |
1899 | | - if hasattr(self.draft_model_engine.model.model, 'd2t'): |
1900 | | - outputs['d2t'] = self.draft_model_engine.model.model.d2t.data |
1901 | | - |
1902 | | - sample_state = self._sample_async(draft_batch, outputs) |
1903 | | - previous_batch = sample_state |
1904 | | - |
1905 | | - self._update_request_states(draft_batch) |
1906 | | - |
1907 | | - def _process_decoded_tokens(draft_batch): |
1908 | | - new_requests = [] |
1909 | | - for req in draft_batch.all_requests(): |
1910 | | - target_model_req = req_id_to_old_request[req.py_request_id] |
1911 | | - target_model_req.py_draft_tokens.append( |
1912 | | - req.get_last_tokens(0)) |
1913 | | - if req.state != LlmRequestState.GENERATION_COMPLETE and len( |
1914 | | - target_model_req.py_draft_tokens |
1915 | | - ) < target_model_req.py_draft_pages_allocated: |
1916 | | - new_requests.append(req) |
1917 | | - else: |
1918 | | - self.draft_seq_slot_manager.free_resources(req) |
1919 | | - |
1920 | | - return new_requests |
1921 | | - |
1922 | | - # The TRTLLM attention kernels cannot handle generation requests with |
1923 | | - # different seqlens. No issues with flashinfer, should we look into removing |
1924 | | - # this? Just needs proper kernel support. |
1925 | | - def _pad_to_max_draft_tokens(): |
1926 | | - for req in scheduled_requests.generation_requests: |
1927 | | - max_draft_len = self.max_draft_len |
1928 | | - num_draft_tokens = len(req.py_draft_tokens) |
1929 | | - req.py_draft_tokens.extend( |
1930 | | - 0 for _ in range(max_draft_len - num_draft_tokens)) |
1931 | | - |
1932 | | - draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests |
1933 | | - draft_batch.context_requests = [] |
1934 | | - |
1935 | | - for i in range(self.max_draft_len - 1): |
1936 | | - if len(draft_batch.generation_requests) == 0: |
1937 | | - break |
1938 | | - |
1939 | | - outputs = self.draft_model_engine.forward( |
1940 | | - draft_batch, |
1941 | | - self.resource_manager, |
1942 | | - new_tensors_device=previous_batch.device) |
1943 | | - |
1944 | | - if hasattr(self.draft_model_engine.model.model, 'd2t'): |
1945 | | - outputs[ |
1946 | | - 'd2t'] = self.draft_model_engine.model.model.d2t.data |
1947 | | - sample_state = self._sample_async(draft_batch, outputs) |
1948 | | - self._update_request_states(draft_batch) |
1949 | | - self._update_requests(previous_batch) |
1950 | | - new_requests = _process_decoded_tokens( |
1951 | | - previous_batch.scheduled_requests) |
1952 | | - draft_batch.generation_requests = new_requests |
1953 | | - previous_batch = sample_state |
1954 | | - self._update_requests(previous_batch) |
1955 | | - new_requests = _process_decoded_tokens( |
1956 | | - previous_batch.scheduled_requests) |
1957 | | - _pad_to_max_draft_tokens() |
1958 | | - |
1959 | | - except Exception as e: |
1960 | | - traceback.print_exc() |
1961 | | - error_msg = str(e) |
1962 | | - logger.error(f"Encountered an error in decode: {error_msg}") |
1963 | | - self._handle_errors(error_msg) |
1964 | | - |
1965 | 1768 | def _handle_errors(self, error_msg: Optional[str] = None): |
1966 | 1769 | error_responses = {} |
1967 | 1770 | error_msg = error_msg or "error" |
|
0 commit comments