diff --git a/builder.py b/builder.py new file mode 100644 index 0000000000..72f3292bdf --- /dev/null +++ b/builder.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +from quantize import quantize_model, name_to_dtype, set_precision, get_precision +from cli import cli_args + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={ device } is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from sentencepiece import SentencePieceProcessor + +from model import Transformer + +def _load_model( + checkpoint_path, + checkpoint_dir, + params_path, + params_table, + gguf_path, + device, + precision, + use_tp # =False +): + use_cuda = "cuda" in device + with torch.device("meta"): + if params_path: + model = Transformer.from_params(params_path) + elif params_table: + model = Transformer.from_table(params_path) + elif gguf_path: + model = Transformer.from_gguf(gguf_path) + else: + model = Transformer.from_name(checkpoint_path.parent.name) + + # checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + cps = [] + if checkpoint_dir is not None: + # Load multiple checkpoint; ignore the single path. + checkpoint_path = None + for i in range(4): + cp_name = f"consolidated.{i}.pth" + print(f"Loading {cp_name}") + cps.append( + torch.load( + os.path.join(checkpoint_dir, cp_name), + map_location=device, + mmap=True, + ) + ) + + checkpoint = {} + for key in cps[0].keys(): + if not torch.allclose(cps[0][key], cps[1][key]): + values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) + if key.endswith("wo.weight") or key.endswith("w2.weight"): + checkpoint[key] = torch.cat(values, dim=1) + else: + checkpoint[key] = torch.cat(values, dim=0) + else: + checkpoint[key] = cps[0][key] + else: + checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True, weights_only=True) + + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + + model.load_state_dict(checkpoint, assign=True) + + if use_tp: + from tp import apply_tp + + print("Applying tensor parallel to model ...") + apply_tp(model) + + model = model.to(device=device, dtype=precision) + return model.eval() + + +def _initialize_model( + checkpoint_path, + checkpoint_dir, + params_path, + params_table, + gguf_path, + dso_path, + pte_path, + quantize, + device, + precision, + setup_caches, + use_tp # =False +): + assert ( + (checkpoint_path and checkpoint_path.is_file()) or + (checkpoint_dir and checkpoint_path.is_dir()) or + (gguf_path and gguf_path.is_file()) or + (dso_path and Path(dso_path).is_file()) or + (pte_path and Path(pte_path).is_file()) + ), "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" + assert not (dso_path and pte_path), "specify either DSO path or PTE path, but not both" + + if (checkpoint_path and (dso_path or pte_path)): + print("Warning: checkpoint path ignored because an exported DSO or PTE path specified") + if (checkpoint_dir and (dso_path or pte_path)): + print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified") + if (gguf_path and (dso_path or pte_path)): + print("Warning: GGUF path ignored because an exported DSO or PTE path specified") + + print("Loading model ...") + t0 = time.time() + model_ = _load_model( + checkpoint_path, + checkpoint_dir, + params_path, + params_table, + gguf_path, + device, + precision, + use_tp + ) + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + if dso_path: + # make sure user did not try to set dtype + # assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export." + assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export." + try: + model = model_ + # Replace model forward with the AOT-compiled forward + # This is a hacky way to quickly demo AOTI's capability. + # model is still a Python object, and any mutation to its + # attributes will NOT be seen on by AOTI-compiled forward + # function, e.g. calling model.setup_cache will NOT touch + # AOTI compiled and maintained model buffers such as kv_cache. + model.forward = torch._export.aot_load(str(dso_path.absolute()), device) + except: + raise RuntimeError(f"Failed to load AOTI compiled {dso_path}") + elif pte_path: + # make sure user did not try to set dtype + # assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export." + assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export." + try: + from model_et import PTEModel + model = PTEModel(model_.config, pte_path) + except Exception as e: + raise RuntimeError(f"Failed to load ET compiled {pte_path}") + else: + model = model_ + + if quantize: + t0q = time.time() + quantize_model(model, quantize) + device_sync(device=device) # MKG + print(f"Time to quantize model: {time.time() - t0q:.02f} seconds") + + if setup_caches: + max_seq_length = 350 + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + model.to(dtype=precision) + + return model + + diff --git a/eval.py b/eval.py index a28eda0ef0..a63cd5f2cf 100644 --- a/eval.py +++ b/eval.py @@ -31,7 +31,8 @@ except: lm_eval_available = False -from generate import _initialize_model, encode_tokens, model_forward +from builder import _initialize_model +from generate import encode_tokens, model_forward if lm_eval_available: try: # lm_eval version 0.4 diff --git a/export.py b/export.py index 9a783e2289..c5c05c9617 100644 --- a/export.py +++ b/export.py @@ -25,7 +25,8 @@ from export_aoti import export_model as export_model_aoti from model import Transformer -from generate import _load_model, decode_one_token, _initialize_model +from builder import _initialize_model +from generate import decode_one_token from quantize import quantize_model, name_to_dtype from torch._export import capture_pre_autograd_graph diff --git a/export_aoti.py b/export_aoti.py index 7a5306b5b6..c7d4d6d92d 100644 --- a/export_aoti.py +++ b/export_aoti.py @@ -14,7 +14,7 @@ import torch.nn as nn from torch.export import Dim, export -from generate import _load_model, decode_one_token +from generate import decode_one_token from quantize import quantize_model from model import Transformer diff --git a/export_et.py b/export_et.py index 6630ea86fa..bfdb8d0872 100644 --- a/export_et.py +++ b/export_et.py @@ -11,7 +11,7 @@ import torch.nn as nn from torch.export import Dim, export -from generate import _load_model, decode_one_token +from generate import decode_one_token from quantize import quantize_model from quantize import quantize_model, name_to_dtype, set_precision, get_precision @@ -28,8 +28,6 @@ from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from generate import _load_model - from model import Transformer from torch._export import capture_pre_autograd_graph diff --git a/generate.py b/generate.py index 61d5185bbe..ac99d6542e 100644 --- a/generate.py +++ b/generate.py @@ -8,6 +8,7 @@ import time from pathlib import Path from typing import Optional, Tuple +from builder import _load_model, _initialize_model import torch import torch._dynamo.config @@ -274,161 +275,6 @@ def encode_tokens(tokenizer, string, bos=True, device="cuda"): return torch.tensor(tokens, dtype=torch.int, device=device) -def _load_model( - checkpoint_path, - checkpoint_dir, - params_path, - params_table, - gguf_path, - device, - precision, - use_tp # =False -): - use_cuda = "cuda" in device - with torch.device("meta"): - if params_path: - model = Transformer.from_params(params_path) - elif params_table: - model = Transformer.from_table(params_path) - elif gguf_path: - model = Transformer.from_gguf(gguf_path) - else: - model = Transformer.from_name(checkpoint_path.parent.name) - - # checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - cps = [] - if checkpoint_dir is not None: - # Load multiple checkpoint; ignore the single path. - checkpoint_path = None - for i in range(4): - cp_name = f"consolidated.{i}.pth" - print(f"Loading {cp_name}") - cps.append( - torch.load( - os.path.join(checkpoint_dir, cp_name), - map_location=device, - mmap=True, - ) - ) - - checkpoint = {} - for key in cps[0].keys(): - if not torch.allclose(cps[0][key], cps[1][key]): - values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) - if key.endswith("wo.weight") or key.endswith("w2.weight"): - checkpoint[key] = torch.cat(values, dim=1) - else: - checkpoint[key] = torch.cat(values, dim=0) - else: - checkpoint[key] = cps[0][key] - else: - checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True, weights_only=True) - - if "model" in checkpoint and "stories" in str(checkpoint_path): - checkpoint = checkpoint["model"] - - model.load_state_dict(checkpoint, assign=True) - - if use_tp: - from tp import apply_tp - - print("Applying tensor parallel to model ...") - apply_tp(model) - - model = model.to(device=device, dtype=precision) - return model.eval() - - -B_INST, E_INST = "[INST]", "[/INST]" - -def _initialize_model( - checkpoint_path, - checkpoint_dir, - params_path, - params_table, - gguf_path, - dso_path, - pte_path, - quantize, - device, - precision, - setup_caches, - use_tp # =False -): - assert ( - (checkpoint_path and checkpoint_path.is_file()) or - (checkpoint_dir and checkpoint_path.is_dir()) or - (gguf_path and gguf_path.is_file()) or - (dso_path and Path(dso_path).is_file()) or - (pte_path and Path(pte_path).is_file()) - ), "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" - assert not (dso_path and pte_path), "specify either DSO path or PTE path, but not both" - - if (checkpoint_path and (dso_path or pte_path)): - print("Warning: checkpoint path ignored because an exported DSO or PTE path specified") - if (checkpoint_dir and (dso_path or pte_path)): - print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified") - if (gguf_path and (dso_path or pte_path)): - print("Warning: GGUF path ignored because an exported DSO or PTE path specified") - - print("Loading model ...") - t0 = time.time() - model_ = _load_model( - checkpoint_path, - checkpoint_dir, - params_path, - params_table, - gguf_path, - device, - precision, - use_tp - ) - device_sync(device=device) # MKG - print(f"Time to load model: {time.time() - t0:.02f} seconds") - - if dso_path: - # make sure user did not try to set dtype - # assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export." - assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export." - try: - model = model_ - # Replace model forward with the AOT-compiled forward - # This is a hacky way to quickly demo AOTI's capability. - # model is still a Python object, and any mutation to its - # attributes will NOT be seen on by AOTI-compiled forward - # function, e.g. calling model.setup_cache will NOT touch - # AOTI compiled and maintained model buffers such as kv_cache. - model.forward = torch._export.aot_load(str(dso_path.absolute()), device) - except: - raise RuntimeError(f"Failed to load AOTI compiled {dso_path}") - elif pte_path: - # make sure user did not try to set dtype - # assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export." - assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export." - try: - from model_et import PTEModel - model = PTEModel(model_.config, pte_path) - except Exception as e: - raise RuntimeError(f"Failed to load ET compiled {pte_path}") - else: - model = model_ - - if quantize: - t0q = time.time() - quantize_model(model, quantize) - device_sync(device=device) # MKG - print(f"Time to quantize model: {time.time() - t0q:.02f} seconds") - - if setup_caches: - max_seq_length = 350 - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - model.to(dtype=precision) - - return model - - def _main( prompt: str = "Hello, my name is", chat_mode: bool = False, @@ -494,6 +340,7 @@ def _main( # will add a version of _initialize_model in future # (need additional args) if is_speculative: + from builder import _load_model draft_model = _load_model( draft_checkpoint_path, None, # checkpoint_dir diff --git a/requirements.txt b/requirements.txt index 2f1a2a7cc1..0f59dc82c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch sentencepiece numpy +gguf