From e9abb4de4d6140f00a3ecc66d26193bdc2dc2d53 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Thu, 11 Apr 2024 13:53:02 -0700 Subject: [PATCH] add export configs (#2965) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2965 Differential Revision: D55953027 --- examples/models/llama2/builder.py | 8 +++-- examples/models/llama2/export_llama_lib.py | 13 ++++++- examples/models/llama2/llama_transformer.py | 40 +++++++++++++-------- examples/models/llama2/model.py | 38 +++++++++++++++++++- 4 files changed, 81 insertions(+), 18 deletions(-) diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index 3473391b641..35577ad3ec7 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -62,7 +62,8 @@ def to_torch_dtype(self) -> torch.dtype: def load_llama_model( *, - checkpoint: str, + checkpoint: Optional[str] = None, + checkpoint_dir: Optional[str] = None, params_path: str, use_kv_cache: bool = False, use_sdpa_with_kv_cache: bool = False, @@ -76,7 +77,9 @@ def load_llama_model( Returns: An instance of LlamaEdgeManager which contains the eager mode model. """ - assert checkpoint and params_path, "Both checkpoint and params can't be empty" + assert ( + checkpoint or checkpoint_dir + ) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty" logging.info( f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}" ) @@ -84,6 +87,7 @@ def load_llama_model( "llama2", "Llama2Model", checkpoint=checkpoint, + checkpoint_dir=checkpoint_dir, params=params_path, use_kv_cache=use_kv_cache, use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 9dedbf47795..76cfd00f3b7 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -242,6 +242,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=f"{ckpt_dir}/params/demo_rand_params.pth", help="checkpoint path", ) + + parser.add_argument( + "--checkpoint_dir", + default=None, + help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", + ) + parser.add_argument( "--calibration_tasks", nargs="+", @@ -417,7 +424,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: """ # load model from checkpoint and params.json - checkpoint_path = canonical_path(args.checkpoint) + checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None + checkpoint_dir = ( + canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None + ) params_path = canonical_path(args.params) output_dir_path = canonical_path(args.output_dir, dir=True) modelname = "llama2" @@ -485,6 +495,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: return ( load_llama_model( checkpoint=checkpoint_path, + checkpoint_dir=checkpoint_dir, params_path=params_path, use_kv_cache=args.use_kv_cache, use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index e9650f81814..66fc47b17f0 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -62,6 +62,12 @@ def forward(self, x): return output * self.weight +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + @dataclass class ModelArgs: dim: int = 4096 @@ -96,6 +102,16 @@ def __post_init__(self): if self.use_sdpa_with_kv_cache_op: assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache" + if self.hidden_dim is None: + # If hidden_dim is not explicitly set in the ModelArgs, + # then calculate implicitly based on dim and also multiple of `args.multiple_of` + multiple_of = self.multiple_of + hidden_dim = 4 * self.dim + hidden_dim = int(2 * hidden_dim / 3) + if self.ffn_dim_multiplier is not None: + hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) + self.hidden_dim = find_multiple(hidden_dim, multiple_of) + def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" @@ -316,19 +332,11 @@ def forward( class FeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - dim = args.dim - hidden_dim = args.hidden_dim - if hidden_dim is None: - # If hidden_dim is not explicitly set in the ModelArgs, - # then calculate implicitly based on dim and also multiple of `args.multiple_of` - multiple_of = args.multiple_of - hidden_dim = 4 * dim - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) + assert args.hidden_dim is not None + hidden_dim: int = args.hidden_dim + self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) @@ -425,7 +433,11 @@ def __init__(self, params: ModelArgs): freqs_cos, freqs_sin = precompute_freqs_cis( params.dim // params.n_heads, - params.max_seq_len, + ( + params.max_seq_len # Normal llama2. + if params.ffn_dim_multiplier is None + else params.max_seq_len * 2 # Sharded checkpoint. + ), params.rope_freq_base, ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 68882433679..5428e34b74f 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import json +import os from pathlib import Path import torch @@ -48,6 +50,12 @@ def __init__(self, **kwargs): # The 1st way ckpt_dir = Path(__file__).absolute().parent / "params" + # Check if checkpoint_dir was provided for a sharded checkpoint. + checkpoint_dir = ( + kwargs["checkpoint_dir"] if "checkpoint_dir" in kwargs else None + ) + + # Use single checkpoint file. checkpoint_path = ( kwargs["checkpoint"] if "checkpoint" in kwargs @@ -72,7 +80,35 @@ def __init__(self, **kwargs): # Follow the instruction in https://github.com/facebookresearch/llama to download the model device = "cpu" # flake8: noqa: TOR102 - checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) + cps = [] + if checkpoint_dir is not None: + # Load multiple checkpoint; ignore the single path. + checkpoint_path = None + for i in range(4): + cp_name = f"consolidated.{i}.pth" + print(f"Loading {cp_name}") + cps.append( + torch.load( + os.path.join(checkpoint_dir, cp_name), + map_location=device, + mmap=True, + ) + ) + checkpoint = {} + for key in cps[0].keys(): + if not torch.allclose(cps[0][key], cps[1][key]): + values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) + if "wo" in key or "w2" in key: + # Concat on dim=1 for "wo" and "w2". + checkpoint[key] = torch.cat(values, dim=1) + else: + # Concat on dim=0 for everything else. + checkpoint[key] = torch.cat(values, dim=0) + else: + # Do not duplicate layers shared between each checkpoint. + checkpoint[key] = cps[0][key] + else: + checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) fairseq2_checkpoint = kwargs.get("fairseq2", False) if fairseq2_checkpoint: print("Using fairseq2 checkpoint")