Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def main(
ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads
)

prompts = [generator.tokenizer.decode([8]*prompt_len)]
prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)]
# prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
# "I believe the meaning of life is",
Expand Down
40 changes: 31 additions & 9 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,57 @@ def generate(

prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])
assert min_prompt_size >= 1 and max_prompt_size < params.max_seq_len

total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

tokens = torch.full((params.max_batch_size, total_len), self.tokenizer.pad_id).long()
tokens = torch.full((params.max_batch_size, params.max_seq_len), self.tokenizer.pad_id).long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
device = xm.xla_device()
tokens = tokens.to(device)
input_text_mask = tokens != self.tokenizer.pad_id

start_pos = 1
cur_pos_tensor = torch.tensor(start_pos).to(device)
input_pos_tensor = torch.arange(0, start_pos).to(device)
output_pos_tensor = cur_pos_tensor - 1
input_tokens = tokens.index_select(1, input_pos_tensor)
cache_kvs = self.model.cache_kvs
xm.mark_step(wait=True)
xm.mark_step()

decoding_start_time = time.time()
for _ in range(start_pos, total_len):
prev_pos = 0
scale_factor = 8
while prev_pos < min_prompt_size:
section_len = 1
while prev_pos + section_len * scale_factor <= min_prompt_size:
section_len *= scale_factor
cur_pos = prev_pos + section_len
print(f"Processing prompt pos [{prev_pos}, {cur_pos}), section length {section_len}")
cur_pos_tensor = torch.tensor(cur_pos).to(device)
input_pos_tensor = torch.arange(prev_pos, cur_pos).to(device)
output_pos_tensor = cur_pos_tensor - 1
input_tokens = tokens.index_select(1, input_pos_tensor)
xm.mark_step()

tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \
= self._generate_one_token_fn(
tokens, input_tokens, input_text_mask, cur_pos_tensor,
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p
)
xm.mark_step()

prev_pos = cur_pos

assert cur_pos_tensor.item() == prev_pos + 1
for _ in range(prev_pos + 1, total_len):
tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \
= self._generate_one_token_fn(
tokens, input_tokens, input_text_mask, cur_pos_tensor,
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p
)
xm.mark_step()
self.model.cache_kvs = cache_kvs
print(f"Decoded {total_len-1} tokens in {time.time() - decoding_start_time:.5f} seconds")
print(f"Processed prompts with {min_prompt_size} to {max_prompt_size} tokens, and generated {total_len - max_prompt_size} tokens")
print(f"Totally decoded {total_len - 1} tokens in {time.time() - decoding_start_time:.5f} seconds")

decoded = []
for i, t in enumerate(tokens.tolist()):
Expand Down