Skip to content

Commit fb523e3

Browse files
committed
target model
Signed-off-by: Enwei Zhu <[email protected]>
1 parent f08286c commit fb523e3

File tree

5 files changed

+124
-47
lines changed

5 files changed

+124
-47
lines changed

tensorrt_llm/_torch/pyexecutor/grammar_matcher.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,19 @@ class GrammarMatcher(ABC):
1616
def accept_token(self, token_id: int) -> bool:
1717
pass
1818

19+
@abstractmethod
20+
def rollback(self, num_tokens: int) -> None:
21+
pass
22+
1923
@abstractmethod
2024
def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
2125
index: int) -> None:
2226
pass
2327

28+
@abstractmethod
29+
def is_terminated(self) -> bool:
30+
pass
31+
2432

2533
class GrammarMatcherFactory(ABC):
2634

@@ -39,15 +47,23 @@ def __init__(self, matcher: xgrammar.GrammarMatcher):
3947
def accept_token(self, token_id: int) -> bool:
4048
return self._matcher.accept_token(token_id)
4149

50+
def rollback(self, num_tokens: int) -> None:
51+
self._matcher.rollback(num_tokens)
52+
4253
def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
4354
index: int) -> None:
4455
self._matcher.fill_next_token_bitmask(next_token_bitmask, index)
4556

57+
def is_terminated(self) -> bool:
58+
return self._matcher.is_terminated()
59+
4660

4761
class XGrammarMatcherFactory(GrammarMatcherFactory):
4862

49-
def __init__(self, guided_decoding_config: GuidedDecodingConfig,
50-
vocab_size_padded: int):
63+
def __init__(self,
64+
guided_decoding_config: GuidedDecodingConfig,
65+
vocab_size_padded: int,
66+
max_num_draft_tokens: int = 0):
5167
super().__init__()
5268
if guided_decoding_config.tokenizer_str is not None:
5369
metadata = xgrammar.TokenizerInfo._detect_metadata_from_hf(
@@ -72,6 +88,7 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig,
7288
cache_enabled=True,
7389
cache_limit_bytes=cache_limit_bytes,
7490
)
91+
self.max_num_draft_tokens = max_num_draft_tokens
7592

7693
def create(self,
7794
guided_decoding_params: GuidedDecodingParams) -> XGrammarMatcher:
@@ -106,7 +123,8 @@ def create(self,
106123
case _:
107124
raise ValueError(f"Unsupported guide type: {guide_type}.")
108125

109-
matcher = xgrammar.GrammarMatcher(compiled_grammar)
126+
matcher = xgrammar.GrammarMatcher(
127+
compiled_grammar, max_rollback_tokens=self.max_num_draft_tokens)
110128
return XGrammarMatcher(matcher)
111129

112130

@@ -121,12 +139,19 @@ def accept_token(self, token_id: int) -> bool:
121139
self._check_err()
122140
return result
123141

142+
def rollback(self, num_tokens: int) -> None:
143+
self._matcher.rollback(num_tokens)
144+
self._check_err()
145+
124146
def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
125147
index: int) -> None:
126148
llguidance.torch.fill_next_token_bitmask(self._matcher,
127149
next_token_bitmask, index)
128150
self._check_err()
129151

152+
def is_terminated(self) -> bool:
153+
return self._matcher.is_stopped()
154+
130155
def _check_err(self) -> None:
131156
if self._matcher.is_error():
132157
raise ValueError(

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,25 @@
1313
class GuidedDecoder:
1414
bitmask_dtype = torch.int32
1515

16-
def __init__(self, guided_decoding_config: GuidedDecodingConfig,
17-
max_num_sequences: int, vocab_size_padded: int):
16+
def __init__(self,
17+
guided_decoding_config: GuidedDecodingConfig,
18+
max_num_sequences: int,
19+
vocab_size_padded: int,
20+
max_num_draft_tokens: int = 0):
1821
self.guided_decoding_backend = guided_decoding_config.backend
1922
self.max_num_sequences = max_num_sequences
2023
self.vocab_size_padded = vocab_size_padded
24+
self.max_num_draft_tokens = max_num_draft_tokens
2125

2226
self.grammar_matcher_factory: Optional[GrammarMatcherFactory] = None
2327
self.grammar_matchers: List[
2428
Optional[GrammarMatcher]] = [None] * self.max_num_sequences
2529

2630
if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR:
2731
self.grammar_matcher_factory = XGrammarMatcherFactory(
28-
guided_decoding_config, vocab_size_padded)
32+
guided_decoding_config,
33+
vocab_size_padded,
34+
max_num_draft_tokens=max_num_draft_tokens)
2935
elif self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE:
3036
self.grammar_matcher_factory = LLGuidanceMatcherFactory(
3137
guided_decoding_config, vocab_size_padded)
@@ -35,14 +41,16 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig,
3541
)
3642

3743
self.bitmask = torch.empty(self.max_num_sequences,
44+
self.max_num_draft_tokens + 1,
3845
self.bitmask_size,
3946
dtype=self.bitmask_dtype,
4047
device='cuda')
4148
self.bitmask_host = torch.empty(self.max_num_sequences,
49+
self.max_num_draft_tokens + 1,
4250
self.bitmask_size,
4351
dtype=self.bitmask_dtype,
4452
pin_memory=True)
45-
53+
self.num_guided_tokens: List[int] = [0] * self.max_num_sequences
4654
self._stream = torch.cuda.Stream()
4755

4856
@property
@@ -52,44 +60,77 @@ def bitmask_size(self) -> int:
5260
@nvtx_range("GuidedDecoder.build")
5361
def build(self, scheduled_requests: ScheduledRequests) -> None:
5462
for llm_req in scheduled_requests.all_requests():
55-
if llm_req.guided_decoding_params is None:
56-
continue
57-
slot = llm_req.py_seq_slot
58-
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
59-
self.grammar_matchers[
60-
slot] = self.grammar_matcher_factory.create(
61-
llm_req.guided_decoding_params)
63+
slot: int = llm_req.py_seq_slot
64+
require_guided: bool = True
6265

63-
elif llm_req.is_generation_in_progress_state:
64-
# The request is in a generation forward step.
65-
# Currently, guided decoding does not support with beam search.
66-
self.grammar_matchers[slot].accept_token(
67-
llm_req.get_last_tokens(0))
66+
if llm_req.guided_decoding_params is None:
67+
require_guided = False
6868
else:
69-
continue
70-
71-
# Fill the bitmask on host and asynchorously copy to device.
72-
self.grammar_matchers[slot].fill_next_token_bitmask(
73-
self.bitmask_host, slot)
74-
with torch.cuda.stream(self._stream):
75-
self.bitmask[slot].copy_(self.bitmask_host[slot],
76-
non_blocking=True)
69+
if llm_req.is_context_init_state and llm_req.is_last_context_chunk:
70+
# The request is in the last chunk of a context forward step.
71+
matcher = self.grammar_matcher_factory.create(
72+
llm_req.guided_decoding_params)
73+
self.grammar_matchers[slot] = matcher
74+
elif llm_req.is_generation_in_progress_state:
75+
# The request is in a generation forward step.
76+
matcher = self.grammar_matchers[slot]
77+
# Rollback the grammar matcher to the last accepted token.
78+
num_rollback_tokens = self.num_guided_tokens[slot] - (
79+
1 + llm_req.py_num_accepted_draft_tokens)
80+
assert num_rollback_tokens >= 0
81+
matcher.rollback(num_rollback_tokens)
82+
83+
# Currently, guided decoding does not support with beam search.
84+
accepted = matcher.accept_token(llm_req.get_last_tokens(0))
85+
# TODO: Make this an error response.
86+
if not accepted:
87+
raise ValueError(
88+
f"Failed to accept new token: {llm_req.get_last_tokens(0)}."
89+
)
90+
else:
91+
require_guided = False
92+
93+
num_guided_tokens: int = 0
94+
if require_guided:
95+
if not matcher.is_terminated():
96+
matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0)
97+
num_guided_tokens += 1
98+
# Process draft tokens
99+
for i, tid in enumerate(llm_req.py_draft_tokens, 1):
100+
accepted = matcher.accept_token(tid)
101+
if matcher.is_terminated():
102+
matcher.rollback(1)
103+
accepted = False
104+
if accepted:
105+
matcher.fill_next_token_bitmask(self.bitmask_host[slot],
106+
i)
107+
num_guided_tokens += 1
108+
else:
109+
break
110+
111+
self.num_guided_tokens[slot] = num_guided_tokens
112+
if num_guided_tokens > 0:
113+
with torch.cuda.stream(self._stream):
114+
self.bitmask[slot, :num_guided_tokens].copy_(
115+
self.bitmask_host[slot, :num_guided_tokens],
116+
non_blocking=True)
77117

78118
@nvtx_range("GuidedDecoder.execute")
79119
def execute(self, scheduled_requests: ScheduledRequests,
80120
logits: torch.Tensor) -> None:
81-
assert logits.size(0) == len(scheduled_requests.context_requests) + len(
82-
scheduled_requests.generation_requests)
83121
torch.cuda.current_stream().wait_stream(self._stream)
84122

85123
batched_logits, batched_bitmask = [], []
86-
for i, llm_req in enumerate(scheduled_requests.all_requests()):
87-
if llm_req.guided_decoding_params is None:
88-
continue
89-
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
90-
continue
91-
batched_logits.append(logits[i])
92-
batched_bitmask.append(self.bitmask[llm_req.py_seq_slot])
124+
offset = 0
125+
for llm_req in scheduled_requests.all_requests():
126+
slot: int = llm_req.py_seq_slot
127+
num_guided_tokens: int = self.num_guided_tokens[slot]
128+
for i in range(num_guided_tokens):
129+
batched_logits.append(logits[offset + i])
130+
batched_bitmask.append(self.bitmask[slot, i])
131+
offset += len(llm_req.py_draft_tokens) + 1
132+
133+
assert offset == logits.size(0)
93134

94135
if len(batched_logits) > 0:
95136
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def __init__(
301301
self.py_orig_prompt_len = self.orig_prompt_len
302302
self.py_max_new_tokens = self.max_new_tokens
303303
self.py_batch_idx = None
304+
self.py_draft_pages_allocated = 0
304305
self.py_rewind_len = 0
305306
self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens
306307
self.py_last_draft_tokens = None

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,15 @@ def create_py_executor(
332332

333333
guided_decoder: Optional[GuidedDecoder] = None
334334
if executor_config.guided_decoding_config is not None:
335-
if spec_config is not None:
336-
raise ValueError(
337-
"Guided decoding is not supported with speculative decoding.")
338335
if mapping.is_last_pp_rank():
336+
max_num_draft_tokens = 0
337+
if spec_config is not None:
338+
max_num_draft_tokens = spec_config.max_draft_len
339339
guided_decoder = GuidedDecoder(
340340
executor_config.guided_decoding_config,
341341
executor_config.max_batch_size,
342-
model_engine.model.vocab_size_padded)
342+
model_engine.model.vocab_size_padded,
343+
max_num_draft_tokens=max_num_draft_tokens)
343344

344345
resources = {}
345346
estimating_kv_cache = False

tensorrt_llm/evaluate/json_mode_eval.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import click
2020
import datasets
21+
import jsonschema
2122
import numpy as np
2223

2324
from .. import LLM as PyTorchLLM
@@ -65,23 +66,31 @@ def generate_samples(self) -> Iterable[tuple]:
6566
sampling_args = {
6667
"guided_decoding": GuidedDecodingParams(json=schema)
6768
}
68-
yield sample["prompt"], sampling_args, sample["completion"]
69+
yield sample["prompt"], sampling_args, sample["completion"], sample[
70+
"schema"]
6971

70-
def compute_score(self, outputs: List[RequestOutput],
71-
references: List[str]) -> float:
72-
all_corrections = []
73-
for output, ref in zip(outputs, references):
72+
def compute_score(self, outputs: List[RequestOutput], references: List[str],
73+
schemas: List[str]) -> float:
74+
all_corrections, all_grammar_corrections = [], []
75+
for output, ref, schema in zip(outputs, references, schemas):
7476
try:
7577
output_json = json.loads(output.outputs[0].text)
76-
except json.JSONDecodeError:
78+
jsonschema.validate(output_json, json.loads(schema))
79+
except (json.JSONDecodeError, jsonschema.ValidationError):
7780
all_corrections.append(False)
78-
continue
81+
all_grammar_corrections.append(False)
82+
else:
83+
all_grammar_corrections.append(True)
7984
ref_json = json.loads(ref)
8085
all_corrections.append(output_json == ref_json)
8186

8287
acc = np.mean(all_corrections) * 100
8388
logger.info(
8489
f"JSON Mode Eval accuracy: {acc:.2f} ({len(all_corrections)})")
90+
grammar_acc = np.mean(all_grammar_corrections) * 100
91+
logger.info(
92+
f"JSON Mode Eval grammar accuracy: {grammar_acc:.2f} ({len(all_grammar_corrections)})"
93+
)
8594
return acc
8695

8796
@click.command("json_mode_eval")

0 commit comments

Comments
 (0)