Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

@torch.no_grad()
def generate(
model: torch.nn.Module,
model: LLaMA,
idx: torch.Tensor,
max_new_tokens: int,
max_seq_length: int,
*,
max_seq_length: Optional[int] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
Expand All @@ -41,35 +42,43 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
if max_seq_length is None:
max_seq_length = min(T_new, model.config.block_size)
# otherwise this would use more memory than necessary
assert max_seq_length <= T_new

device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)

# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:]
for _ in range(max_new_tokens):
x = idx.index_select(0, input_pos).view(1, -1)

# forward
logits = model(idx_cond.view(1, -1))
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[[-1]]] = -float("Inf")
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

# advance
input_pos = input_pos[-1:] + 1

# concatenate the new generation
idx[t] = idx_next
idx = idx.index_copy(0, input_pos, idx_next)

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:t + 1] # include the EOS token
return idx[:input_pos] # include the EOS token

return idx

Expand Down Expand Up @@ -129,14 +138,7 @@ def main(
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_new_tokens,
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0

print(tokenizer.decode(y))
Expand Down
142 changes: 108 additions & 34 deletions lit_llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# mypy: ignore-errors
import math
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -15,6 +15,10 @@
from lit_llama.utils import find_multiple


RoPECache = Tuple[torch.Tensor, torch.Tensor]
KVCache = Tuple[torch.Tensor, torch.Tensor]


@dataclass
class LLaMAConfig:
block_size: int = 2048
Expand Down Expand Up @@ -56,23 +60,59 @@ def __init__(self, config: LLaMAConfig) -> None:
)
)

self.rope_cache: Optional[RoPECache] = None
self.mask_cache: Optional[torch.Tensor] = None
self.kv_caches: List[KVCache] = []


def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))

def forward(self, idx: torch.Tensor) -> torch.Tensor:
_, t = idx.size()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
def forward(
self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
B, T = idx.size()

block_size = self.config.block_size
if max_seq_length is None:
max_seq_length = block_size
assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"

# forward the LLaMA model itself
if self.rope_cache is None:
self.rope_cache = self.build_rope_cache(idx)
if self.mask_cache is None:
self.mask_cache = self.build_mask_cache(idx)

if input_pos is not None:
rope = self.rope_cache.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, :max_seq_length]
else:
rope = self.rope_cache[:T]
mask = self.mask_cache[:, :, :T, :T]

# forward the model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)

for block in self.transformer.h:
x = block(x)
if input_pos is None: # proxy for use_cache=False
for block in self.transformer.h:
x, _ = block(x, rope, mask, max_seq_length)
else:
if not self.kv_caches:
head_size = self.config.n_embd // self.config.n_head
cache_shape = (B, self.config.n_head, max_seq_length, head_size)
self.kv_caches = [
(torch.zeros(cache_shape, device=idx.device), torch.zeros(cache_shape, device=idx.device))
for _ in range(self.config.n_layer)
]
for i, block in enumerate(self.transformer.h):
x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])

x = self.transformer.ln_f(x)

logits = self.lm_head(x) # (b, t, vocab_size)
Expand All @@ -83,6 +123,18 @@ def forward(self, idx: torch.Tensor) -> torch.Tensor:
def from_name(cls, name: str) -> Self:
return cls(LLaMAConfig.from_name(name))

def build_rope_cache(self, idx: torch.Tensor) -> KVCache:
return build_rope_cache(
seq_len=self.config.block_size,
n_elem=self.config.n_embd // self.config.n_head,
dtype=idx.dtype,
device=idx.device,
)

def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
return torch.tril(ones).unsqueeze(0).unsqueeze(0)


class Block(nn.Module):
def __init__(self, config: LLaMAConfig) -> None:
Expand All @@ -92,10 +144,19 @@ def __init__(self, config: LLaMAConfig) -> None:
self.rms_2 = RMSNorm(config.n_embd)
self.mlp = MLP(config)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.rms_1(x))
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
mask: torch.Tensor,
max_seq_length: int,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
x = x + h
x = x + self.mlp(self.rms_2(x))
return x
return x, new_kv_cache


class CausalSelfAttention(nn.Module):
Expand All @@ -111,30 +172,45 @@ def __init__(self, config: LLaMAConfig) -> None:
self.n_head = config.n_head
self.n_embd = config.n_embd
self.block_size = config.block_size
self.rope_cache: Optional[torch.Tensor] = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
mask: torch.Tensor,
max_seq_length: int,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

head_size = C // self.n_head
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)

if self.rope_cache is None:
# cache for future forward calls
self.rope_cache = build_rope_cache(
seq_len=self.block_size,
n_elem=self.n_embd // self.n_head,
dtype=x.dtype,
device=x.device,
)

q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)
k = k.view(B, T, self.n_head, head_size)
q = q.view(B, T, self.n_head, head_size)
v = v.view(B, T, self.n_head, head_size)

q = apply_rope(q, rope)
k = apply_rope(k, rope)

k = k.transpose(1, 2) # (B, nh, T, hs)
q = q.transpose(1, 2) # (B, nh, T, hs)
v = v.transpose(1, 2) # (B, nh, T, hs)


if kv_cache is not None:
cache_k, cache_v = kv_cache
# check if reached token limit
if input_pos[-1] >= max_seq_length:
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
# shift 1 position to the left
cache_k = torch.roll(cache_k, -1, dims=2)
cache_v = torch.roll(cache_v, -1, dims=2)
k = cache_k.index_copy(2, input_pos, k)
v = cache_v.index_copy(2, input_pos, v)
kv_cache = k, v

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
Expand All @@ -143,14 +219,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

# output projection
y = self.c_proj(y)

return y
return y, kv_cache


class MLP(nn.Module):
Expand Down Expand Up @@ -218,8 +294,6 @@ def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torc


def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)

# truncate to support variable sizes
T = x.size(1)
rope_cache = rope_cache[:T]
Expand All @@ -233,4 +307,4 @@ def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
], -1)

x_out2 = x_out2.flatten(3)
return x_out2.transpose(1, 2).type_as(x)
return x_out2.type_as(x)