Skip to content

Commit 7220c02

Browse files
committed
update
1 parent 36f17d8 commit 7220c02

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

llama/generation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def generate(
7474

7575
temperature_tensor = torch.tensor(temperature).to(device)
7676
top_p_tensor = torch.tensor(top_p).to(device)
77+
with_temp = temperature > 0
7778

7879
cache_kvs = self.model.cache_kvs
7980
xm.mark_step()
@@ -97,7 +98,7 @@ def generate(
9798
= self._generate_one_token_fn(
9899
tokens, input_tokens, input_text_mask, cur_pos_tensor,
99100
input_pos_tensor, output_pos_tensor, cache_kvs,
100-
temperature_tensor, top_p_tensor, temperature > 0
101+
temperature_tensor, top_p_tensor, with_temp
101102
)
102103
xm.mark_step()
103104

@@ -109,7 +110,7 @@ def generate(
109110
= self._generate_one_token_fn(
110111
tokens, input_tokens, input_text_mask, cur_pos_tensor,
111112
input_pos_tensor, output_pos_tensor, cache_kvs,
112-
temperature_tensor, top_p_tensor, temperature > 0
113+
temperature_tensor, top_p_tensor, with_temp
113114
)
114115
xm.mark_step()
115116
self.model.cache_kvs = cache_kvs

0 commit comments

Comments
 (0)