We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ddb7a5e commit 8ab9f48Copy full SHA for 8ab9f48
llama/generation.py
@@ -72,6 +72,8 @@ def generate(
72
tokens = tokens.to(device)
73
input_text_mask = tokens != self.tokenizer.pad_id
74
75
+ # Passing tensors instead of floats into self._generate_one_token_fn,
76
+ # so that different values would not trigger compilations of new graphs
77
temperature_tensor = torch.tensor(temperature).to(device)
78
top_p_tensor = torch.tensor(top_p).to(device)
79
with_temp = temperature > 0
0 commit comments