diff --git a/generate.py b/generate.py index 2b19bb70..72940b23 100644 --- a/generate.py +++ b/generate.py @@ -17,10 +17,11 @@ @torch.no_grad() def generate( - model: torch.nn.Module, + model: LLaMA, idx: torch.Tensor, max_new_tokens: int, - max_seq_length: int, + *, + max_seq_length: Optional[int] = None, temperature: float = 1.0, top_k: Optional[int] = None, eos_id: Optional[int] = None, @@ -41,9 +42,15 @@ def generate( # create an empty tensor of the expected final shape and fill in the current tokens T = idx.size(0) T_new = T + max_new_tokens - empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device) + if max_seq_length is None: + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = idx.device, idx.dtype + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) empty[:T] = idx idx = empty + input_pos = torch.arange(0, T, device=device) if idx.device.type == "xla": import torch_xla.core.xla_model as xm @@ -51,34 +58,33 @@ def generate( xm.mark_step() # generate max_new_tokens tokens - for t in range(T, T_new): - # ignore the not-filled-yet tokens - idx_cond = idx[:t] - # if the sequence context is growing too long we must crop it at max_seq_length - idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:] + for _ in range(max_new_tokens): + x = idx.index_select(0, input_pos).view(1, -1) # forward - logits = model(idx_cond.view(1, -1)) + logits = model(x, max_seq_length, input_pos) logits = logits[0, -1] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[[-1]]] = -float("Inf") + logits = torch.where(logits < v[[-1]], -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1) + idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) + + # advance + input_pos = input_pos[-1:] + 1 if idx.device.type == "xla": xm.mark_step() # concatenate the new generation - # https://github.com/pytorch/pytorch/issues/101936 - idx[t] = idx_next.item() if idx.device.type == "mps" else idx_next + idx = idx.index_copy(0, input_pos, idx_next) # if token is triggered, return the output (stop generation) if idx_next == eos_id: - return idx[:t + 1] # include the EOS token + return idx[:input_pos] # include the EOS token return idx @@ -138,16 +144,10 @@ def main( L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() - y = generate( - model, - encoded, - max_new_tokens, - model.config.block_size, # type: ignore[union-attr,arg-type] - temperature=temperature, - top_k=top_k, - ) + y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k) t = time.perf_counter() - t0 + model.reset_cache() print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) diff --git a/generate/adapter.py b/generate/adapter.py index 355852a4..e0a9a5f4 100644 --- a/generate/adapter.py +++ b/generate/adapter.py @@ -82,17 +82,10 @@ def main( prompt_length = encoded.size(0) t0 = time.perf_counter() - y = generate( - model, - idx=encoded, - max_seq_length=max_new_tokens, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - eos_id=tokenizer.eos_id - ) + y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) t = time.perf_counter() - t0 + model.reset_cache() output = tokenizer.decode(y) output = output.split("### Response:")[1].strip() print(output) diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index e307bab0..95ae38e3 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -6,7 +6,6 @@ import lightning as L import torch -import torch.nn as nn # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -85,17 +84,10 @@ def main( prompt_length = encoded.size(0) t0 = time.perf_counter() - y = generate( - model, - idx=encoded, - max_seq_length=max_new_tokens, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - eos_id=tokenizer.eos_id - ) + y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) t = time.perf_counter() - t0 + model.reset_cache() output = tokenizer.decode(y) output = output.split("### Response:")[1].strip() print(output) diff --git a/generate/full.py b/generate/full.py index 3c1c252e..fda53b14 100644 --- a/generate/full.py +++ b/generate/full.py @@ -7,68 +7,14 @@ import lightning as L import torch +# support running without installing as a package +wd = Path(__file__).absolute().parent.parent +sys.path.append(str(wd)) + from lit_llama import LLaMA, Tokenizer from lit_llama.utils import EmptyInitOnDevice from scripts.prepare_alpaca import generate_prompt - -@torch.no_grad() -def generate( - model: torch.nn.Module, - idx: torch.Tensor, - max_new_tokens: int, - max_seq_length: int, - temperature: float = 1.0, - top_k: Optional[int] = None, - eos_id: Optional[int] = None, -) -> torch.Tensor: - """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. - - The implementation of this function is modified from A. Karpathy's nanoGPT. - - Args: - model: The model to use. - idx: Tensor of shape (T) with indices of the prompt sequence. - max_new_tokens: The number of new tokens to generate. - max_seq_length: The maximum sequence length allowed. - temperature: Scales the predicted logits by 1 / temperature - top_k: If specified, only sample among the tokens with the k highest probabilities - eos_id: If specified, stop generating any more token once the token is triggered - """ - # create an empty tensor of the expected final shape and fill in the current tokens - T = idx.size(0) - T_new = T + max_new_tokens - empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device) - empty[:T] = idx - idx = empty - - # generate max_new_tokens tokens - for t in range(T, T_new): - # ignore the not-filled-yet tokens - idx_cond = idx[:t] - # if the sequence context is growing too long we must crop it at max_seq_length - idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:] - - # forward - logits = model(idx_cond.view(1, -1)) - logits = logits[0, -1] / temperature - - # optionally crop the logits to only the top k options - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[[-1]]] = -float("Inf") - - probs = torch.nn.functional.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1) - - # concatenate the new generation - # https://github.com/pytorch/pytorch/issues/101936 - idx[t] = idx_next.item() if idx.device.type == "mps" else idx_next - - # if token is triggered, return the output (stop generation) - if idx_next == eos_id: - return idx[:t + 1] # include the EOS token - - return idx +from generate import generate def main( @@ -130,16 +76,10 @@ def main( L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() - y = generate( - model, - encoded, - max_new_tokens, - model.config.block_size, # type: ignore[union-attr,arg-type] - temperature=temperature, - top_k=top_k, - ) + y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k) t = time.perf_counter() - t0 + model.reset_cache() print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) diff --git a/generate/lora.py b/generate/lora.py index d9f2e3df..5c9b44a4 100644 --- a/generate/lora.py +++ b/generate/lora.py @@ -98,7 +98,6 @@ def main( output = generate( model, idx=encoded, - max_seq_length=max_new_tokens, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, diff --git a/lit_llama/adapter.py b/lit_llama/adapter.py index f743c194..7a7f95d3 100644 --- a/lit_llama/adapter.py +++ b/lit_llama/adapter.py @@ -4,14 +4,15 @@ https://arxiv.org/abs/2303.16199 """ # mypy: ignore-errors -import math from dataclasses import dataclass +from typing import Optional, Tuple, List, Union import torch import torch.nn as nn from torch.nn import functional as F + import lit_llama.model as llama -from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP +from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP, KVCache, RoPECache @dataclass @@ -32,7 +33,7 @@ def __init__(self, config: LLaMAConfig, block_idx: int) -> None: self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) # output projection self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) - + if block_idx >= config.adapter_start_layer: # adapter embedding layer self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) @@ -45,47 +46,59 @@ def __init__(self, config: LLaMAConfig, block_idx: int) -> None: self.block_idx = block_idx self.adapter_prompt_length = config.adapter_prompt_length self.adapter_start_layer = config.adapter_start_layer - self.rope_cache = None - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + mask: torch.Tensor, + max_seq_length: int, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + adapter_kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) head_size = C // self.n_head - k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) - - if self.rope_cache is None: - # cache for future forward calls - self.rope_cache = build_rope_cache( - seq_len=self.block_size, - n_elem=self.n_embd // self.n_head, - dtype=x.dtype, - device=x.device, - ) - - q = apply_rope(q, self.rope_cache) - k = apply_rope(k, self.rope_cache) - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) - # att = F.softmax(att, dim=-1) - # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + k = k.view(B, T, self.n_head, head_size) + q = q.view(B, T, self.n_head, head_size) + v = v.view(B, T, self.n_head, head_size) + + q = apply_rope(q, rope) + k = apply_rope(k, rope) + + k = k.transpose(1, 2) # (B, nh, T, hs) + q = q.transpose(1, 2) # (B, nh, T, hs) + v = v.transpose(1, 2) # (B, nh, T, hs) + + if kv_cache is not None: + cache_k, cache_v = kv_cache + # check if reached token limit + if input_pos[-1] >= max_seq_length: + input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) + # shift 1 position to the left + cache_k = torch.roll(cache_k, -1, dims=2) + cache_v = torch.roll(cache_v, -1, dims=2) + k = cache_k.index_copy(2, input_pos, k) + v = cache_v.index_copy(2, input_pos, v) + kv_cache = k, v # efficient attention using Flash Attention CUDA kernels - y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) if self.block_idx >= self.adapter_start_layer: - prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd) - - aT = prefix.size(1) - _, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) - ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) - av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) + if adapter_kv_cache is not None: + ak, av = adapter_kv_cache + else: + prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd) + aT = prefix.size(1) + _, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) + ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) + av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) + adapter_kv_cache = (ak, av) amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) @@ -96,7 +109,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # output projection y = self.c_proj(y) - return y + return y, kv_cache, adapter_kv_cache class Block(nn.Module): @@ -110,10 +123,22 @@ def __init__(self, config: LLaMAConfig, block_idx: int) -> None: self.rms_2 = RMSNorm(config.n_embd) self.mlp = MLP(config) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.attn(self.rms_1(x)) + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + mask: torch.Tensor, + max_seq_length: int, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + adapter_kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: + h, new_kv_cache, new_adapter_kv_cache = self.attn( + self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache, adapter_kv_cache + ) + x = x + h x = x + self.mlp(self.rms_2(x)) - return x + return x, new_kv_cache, new_adapter_kv_cache class LLaMA(llama.LLaMA): @@ -130,15 +155,76 @@ def __init__(self, config: LLaMAConfig) -> None: self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.vocab_size, config.n_embd), - h=nn.ModuleList([Block(config, i) for i in range(config.n_layer)]), + h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), ln_f=RMSNorm(config.n_embd), ) ) + self.rope_cache: Optional[RoPECache] = None + self.mask_cache: Optional[torch.Tensor] = None + self.kv_caches: List[KVCache] = [] + self.adapter_kv_caches: List[KVCache] = [] + @classmethod def from_name(cls, name: str): return cls(LLaMAConfig.from_name(name)) + def reset_cache(self) -> None: + self.kv_caches.clear() + self.adapter_kv_caches.clear() + + def forward( + self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]: + B, T = idx.size() + + block_size = self.config.block_size + if max_seq_length is None: + max_seq_length = block_size + assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" + assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}" + + if self.rope_cache is None: + self.rope_cache = self.build_rope_cache(idx) + if self.mask_cache is None: + self.mask_cache = self.build_mask_cache(idx) + + if input_pos is not None: + rope = self.rope_cache.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, :max_seq_length] + else: + rope = self.rope_cache[:T] + mask = self.mask_cache[:, :, :T, :T] + + # forward the model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if input_pos is None: # proxy for use_cache=False + for block in self.transformer.h: + x, *_ = block(x, rope, mask, max_seq_length) + else: + if not self.kv_caches: + head_size = self.config.n_embd // self.config.n_head + cache_shape = (B, self.config.n_head, max_seq_length, head_size) + self.kv_caches = [ + (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype)) + for _ in range(self.config.n_layer) + ] + if not self.adapter_kv_caches: + self.adapter_kv_caches = [None for _ in range(self.config.n_layer)] + for i, block in enumerate(self.transformer.h): + x, self.kv_caches[i], self.adapter_kv_caches[i] = block( + x, rope, mask, max_seq_length, input_pos, self.kv_caches[i], self.adapter_kv_caches[i] + ) + + x = self.transformer.ln_f(x) + + logits = self.lm_head(x) # (b, t, vocab_size) + + return logits + def mark_only_adapter_as_trainable(model: LLaMA) -> None: """Sets `requires_grad=False` for all non-adapter weights.""" diff --git a/lit_llama/model.py b/lit_llama/model.py index 9ec2433b..1cf5ff04 100644 --- a/lit_llama/model.py +++ b/lit_llama/model.py @@ -5,7 +5,7 @@ # mypy: ignore-errors import math from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -15,6 +15,11 @@ from lit_llama.utils import find_multiple +MaskCache = torch.Tensor +RoPECache = torch.Tensor +KVCache = Tuple[torch.Tensor, torch.Tensor] + + @dataclass class LLaMAConfig: block_size: int = 2048 @@ -51,28 +56,63 @@ def __init__(self, config: LLaMAConfig) -> None: self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), - h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), ln_f=RMSNorm(config.n_embd), ) ) + self.rope_cache: Optional[RoPECache] = None + self.mask_cache: Optional[MaskCache] = None + self.kv_caches: List[KVCache] = [] + def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) - def forward(self, idx: torch.Tensor) -> torch.Tensor: - _, t = idx.size() - assert ( - t <= self.config.block_size - ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + def forward( + self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]: + B, T = idx.size() - # forward the LLaMA model itself + block_size = self.config.block_size + if max_seq_length is None: + max_seq_length = block_size + assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" + assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}" + + if self.rope_cache is None: + self.rope_cache = self.build_rope_cache(idx) + if self.mask_cache is None: + self.mask_cache = self.build_mask_cache(idx) + + if input_pos is not None: + rope = self.rope_cache.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, :max_seq_length] + else: + rope = self.rope_cache[:T] + mask = self.mask_cache[:, :, :T, :T] + + # forward the model itself x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - for block in self.transformer.h: - x = block(x) + if input_pos is None: # proxy for use_cache=False + for block in self.transformer.h: + x, _ = block(x, rope, mask, max_seq_length) + else: + if not self.kv_caches: + head_size = self.config.n_embd // self.config.n_head + cache_shape = (B, self.config.n_head, max_seq_length, head_size) + self.kv_caches = [ + (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype)) + for _ in range(self.config.n_layer) + ] + for i, block in enumerate(self.transformer.h): + x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i]) + x = self.transformer.ln_f(x) logits = self.lm_head(x) # (b, t, vocab_size) @@ -83,6 +123,21 @@ def forward(self, idx: torch.Tensor) -> torch.Tensor: def from_name(cls, name: str) -> Self: return cls(LLaMAConfig.from_name(name)) + def build_rope_cache(self, idx: torch.Tensor) -> RoPECache: + return build_rope_cache( + seq_len=self.config.block_size, + n_elem=self.config.n_embd // self.config.n_head, + dtype=idx.dtype, + device=idx.device, + ) + + def build_mask_cache(self, idx: torch.Tensor) -> MaskCache: + ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) + return torch.tril(ones).unsqueeze(0).unsqueeze(0) + + def reset_cache(self) -> None: + self.kv_caches.clear() + class Block(nn.Module): def __init__(self, config: LLaMAConfig) -> None: @@ -92,10 +147,19 @@ def __init__(self, config: LLaMAConfig) -> None: self.rms_2 = RMSNorm(config.n_embd) self.mlp = MLP(config) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.attn(self.rms_1(x)) + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + mask: MaskCache, + max_seq_length: int, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache]]: + h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache) + x = x + h x = x + self.mlp(self.rms_2(x)) - return x + return x, new_kv_cache class CausalSelfAttention(nn.Module): @@ -111,30 +175,44 @@ def __init__(self, config: LLaMAConfig) -> None: self.n_head = config.n_head self.n_embd = config.n_embd self.block_size = config.block_size - self.rope_cache: Optional[torch.Tensor] = None - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + mask: MaskCache, + max_seq_length: int, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache]]: B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) head_size = C // self.n_head - k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) - - if self.rope_cache is None: - # cache for future forward calls - self.rope_cache = build_rope_cache( - seq_len=self.block_size, - n_elem=self.n_embd // self.n_head, - dtype=x.dtype, - device=x.device, - ) - - q = apply_rope(q, self.rope_cache) - k = apply_rope(k, self.rope_cache) + k = k.view(B, T, self.n_head, head_size) + q = q.view(B, T, self.n_head, head_size) + v = v.view(B, T, self.n_head, head_size) + + q = apply_rope(q, rope) + k = apply_rope(k, rope) + + k = k.transpose(1, 2) # (B, nh, T, hs) + q = q.transpose(1, 2) # (B, nh, T, hs) + v = v.transpose(1, 2) # (B, nh, T, hs) + + if kv_cache is not None: + cache_k, cache_v = kv_cache + # check if reached token limit + if input_pos[-1] >= max_seq_length: + input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) + # shift 1 position to the left + cache_k = torch.roll(cache_k, -1, dims=2) + cache_v = torch.roll(cache_v, -1, dims=2) + k = cache_k.index_copy(2, input_pos, k) + v = cache_v.index_copy(2, input_pos, v) + kv_cache = k, v # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) @@ -143,14 +221,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) # efficient attention using Flash Attention CUDA kernels - y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.c_proj(y) - return y + return y, kv_cache class MLP(nn.Module): @@ -193,7 +271,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.scale * x_normed -def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor: +def build_rope_cache( + seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 +) -> RoPECache: """Enhanced Transformer with Rotary Position Embedding. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ @@ -217,9 +297,7 @@ def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torc return cache -def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - x = x.transpose(1, 2) - +def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor: # truncate to support variable sizes T = x.size(1) rope_cache = rope_cache[:T] @@ -228,9 +306,12 @@ def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: xshaped = x.float().reshape(*x.shape[:-1], -1, 2) rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2) x_out2 = torch.stack( - [xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], -1) + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) x_out2 = x_out2.flatten(3) - return x_out2.transpose(1, 2).type_as(x) + return x_out2.type_as(x) diff --git a/tests/test_generate.py b/tests/test_generate.py index 5acb7a00..4f7f307a 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -101,7 +101,7 @@ def device(self): tokenizer_mock.assert_called_once_with(tokenizer_path) assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) - assert generate_mock.mock_calls == [call(ANY, ANY, 50, ANY, temperature=2.0, top_k=2)] * num_samples + assert generate_mock.mock_calls == [call(ANY, ANY, 50, temperature=2.0, top_k=2)] * num_samples # only the generated result is printed to stdout assert out.getvalue() == "foo bar baz\n" * num_samples diff --git a/tests/test_model.py b/tests/test_model.py index 9d85e368..3abc4843 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -33,25 +33,30 @@ def copy_weights(llama_model, orig_llama_model) -> None: @torch.no_grad() -def test_to_orig_llama(lit_llama, orig_llama) -> None: +@pytest.mark.parametrize("kv_cache", (False, True)) +def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None: block_size = 64 vocab_size = 32000 n_layer = 16 n_head = 16 n_embd = 32 + batch_size = 3 llama_config = lit_llama.LLaMAConfig( block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd ) orig_llama_config = orig_llama.ModelArgs( - dim=n_embd, n_layers=n_layer, n_heads=n_head, vocab_size=vocab_size, norm_eps=1e-5, max_seq_len=block_size + dim=n_embd, + n_layers=n_layer, + n_heads=n_head, + vocab_size=vocab_size, + norm_eps=1e-5, + max_seq_len=block_size, + max_batch_size=batch_size, ) - batch_size = 3 - - token_sample = torch.randint( - 0, orig_llama_config.vocab_size, size=(batch_size, orig_llama_config.max_seq_len), dtype=torch.int64 - ) + seq_len = orig_llama_config.max_seq_len + token_sample = torch.randint(0, orig_llama_config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) llama_model = lit_llama.LLaMA(llama_config) llama_model.apply(llama_model._init_weights) @@ -63,11 +68,31 @@ def test_to_orig_llama(lit_llama, orig_llama) -> None: llama_embed = llama_model.transformer.wte(token_sample) assert torch.allclose(orig_llama_embed, llama_embed) - seq_len = token_sample.shape[1] - mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) - mask = torch.triu(mask, diagonal=1) - orig_llama_block_out = orig_llama_model.layers[0](orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], mask) - llama_block_out = llama_model.transformer.h[0](llama_embed) + llama_rope = llama_model.build_rope_cache(token_sample) + llama_mask = llama_model.build_mask_cache(token_sample) + orig_llama_mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) + orig_llama_mask = torch.triu(orig_llama_mask, diagonal=1) + if kv_cache: + orig_llama_block_out = orig_llama_model.layers[0]( + orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask + ) + theirs_k_cache = orig_llama_model.layers[0].attention.cache_k + theirs_v_cache = orig_llama_model.layers[0].attention.cache_v + head_size = n_embd // n_head + kv_cache_shape = (batch_size, n_head, block_size, head_size) + ours_kv_cache = torch.zeros(kv_cache_shape), torch.zeros(kv_cache_shape) + (llama_block_out, ours_kv_cache) = llama_model.transformer.h[0]( + llama_embed, llama_rope, llama_mask, seq_len, torch.arange(block_size), ours_kv_cache + ) + ours_k_cache = ours_kv_cache[0].permute(0, 2, 1, 3) + ours_v_cache = ours_kv_cache[1].permute(0, 2, 1, 3) + torch.testing.assert_close(ours_k_cache, theirs_k_cache) + torch.testing.assert_close(ours_v_cache, theirs_v_cache) + else: + orig_llama_block_out = orig_llama_model.layers[0]( + orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask + ) + (llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, llama_mask, seq_len) assert torch.allclose(orig_llama_block_out, llama_block_out) expected = orig_llama_model(token_sample, 0) @@ -79,6 +104,7 @@ def test_to_orig_llama(lit_llama, orig_llama) -> None: @torch.no_grad() def test_bfloat16_llama_init(lit_llama, orig_llama) -> None: from lit_llama.utils import EmptyInitOnDevice + block_size = 64 vocab_size = 32000 n_layer = 16 @@ -86,20 +112,14 @@ def test_bfloat16_llama_init(lit_llama, orig_llama) -> None: n_embd = 32 llama_config = lit_llama.LLaMAConfig( - block_size=block_size, - vocab_size=vocab_size, - n_layer=n_layer, - n_head=n_head, - n_embd=n_embd, + block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd ) llama_model = lit_llama.LLaMA(llama_config) llama_model.apply(llama_model._init_weights) batch_size = 3 - token_sample = torch.randint( - 0, vocab_size, size=(batch_size, block_size), dtype=torch.int64 - ) + token_sample = torch.randint(0, vocab_size, size=(batch_size, block_size), dtype=torch.int64) expected = llama_model(token_sample) @@ -140,8 +160,9 @@ def enable_gate(model): def test_adapter_parity(orig_llama_adapter): """Test parity between our implementation of LLaMA-Adapter and the reference code.""" import lit_llama.adapter as lit_llama + orig_llama = orig_llama_adapter - + block_size = 32 vocab_size = 100 n_layer = 2 @@ -151,12 +172,23 @@ def test_adapter_parity(orig_llama_adapter): adapter_start_layer: int = 0 llama_config = lit_llama.LLaMAConfig( - block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd, - adapter_prompt_length=adapter_prompt_length, adapter_start_layer=adapter_start_layer, + block_size=block_size, + vocab_size=vocab_size, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + adapter_prompt_length=adapter_prompt_length, + adapter_start_layer=adapter_start_layer, ) orig_llama_config = orig_llama.ModelArgs( - dim=n_embd, n_layers=n_layer, n_heads=n_head, vocab_size=vocab_size, norm_eps=1e-5, max_seq_len=block_size, - adapter_len=adapter_prompt_length, adapter_layer=(n_layer - adapter_start_layer), + dim=n_embd, + n_layers=n_layer, + n_heads=n_head, + vocab_size=vocab_size, + norm_eps=1e-5, + max_seq_len=block_size, + adapter_len=adapter_prompt_length, + adapter_layer=(n_layer - adapter_start_layer), ) batch_size = 3 @@ -183,13 +215,7 @@ def test_adapter_parity(orig_llama_adapter): @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="torch.compile not supported on this platform") def test_model_compile(lit_llama): - llama_config = lit_llama.LLaMAConfig( - block_size=8, - vocab_size=8, - n_layer=2, - n_head=2, - n_embd=4, - ) + llama_config = lit_llama.LLaMAConfig(block_size=8, vocab_size=8, n_layer=2, n_head=2, n_embd=4) model = lit_llama.LLaMA(llama_config) model.apply(model._init_weights) diff --git a/tests/test_rope.py b/tests/test_rope.py index 8a42dcab..37e993ab 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -1,5 +1,6 @@ import torch + @torch.no_grad() def test_rope(lit_llama, orig_llama) -> None: torch.manual_seed(1) @@ -11,6 +12,6 @@ def test_rope(lit_llama, orig_llama) -> None: llama_rope_cache = lit_llama.build_rope_cache(seq_len, n_embed // n_head, dtype=x.dtype, device=x.device) torch.testing.assert_close(freqs_cis, torch.view_as_complex(llama_rope_cache)) - llama_x_rope = lit_llama.apply_rope(x.transpose(1, 2), llama_rope_cache).transpose(1, 2) + llama_x_rope = lit_llama.apply_rope(x, llama_rope_cache) orig_llama_x_rope, _ = orig_llama.apply_rotary_emb(x, x, freqs_cis) torch.testing.assert_close(llama_x_rope, orig_llama_x_rope)