@@ -101,7 +101,11 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
101
101
return probs
102
102
103
103
104
- def sample (logits , temperature : float = 1.0 , top_k : Optional [int ] = None ):
104
+ def sample (logits , need_probs : bool , temperature : float = 1.0 ,top_k : Optional [int ] = None ):
105
+ if temperature == 0 and not need_probs :
106
+ _ , idx_next = torch .topk (logits , k = 1 , dim = - 1 )
107
+ idx_next = idx_next .squeeze (dim = (0 , 1 ))
108
+ return (idx_next , None )
105
109
probs = logits_to_probs (logits [0 , - 1 ], temperature , top_k )
106
110
idx_next = multinomial_sample_one_no_sync (probs )
107
111
return idx_next , probs
@@ -129,23 +133,24 @@ def prefill(
129
133
# input_pos: [B, S]
130
134
logits = model (x , input_pos )
131
135
132
- return sample (logits , ** sampling_kwargs )[0 ]
136
+ return sample (logits , need_probs = False , ** sampling_kwargs )[0 ]
133
137
134
138
135
139
def decode_one_token (
136
- model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , ** sampling_kwargs
137
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
140
+ model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , need_probs : bool , ** sampling_kwargs
141
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
138
142
# input_pos: [B, 1]
139
143
assert input_pos .shape [- 1 ] == 1
140
144
logits = model (x , input_pos )
141
- return sample (logits , ** sampling_kwargs )
145
+ return sample (logits , need_probs = need_probs , ** sampling_kwargs )
142
146
143
147
144
148
def decode_n_tokens (
145
149
model : Transformer ,
146
150
cur_token : torch .Tensor ,
147
151
input_pos : torch .Tensor ,
148
152
num_new_tokens : int ,
153
+ need_probs : bool ,
149
154
callback = lambda _ : _ ,
150
155
** sampling_kwargs ,
151
156
):
@@ -154,12 +159,13 @@ def decode_n_tokens(
154
159
# Actually better for Inductor to codegen attention here
155
160
with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
156
161
next_token , next_prob = decode_one_token (
157
- model , cur_token , input_pos , ** sampling_kwargs
162
+ model , cur_token , input_pos , need_probs = need_probs , ** sampling_kwargs
158
163
)
159
164
input_pos += 1
160
165
new_tokens .append (next_token .clone ())
161
166
callback (new_tokens [- 1 ])
162
- new_probs .append (next_prob .clone ())
167
+ if need_probs :
168
+ new_probs .append (next_prob .clone ())
163
169
cur_token = next_token .view (1 , - 1 )
164
170
165
171
return new_tokens , new_probs
@@ -187,6 +193,7 @@ def speculative_decode(
187
193
cur_token .view (1 , - 1 ),
188
194
orig_input_pos .clone (),
189
195
speculate_k ,
196
+ need_probs = True ,
190
197
** sampling_kwargs ,
191
198
)
192
199
@@ -301,6 +308,7 @@ def generate(
301
308
input_pos ,
302
309
max_new_tokens - 1 ,
303
310
callback = callback ,
311
+ need_probs = False ,
304
312
** sampling_kwargs ,
305
313
)
306
314
seq [T + 1 :] = torch .cat (generated_tokens )
0 commit comments