Skip to content

Commit 9222169

Browse files
authored
Merge pull request #23 from pytorch-tpu/liyanglu/tensorfy_temp_top_p
minor update
2 parents d2fb888 + 516351f commit 9222169

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

llama/generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def generate(
7474

7575
# Passing tensors instead of floats into self._generate_one_token_fn,
7676
# so that different values would not trigger compilations of new graphs
77-
temperature_tensor = torch.tensor(temperature).to(device)
78-
top_p_tensor = torch.tensor(top_p).to(device)
77+
temperature_tensor = torch.tensor(float(temperature)).to(device)
78+
top_p_tensor = torch.tensor(float(top_p)).to(device)
7979
with_temp = temperature > 0
8080

8181
cache_kvs = self.model.cache_kvs

0 commit comments

Comments
 (0)