Skip to content

Commit ddb7a5e

Browse files
committed
recover tmp changes
1 parent 7220c02 commit ddb7a5e

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

example_xla.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def main(
9898
ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads
9999
)
100100

101+
prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)]
101102
# prompts = [
102103
# For these prompts, the expected answer is the natural continuation of the prompt
103104
# "I believe the meaning of life is",
@@ -125,19 +126,10 @@ def main(
125126
#
126127
#cheese =>""",
127128
# ]
128-
129-
pairs = []
130-
for l in [1500]:
131-
for t in [0.1, 0.5, 0]:
132-
for p in [0.8, 0.9]:
133-
pairs.append([l, t, p])
134-
135-
for l, t, p in pairs:
136-
print(l, t, p)
137-
prompts = [generator.tokenizer.decode([8]*l) for _ in range(max_batch_size)]
129+
for _ in range(2):
138130
with torch.no_grad():
139131
results = generator.generate(
140-
prompts, max_gen_len=max_gen_len, temperature=t, top_p=p
132+
prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
141133
)
142134

143135
for result in results:

0 commit comments

Comments
 (0)