|
13 | 13 | # limitations under the License. |
14 | 14 | import json |
15 | 15 | import pathlib |
| 16 | +from dataclasses import dataclass |
| 17 | +from pathlib import Path |
| 18 | +from typing import Optional |
| 19 | +from typing import Tuple |
16 | 20 |
|
17 | 21 | import torch |
| 22 | +from torch import nn |
18 | 23 |
|
19 | 24 | from keras_nlp.models import MistralBackbone |
20 | 25 |
|
21 | | -from .scripts.mistral_torch import ModelArgs |
22 | | -from .scripts.mistral_torch import Transformer as TorchTransformer |
23 | | - |
24 | 26 | MODEL_PATH = pathlib.Path("mistral-7B-v0.1") |
25 | 27 |
|
| 28 | +# Torch model taken from: |
| 29 | +# https://github.com/mistralai/mistral-src/blob/147c4e68279b90eb61b19bdea44e16f5539d5a5d/one_file_ref.py |
| 30 | + |
| 31 | + |
| 32 | +@dataclass |
| 33 | +class ModelArgs: |
| 34 | + dim: int |
| 35 | + n_layers: int |
| 36 | + head_dim: int |
| 37 | + hidden_dim: int |
| 38 | + n_heads: int |
| 39 | + n_kv_heads: int |
| 40 | + sliding_window: int |
| 41 | + norm_eps: float |
| 42 | + vocab_size: int |
| 43 | + |
| 44 | + max_batch_size: int = 0 |
| 45 | + |
| 46 | + |
| 47 | +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int): |
| 48 | + keys = torch.repeat_interleave(keys, repeats=repeats, dim=2) |
| 49 | + values = torch.repeat_interleave(values, repeats=repeats, dim=2) |
| 50 | + return keys, values |
| 51 | + |
| 52 | + |
| 53 | +def _reshape_for_broadcast( |
| 54 | + freqs_cis: torch.Tensor, x: torch.Tensor |
| 55 | +) -> torch.Tensor: |
| 56 | + """ |
| 57 | + freqs_cis: complex - (seq_len, head_dim / 2) |
| 58 | + x: complex - (bsz, seq_len, head_dim / 2) |
| 59 | + """ |
| 60 | + ndim = x.ndim |
| 61 | + assert 1 < ndim |
| 62 | + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( |
| 63 | + freqs_cis.shape, |
| 64 | + (x.shape[1], x.shape[-1]), |
| 65 | + ) |
| 66 | + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| 67 | + return freqs_cis.view(*shape) |
| 68 | + |
| 69 | + |
| 70 | +def apply_rotary_emb( |
| 71 | + xq: torch.Tensor, |
| 72 | + xk: torch.Tensor, |
| 73 | + freqs_cis: torch.Tensor, |
| 74 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 75 | + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| 76 | + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| 77 | + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) |
| 78 | + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| 79 | + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| 80 | + return xq_out.type_as(xq), xk_out.type_as(xk) |
| 81 | + |
| 82 | + |
| 83 | +class Attention(nn.Module): |
| 84 | + def __init__(self, args: ModelArgs): |
| 85 | + super().__init__() |
| 86 | + self.args = args |
| 87 | + |
| 88 | + self.n_heads: int = args.n_heads |
| 89 | + self.n_kv_heads: int = args.n_kv_heads |
| 90 | + |
| 91 | + self.repeats = self.n_heads // self.n_kv_heads |
| 92 | + self.sliding_window = self.args.sliding_window |
| 93 | + |
| 94 | + self.scale = self.args.head_dim**-0.5 |
| 95 | + |
| 96 | + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) |
| 97 | + self.wk = nn.Linear( |
| 98 | + args.dim, args.n_kv_heads * args.head_dim, bias=False |
| 99 | + ) |
| 100 | + self.wv = nn.Linear( |
| 101 | + args.dim, args.n_kv_heads * args.head_dim, bias=False |
| 102 | + ) |
| 103 | + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) |
| 104 | + self.cache_k = torch.empty( |
| 105 | + ( |
| 106 | + args.max_batch_size, |
| 107 | + args.sliding_window, |
| 108 | + self.n_kv_heads, |
| 109 | + self.args.head_dim, |
| 110 | + ), |
| 111 | + dtype=torch.float16, |
| 112 | + ) |
| 113 | + self.cache_v = torch.empty( |
| 114 | + ( |
| 115 | + args.max_batch_size, |
| 116 | + args.sliding_window, |
| 117 | + self.n_kv_heads, |
| 118 | + self.args.head_dim, |
| 119 | + ), |
| 120 | + dtype=torch.float16, |
| 121 | + ) |
| 122 | + |
| 123 | + def forward( |
| 124 | + self, |
| 125 | + x: torch.Tensor, |
| 126 | + freqs_cis: torch.Tensor, |
| 127 | + positions: torch.Tensor, |
| 128 | + mask: Optional[torch.Tensor], |
| 129 | + ) -> torch.Tensor: |
| 130 | + bsz, seqlen, _ = x.shape |
| 131 | + |
| 132 | + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
| 133 | + xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim) |
| 134 | + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) |
| 135 | + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) |
| 136 | + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) |
| 137 | + |
| 138 | + # The cache is a rotating buffer |
| 139 | + scatter_pos = (positions[-self.sliding_window :] % self.sliding_window)[ |
| 140 | + None, :, None, None |
| 141 | + ] |
| 142 | + scatter_pos = scatter_pos.repeat( |
| 143 | + bsz, 1, self.n_kv_heads, self.args.head_dim |
| 144 | + ) |
| 145 | + self.cache_k[:bsz].scatter_( |
| 146 | + dim=1, |
| 147 | + index=scatter_pos, |
| 148 | + src=xk[:, -self.sliding_window :].to(self.cache_k.dtype), |
| 149 | + ) |
| 150 | + self.cache_v[:bsz].scatter_( |
| 151 | + dim=1, |
| 152 | + index=scatter_pos, |
| 153 | + src=xv[:, -self.sliding_window :].to(self.cache_v.dtype), |
| 154 | + ) |
| 155 | + |
| 156 | + if positions.shape[0] > 1: |
| 157 | + # prefill |
| 158 | + key, value = repeat_kv(xk, xv, self.repeats) |
| 159 | + else: |
| 160 | + cur_pos = positions[-1].item() + 1 |
| 161 | + key, value = repeat_kv( |
| 162 | + self.cache_k[:bsz, :cur_pos, ...].to(xk.dtype), |
| 163 | + self.cache_v[:bsz, :cur_pos, ...].to(xv.dtype), |
| 164 | + self.repeats, |
| 165 | + ) |
| 166 | + |
| 167 | + query = xq.transpose(1, 2) |
| 168 | + key = key.transpose(1, 2) |
| 169 | + value = value.transpose(1, 2) |
| 170 | + # scores : [bsz, n_heads, seqlen | 1, seqlen] |
| 171 | + scores = torch.matmul(query, key.transpose(2, 3)) * self.scale |
| 172 | + |
| 173 | + if mask is not None: |
| 174 | + scores += mask[None, None, ...] |
| 175 | + |
| 176 | + scores = scores.float() |
| 177 | + scores = nn.functional.softmax(scores, dim=-1).type_as(query) |
| 178 | + output = torch.matmul( |
| 179 | + scores, value |
| 180 | + ) # (bs, n_local_heads, slen, head_dim) |
| 181 | + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) |
| 182 | + return self.wo(output) |
| 183 | + |
| 184 | + |
| 185 | +class FeedForward(nn.Module): |
| 186 | + def __init__(self, args: ModelArgs): |
| 187 | + super().__init__() |
| 188 | + |
| 189 | + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) |
| 190 | + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) |
| 191 | + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) |
| 192 | + |
| 193 | + def forward(self, x) -> torch.Tensor: |
| 194 | + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) |
| 195 | + |
| 196 | + |
| 197 | +class RMSNorm(torch.nn.Module): |
| 198 | + def __init__(self, dim: int, eps: float = 1e-6): |
| 199 | + super().__init__() |
| 200 | + self.eps = eps |
| 201 | + self.weight = nn.Parameter(torch.ones(dim)) |
| 202 | + |
| 203 | + def _norm(self, x): |
| 204 | + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| 205 | + |
| 206 | + def forward(self, x): |
| 207 | + output = self._norm(x.float()).type_as(x) |
| 208 | + return output * self.weight |
| 209 | + |
| 210 | + |
| 211 | +class TransformerBlock(nn.Module): |
| 212 | + def __init__(self, args: ModelArgs): |
| 213 | + super().__init__() |
| 214 | + self.n_heads = args.n_heads |
| 215 | + self.dim = args.dim |
| 216 | + self.attention = Attention(args) |
| 217 | + self.feed_forward = FeedForward(args=args) |
| 218 | + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| 219 | + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| 220 | + self.args = args |
| 221 | + |
| 222 | + def forward( |
| 223 | + self, |
| 224 | + x: torch.Tensor, |
| 225 | + freqs_cis: torch.Tensor, |
| 226 | + positions: torch.Tensor, |
| 227 | + mask: Optional[torch.Tensor], |
| 228 | + ) -> torch.Tensor: |
| 229 | + r = self.attention.forward( |
| 230 | + self.attention_norm(x), freqs_cis, positions, mask |
| 231 | + ) |
| 232 | + h = x + r |
| 233 | + r = self.feed_forward.forward(self.ffn_norm(h)) |
| 234 | + out = h + r |
| 235 | + return out |
| 236 | + |
| 237 | + |
| 238 | +def precompute_freqs_cis( |
| 239 | + dim: int, end: int, theta: float = 10000.0 |
| 240 | +) -> torch.Tensor: |
| 241 | + freqs = 1.0 / ( |
| 242 | + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) |
| 243 | + ) |
| 244 | + t = torch.arange(end, device=freqs.device) # type: ignore |
| 245 | + freqs = torch.outer(t, freqs).float() # type: ignore |
| 246 | + return torch.polar(torch.ones_like(freqs), freqs) # complex64 |
| 247 | + |
| 248 | + |
| 249 | +class TorchTransformer(nn.Module): |
| 250 | + def __init__(self, args: ModelArgs): |
| 251 | + super().__init__() |
| 252 | + self.args = args |
| 253 | + self.vocab_size = args.vocab_size |
| 254 | + self.n_layers = args.n_layers |
| 255 | + assert self.vocab_size > 0 |
| 256 | + |
| 257 | + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) |
| 258 | + |
| 259 | + self.layers = torch.nn.ModuleList( |
| 260 | + [TransformerBlock(args=args) for _ in range(args.n_layers)] |
| 261 | + ) |
| 262 | + |
| 263 | + self.norm = RMSNorm(args.dim, eps=args.norm_eps) |
| 264 | + |
| 265 | + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) |
| 266 | + |
| 267 | + self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000) |
| 268 | + |
| 269 | + def forward( |
| 270 | + self, |
| 271 | + input_ids: torch.Tensor, |
| 272 | + positions: torch.Tensor, |
| 273 | + ): |
| 274 | + h = self.tok_embeddings(input_ids) |
| 275 | + freqs_cis = self.freqs_cis[positions] |
| 276 | + |
| 277 | + mask: Optional[torch.Tensor] = None |
| 278 | + if input_ids.shape[1] > 1: |
| 279 | + seqlen = input_ids.shape[1] |
| 280 | + tensor = torch.full( |
| 281 | + (seqlen, seqlen), |
| 282 | + dtype=h.dtype, |
| 283 | + fill_value=1, |
| 284 | + device=h.device, |
| 285 | + ) |
| 286 | + mask = torch.tril(tensor, diagonal=0).to(h.dtype) |
| 287 | + # make the mask banded to account for sliding window |
| 288 | + mask = torch.triu(mask, diagonal=-self.args.sliding_window) |
| 289 | + mask = torch.log(mask) |
| 290 | + |
| 291 | + for layer in self.layers: |
| 292 | + h = layer(h, freqs_cis, positions, mask) |
| 293 | + |
| 294 | + return self.output(self.norm(h)).float() |
| 295 | + |
| 296 | + @staticmethod |
| 297 | + def from_folder( |
| 298 | + folder: Path, max_batch_size: int = 1, device="cpu", dtype=torch.float16 |
| 299 | + ): |
| 300 | + with open(folder / "params.json", "r") as f: |
| 301 | + model_args = ModelArgs(**json.loads(f.read())) |
| 302 | + model_args.max_batch_size = max_batch_size |
| 303 | + model = TorchTransformer(model_args).to(device=device, dtype=dtype) |
| 304 | + loaded = torch.load(folder / "consolidated.00.pth") |
| 305 | + model.load_state_dict(loaded) |
| 306 | + return model |
| 307 | + |
26 | 308 |
|
27 | 309 | def port_weights( |
28 | 310 | model_k3: MistralBackbone, model_torch: TorchTransformer, params: ModelArgs |
|
0 commit comments