Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

@torch.no_grad()
def generate(
model: torch.nn.Module,
model: LLaMA,
idx: torch.Tensor,
max_new_tokens: int,
max_seq_length: int,
*,
max_seq_length: Optional[int] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
Expand All @@ -41,44 +42,49 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
if max_seq_length is None:
max_seq_length = min(T_new, model.config.block_size)

device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)

if idx.device.type == "xla":
import torch_xla.core.xla_model as xm

xm.mark_step()

# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:]
for _ in range(max_new_tokens):
x = idx.index_select(0, input_pos).view(1, -1)

# forward
logits = model(idx_cond.view(1, -1))
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[[-1]]] = -float("Inf")
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

# advance
input_pos = input_pos[-1:] + 1

if idx.device.type == "xla":
xm.mark_step()

# concatenate the new generation
# https://github.com/pytorch/pytorch/issues/101936
idx[t] = idx_next.item() if idx.device.type == "mps" else idx_next
idx = idx.index_copy(0, input_pos, idx_next)

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:t + 1] # include the EOS token
return idx[:input_pos] # include the EOS token

return idx

Expand Down Expand Up @@ -138,16 +144,10 @@ def main(
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_new_tokens,
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0

model.reset_cache()
print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
Expand Down
11 changes: 2 additions & 9 deletions generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,10 @@ def main(
prompt_length = encoded.size(0)

t0 = time.perf_counter()
y = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0

model.reset_cache()
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
print(output)
Expand Down
12 changes: 2 additions & 10 deletions generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import lightning as L
import torch
import torch.nn as nn

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -85,17 +84,10 @@ def main(
prompt_length = encoded.size(0)

t0 = time.perf_counter()
y = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0

model.reset_cache()
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
print(output)
Expand Down
69 changes: 8 additions & 61 deletions generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,68 +7,14 @@
import lightning as L
import torch

# support running without installing as a package
wd = Path(__file__).absolute().parent.parent
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice
from scripts.prepare_alpaca import generate_prompt

@torch.no_grad()
def generate(
model: torch.nn.Module,
idx: torch.Tensor,
max_new_tokens: int,
max_seq_length: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

The implementation of this function is modified from A. Karpathy's nanoGPT.

Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
eos_id: If specified, stop generating any more token once the <eos> token is triggered
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
empty[:T] = idx
idx = empty

# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:]

# forward
logits = model(idx_cond.view(1, -1))
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[[-1]]] = -float("Inf")

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)

# concatenate the new generation
# https://github.com/pytorch/pytorch/issues/101936
idx[t] = idx_next.item() if idx.device.type == "mps" else idx_next

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:t + 1] # include the EOS token

return idx
from generate import generate


def main(
Expand All @@ -79,7 +25,7 @@ def main(
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Optional[Path] = None,
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
tokenizer_path: Path = Path("../checkpoints/lit-llama/tokenizer.model"),
model_size: str = "7B",
quantize: Optional[str] = None,
) -> None:
Expand All @@ -100,7 +46,7 @@ def main(
``"gptq.int4"``: GPTQ 4-bit mode.
"""
if not checkpoint_path:
checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
checkpoint_path = Path(f"../checkpoints/lit-llama/{model_size}/lit-llama.pth")
assert checkpoint_path.is_file(), checkpoint_path
assert tokenizer_path.is_file(), tokenizer_path

Expand Down Expand Up @@ -140,6 +86,7 @@ def main(
)
t = time.perf_counter() - t0

model.reset_cache()
print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
Expand Down
Loading