We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 36f17d8 commit 7220c02Copy full SHA for 7220c02
llama/generation.py
@@ -74,6 +74,7 @@ def generate(
74
75
temperature_tensor = torch.tensor(temperature).to(device)
76
top_p_tensor = torch.tensor(top_p).to(device)
77
+ with_temp = temperature > 0
78
79
cache_kvs = self.model.cache_kvs
80
xm.mark_step()
@@ -97,7 +98,7 @@ def generate(
97
98
= self._generate_one_token_fn(
99
tokens, input_tokens, input_text_mask, cur_pos_tensor,
100
input_pos_tensor, output_pos_tensor, cache_kvs,
- temperature_tensor, top_p_tensor, temperature > 0
101
+ temperature_tensor, top_p_tensor, with_temp
102
)
103
104
@@ -109,7 +110,7 @@ def generate(
109
110
111
112
113
114
115
116
self.model.cache_kvs = cache_kvs
0 commit comments