Skip to content

Commit 516351f

Browse files
committed
minor update
1 parent 8ab9f48 commit 516351f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

llama/generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def generate(
7474

7575
# Passing tensors instead of floats into self._generate_one_token_fn,
7676
# so that different values would not trigger compilations of new graphs
77-
temperature_tensor = torch.tensor(temperature).to(device)
78-
top_p_tensor = torch.tensor(top_p).to(device)
77+
temperature_tensor = torch.tensor(float(temperature)).to(device)
78+
top_p_tensor = torch.tensor(float(top_p)).to(device)
7979
with_temp = temperature > 0
8080

8181
cache_kvs = self.model.cache_kvs

0 commit comments

Comments
 (0)