Skip to content

Commit 309798d

Browse files
swolchokmalfet
authored andcommitted
zero temp sampling (pytorch#277)
Add special case for zero-temperature sampling. For stories15M on my devserver, seems to improve tokens/sec as follows: before: 189, 180, 166 after: 264, 285, 285
1 parent 68ee0b0 commit 309798d

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

generate.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
101101
return probs
102102

103103

104-
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
104+
def sample(logits, need_probs: bool, temperature: float = 1.0,top_k: Optional[int] = None):
105+
if temperature == 0 and not need_probs:
106+
_, idx_next = torch.topk(logits, k=1, dim=-1)
107+
idx_next = idx_next.squeeze(dim=(0, 1))
108+
return (idx_next, None)
105109
probs = logits_to_probs(logits[0, -1], temperature, top_k)
106110
idx_next = multinomial_sample_one_no_sync(probs)
107111
return idx_next, probs
@@ -129,23 +133,24 @@ def prefill(
129133
# input_pos: [B, S]
130134
logits = model(x, input_pos)
131135

132-
return sample(logits, **sampling_kwargs)[0]
136+
return sample(logits, need_probs=False, **sampling_kwargs)[0]
133137

134138

135139
def decode_one_token(
136-
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
137-
) -> Tuple[torch.Tensor, torch.Tensor]:
140+
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, need_probs: bool, **sampling_kwargs
141+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
138142
# input_pos: [B, 1]
139143
assert input_pos.shape[-1] == 1
140144
logits = model(x, input_pos)
141-
return sample(logits, **sampling_kwargs)
145+
return sample(logits, need_probs=need_probs, **sampling_kwargs)
142146

143147

144148
def decode_n_tokens(
145149
model: Transformer,
146150
cur_token: torch.Tensor,
147151
input_pos: torch.Tensor,
148152
num_new_tokens: int,
153+
need_probs: bool,
149154
callback=lambda _: _,
150155
**sampling_kwargs,
151156
):
@@ -154,12 +159,13 @@ def decode_n_tokens(
154159
# Actually better for Inductor to codegen attention here
155160
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
156161
next_token, next_prob = decode_one_token(
157-
model, cur_token, input_pos, **sampling_kwargs
162+
model, cur_token, input_pos, need_probs=need_probs, **sampling_kwargs
158163
)
159164
input_pos += 1
160165
new_tokens.append(next_token.clone())
161166
callback(new_tokens[-1])
162-
new_probs.append(next_prob.clone())
167+
if need_probs:
168+
new_probs.append(next_prob.clone())
163169
cur_token = next_token.view(1, -1)
164170

165171
return new_tokens, new_probs
@@ -187,6 +193,7 @@ def speculative_decode(
187193
cur_token.view(1, -1),
188194
orig_input_pos.clone(),
189195
speculate_k,
196+
need_probs=True,
190197
**sampling_kwargs,
191198
)
192199

@@ -301,6 +308,7 @@ def generate(
301308
input_pos,
302309
max_new_tokens - 1,
303310
callback=callback,
311+
need_probs = False,
304312
**sampling_kwargs,
305313
)
306314
seq[T + 1 :] = torch.cat(generated_tokens)

0 commit comments

Comments
 (0)