Skip to content

Commit d2fb888

Browse files
authored
Merge pull request #22 from pytorch-tpu/liyanglu/tensorfy_temp_top_p
Turn `temperature` and `top_p` into tensors
2 parents f61383e + 8ab9f48 commit d2fb888

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

llama/generation.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer):
2020
backend="torchxla_trace_once", fullgraph=True)
2121

2222
def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_tensor,
23-
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p):
23+
input_pos_tensor, output_pos_tensor, cache_kvs,
24+
temperature_tensor, top_p_tensor, with_temp):
2425
logits, cache_kvs = self.model(input_tokens, input_pos_tensor, output_pos_tensor, cache_kvs)
25-
if temperature > 0:
26-
probs = torch.softmax(logits / temperature, dim=-1)
27-
next_token = sample_top_p(probs, top_p)
26+
if with_temp:
27+
probs = torch.softmax(logits / temperature_tensor, dim=-1)
28+
next_token = sample_top_p(probs, top_p_tensor)
2829
else:
2930
next_token = torch.argmax(logits, dim=-1)
3031
next_token = next_token.reshape(-1)
@@ -71,6 +72,12 @@ def generate(
7172
tokens = tokens.to(device)
7273
input_text_mask = tokens != self.tokenizer.pad_id
7374

75+
# Passing tensors instead of floats into self._generate_one_token_fn,
76+
# so that different values would not trigger compilations of new graphs
77+
temperature_tensor = torch.tensor(temperature).to(device)
78+
top_p_tensor = torch.tensor(top_p).to(device)
79+
with_temp = temperature > 0
80+
7481
cache_kvs = self.model.cache_kvs
7582
xm.mark_step()
7683

@@ -92,7 +99,8 @@ def generate(
9299
tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \
93100
= self._generate_one_token_fn(
94101
tokens, input_tokens, input_text_mask, cur_pos_tensor,
95-
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p
102+
input_pos_tensor, output_pos_tensor, cache_kvs,
103+
temperature_tensor, top_p_tensor, with_temp
96104
)
97105
xm.mark_step()
98106

@@ -103,7 +111,8 @@ def generate(
103111
tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \
104112
= self._generate_one_token_fn(
105113
tokens, input_tokens, input_text_mask, cur_pos_tensor,
106-
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p
114+
input_pos_tensor, output_pos_tensor, cache_kvs,
115+
temperature_tensor, top_p_tensor, with_temp
107116
)
108117
xm.mark_step()
109118
self.model.cache_kvs = cache_kvs
@@ -133,7 +142,7 @@ def generate(
133142
def sample_top_p(probs, p):
134143
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
135144
probs_sum = torch.cumsum(probs_sort, dim=-1)
136-
mask = probs_sum - probs_sort > p
145+
mask = (probs_sum - probs_sort) > p
137146
probs_sort = torch.where(mask, 0.0, probs_sort)
138147
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
139148
next_token = torch.multinomial(probs_sort, num_samples=1)

0 commit comments

Comments
 (0)