diff --git a/build/builder.py b/build/builder.py index d8ba6e019b..635cca1523 100644 --- a/build/builder.py +++ b/build/builder.py @@ -12,19 +12,23 @@ from typing import Any, Dict, Optional, Tuple, Union import torch -import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh import torch._dynamo.config import torch._inductor.config +import torch.nn as nn from config.model_config import resolve_model_config -from distributed import init_distributed, ParallelDims, parallelize_llama +from distributed import ( + init_distributed, + launch_distributed, + ParallelDims, + parallelize_llama, +) from quantization.quantize import quantize_model +from torch.distributed.device_mesh import DeviceMesh from utils.measure_time import measure_time -from build.model import Transformer +from build.model import Model from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype -from distributed import launch_distributed @dataclass @@ -210,7 +214,7 @@ def __post_init__(self): def validate_model( self, - model: Transformer, + model: Model, model_description: str = "model", ) -> None: if model is None: @@ -221,7 +225,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - use_tiktoken = model.config.use_tiktoken + use_tiktoken = model.config.text_transformer_args.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( @@ -298,11 +302,11 @@ def _unset_gguf_kwargs(builder_args): def _init_model_on_meta_device(builder_args): with torch.device("meta"): if builder_args.params_path: - return Transformer.from_params(builder_args.params_path) + return Model.from_params(builder_args.params_path) elif builder_args.params_table: - return Transformer.from_table(builder_args.params_table) + return Model.from_table(builder_args.params_table) else: - return Transformer.from_name(builder_args.checkpoint_path.parent.name) + return Model.from_name(builder_args.checkpoint_path.parent.name) def _load_model_gguf(builder_args, only_config=False): @@ -311,7 +315,7 @@ def _load_model_gguf(builder_args, only_config=False): kwargs = {} else: kwargs = builder_args.gguf_kwargs - model = Transformer.from_gguf(builder_args.gguf_path, **kwargs) + model = Model.from_gguf(builder_args.gguf_path, **kwargs) return model @@ -334,7 +338,6 @@ def _load_model_default(builder_args, only_config=False): mmap=True, ) ) - checkpoint = {} for key in cps[0].keys(): if not torch.allclose(cps[0][key], cps[1][key]): @@ -355,9 +358,10 @@ def _load_model_default(builder_args, only_config=False): if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] + + checkpoint = {"text_transformer." + k: v for k, v in checkpoint.items()} - model.load_state_dict(checkpoint, assign=True, strict=False) - + model.load_state_dict(checkpoint, assign=True, strict=True) return model @@ -380,11 +384,13 @@ def _maybe_init_distributed( """ if not builder_args.use_distributed: return None, None - dist_config = 'llama3_8B.toml' # TODO - integrate with chat cmd line + dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line world_mesh, parallel_dims = launch_distributed(dist_config) - assert world_mesh is not None and parallel_dims is not None, f"failed to launch distributed using {dist_config}" + assert ( + world_mesh is not None and parallel_dims is not None + ), f"failed to launch distributed using {dist_config}" return world_mesh, parallel_dims @@ -523,7 +529,7 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): model.setup_caches( - max_batch_size=1, max_seq_length=max_seq_length or model.config.max_seq_length + max_batch_size=1, max_seq_length=max_seq_length or model.config.text_transformer_args.max_seq_length ) model.to(dtype=builder_args.precision) diff --git a/build/gguf_loader.py b/build/gguf_loader.py index 897a5a1700..986db69d8c 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -13,11 +13,12 @@ import torch +from build.gguf_util import Q4_0, to_float +from build.model import Model, ModelArgs, TransformerArgs + from gguf import GGUFValueType from quantization.qops import LinearInt4 as WeightOnlyInt4Linear from quantization.quantize import pack_scales_and_zeros -from build.gguf_util import Q4_0, to_float -from build.model import TransformerArgs, Transformer logger: logging.Logger = logging.getLogger(__name__) @@ -41,6 +42,7 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: result = copy.deepcopy(gguf_name) for gguf_string, replacement in _name_replacements: result = result.replace(gguf_string, replacement) + result = "text_transformer." + result return result @@ -107,14 +109,16 @@ def load_model(gguf_file: str) -> torch.nn.Module: arch = metadata["general.architecture"] assert arch == "llama", "Only LLaMa models are supported by this converter." - model_args = TransformerArgs( - dim=metadata[f"{arch}.embedding_length"], - n_layers=metadata[f"{arch}.block_count"], - n_heads=metadata[f"{arch}.attention.head_count"], - n_local_heads=metadata[f"{arch}.attention.head_count_kv"], - vocab_size=len(metadata["tokenizer.ggml.tokens"]), - norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], - hidden_dim=metadata[f"{arch}.feed_forward_length"], + model_args = ModelArgs( + TransformerArgs( + dim=metadata[f"{arch}.embedding_length"], + n_layers=metadata[f"{arch}.block_count"], + n_heads=metadata[f"{arch}.attention.head_count"], + n_local_heads=metadata[f"{arch}.attention.head_count_kv"], + vocab_size=len(metadata["tokenizer.ggml.tokens"]), + norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], + hidden_dim=metadata[f"{arch}.feed_forward_length"], + ) ) # TODO: what to do with rope args like @@ -122,7 +126,7 @@ def load_model(gguf_file: str) -> torch.nn.Module: # metadata.get(f"{arch}.rope.dimension_count", None) with torch.device("meta"): - model = Transformer(model_args) + model = Model(model_args) return model diff --git a/build/model.py b/build/model.py index 27d1500bbf..94174542a0 100644 --- a/build/model.py +++ b/build/model.py @@ -59,21 +59,46 @@ def __post_init__(self): self.use_tiktoken = self.use_tiktoken == "True" @classmethod - def from_params(cls, params_path): + def from_params(cls, params): replace = [("rope_theta", "rope_base"), ("n_kv_heads", "n_local_heads")] - with open(params_path, "r") as f: - params = json.loads(f.read()) - # Patch for llama3 - for _from, _to in replace: - if _from in params: - params[_to] = params.pop(_from) + for _from, _to in replace: + if _from in params: + params[_to] = params.pop(_from) return cls(**params) +@dataclass +class ModelArgs: + text_transformer_args: TransformerArgs + + def __post_init__(self): + assert self.text_transformer_args is not None + assert type(self.text_transformer_args) == TransformerArgs + + @classmethod + def from_params(cls, params_path): + with open(params_path, "r") as f: + loaded_params = json.loads(f.read()) + + try: + # try to interpret as a single transformer config + text_transformer_args = TransformerArgs.from_params( + loaded_params + ) + except TypeError: + # try to interpret as a dict of transformer configs + for name, params in loaded_params.items(): + if name == "text": + text_transformer_args = TransformerArgs.from_params(params) + else: + raise ValueError(f"Unknown transformer name {name}") + + return cls(text_transformer_args) + @classmethod def from_table(cls, name: str): json_path = config_path / f"{name}.json" if json_path.is_file(): - return TransformerArgs.from_params(json_path) + return ModelArgs.from_params(json_path) else: known_model_params = [ config.replace(".json", "") for config in os.listdir(config_path) @@ -86,7 +111,7 @@ def from_table(cls, name: str): def from_name(cls, name: str): json_path = config_path / f"{name}.json" if Path(json_path).is_file(): - return TransformerArgs.from_params(json_path) + return ModelArgs.from_params(json_path) known_model_params = [ config.replace(".json", "") for config in os.listdir(config_path) @@ -113,7 +138,7 @@ def from_name(cls, name: str): f"Unknown model directory name {name}. Must be one of {known_model_params}." ) - return TransformerArgs.from_params(config_path / f"{config[0]}.json") + return ModelArgs.from_params(config_path / f"{config[0]}.json") class KVCache(nn.Module): @@ -144,6 +169,40 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out +class Model(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.text_transformer = Transformer(config.text_transformer_args) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + return self.text_transformer(idx, input_pos) + + def setup_caches(self, max_batch_size, max_seq_length): + self.text_transformer.setup_caches(max_batch_size, max_seq_length) + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + @classmethod + def from_table(cls, name: str): + return cls(ModelArgs.from_table(name)) + + @classmethod + def from_params(cls, params_path: str): + return cls(ModelArgs.from_params(params_path)) + + @classmethod + def from_gguf(cls, gguf_path: str, **kwargs): + from build.gguf_loader import load_model_and_state_dict + + model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) + if state_dict != {}: + model.load_state_dict(state_dict, assign=True) + return model + + class Transformer(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() @@ -180,7 +239,7 @@ def setup_caches(self, max_batch_size, max_seq_length): self.config.dim // self.config.n_heads, self.config.block_size * 2, self.config.rope_base, - use_scaled = self.config.use_scaled_rope, + use_scaled=self.config.use_scaled_rope, ) self.register_buffer("freqs_cis", freqs_cis, persistent=True) causal_mask = torch.tril( @@ -201,27 +260,6 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: # print(f"logits shape: {logits.shape}") return logits - @classmethod - def from_name(cls, name: str): - return cls(TransformerArgs.from_name(name)) - - @classmethod - def from_table(cls, name: str): - return cls(TransformerArgs.from_table(name)) - - @classmethod - def from_params(cls, params_path: str): - return cls(TransformerArgs.from_params(params_path)) - - @classmethod - def from_gguf(cls, gguf_path: str, **kwargs): - from build.gguf_loader import load_model_and_state_dict - - model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) - if state_dict != {}: - model.load_state_dict(state_dict, assign=True) - return model - class TransformerBlock(nn.Module): def __init__(self, config: TransformerArgs) -> None: @@ -388,6 +426,7 @@ def apply_scaling(freqs: torch.Tensor): new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + def precompute_freqs_cis( n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False ) -> Tensor: diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index cbcb29b725..f0d12d7696 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -59,7 +59,7 @@ def apply_tp( # after we apply TP to the model. Because we don't want to change model code # when applying TP. We need to have change to ensure KVCache has the correct # size as k and v. - model.config.n_local_heads = model.config.n_local_heads // tp_mesh.size() + model.config.text_transformer_args.n_local_heads = model.config.text_transformer_args.n_local_heads // tp_mesh.size() # Apply tensor parallelism to every transformer block for transformer_block in model.layers: diff --git a/docs/ADVANCED-USERS.md b/docs/ADVANCED-USERS.md index 12ad5f229d..5a8d41db04 100644 --- a/docs/ADVANCED-USERS.md +++ b/docs/ADVANCED-USERS.md @@ -112,21 +112,21 @@ architecture, provided you have the model weights in llama format, the model parameters and the tokenizer model used by your language model. Some common models are recognized by torchchat based on their filename -through `Transformer.from_name()` to perform a fuzzy match against a +through `Model.from_name()` to perform a fuzzy match against a table of known model architectures. Alternatively, you can specify the index into that table with the option `--params-table ${INDEX}` where the index is the lookup key key in the [the list of known pconfigurations](https://github.com/pytorch/torchchat/tree/main/build/known_model_params) For example, for the stories15M model, this would be expressed as `--params-table stories15M`. (We use the model constructor -`Transformer.from_table()`) +`Model.from_table()`) For models using a configuration not in the list of known configurations, you can construct the model by initializing the `TransformerArgs` dataclass that controls model construction from a parameter json using the `params-path ${PARAMS_PATH}` containing the -appropriate model parameters to initialize the `TransformerArgs` for the -model. (We use the model constructor `Transformer.from_params()`). +appropriate model parameters to initialize the `ModelArgs` for the +model. (We use the model constructor `Model.from_params()`). The parameter file should be in JSON format specifying these parameters. You can find the `TransformerArgs` data class in diff --git a/eval.py b/eval.py index 30880567d6..b79757ea03 100644 --- a/eval.py +++ b/eval.py @@ -16,7 +16,7 @@ TokenizerArgs, ) -from build.model import Transformer +from build.model import Model from build.utils import set_precision from cli import add_arguments_for_verb, arg_init from utils.measure_time import measure_time @@ -35,7 +35,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - model: Transformer, + model: Model, prompt: torch.Tensor, max_new_tokens: int, max_seq_length: Optional[int] = None, @@ -58,7 +58,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( T = prompt.size(0) T_new = T + max_new_tokens if max_seq_length is None: - max_seq_length = min(T_new, model.config.block_size) + max_seq_length = min(T_new, model.config.text_transformer_args.block_size) device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and @@ -81,7 +81,7 @@ class GPTFastEvalWrapper(eval_wrapper): def __init__( self, - model: Transformer, + model: Model, tokenizer, model_forward: Optional[Callable] = None, max_seq_length: Optional[int] = None, @@ -169,7 +169,7 @@ def _model_generate(self, context, max_length, eos_token_id): @torch.no_grad() def eval( - model: Transformer, + model: Model, model_forward: Callable, tokenizer, tasks: Optional[list] = None, @@ -182,7 +182,7 @@ def eval( Evaluates a language model on a specified task using the lm-evaluation-harness library. Args: - model (Transformer): The pre-trained language model to evaluate. + model (Model): The pre-trained language model to evaluate. tokenizer: The tokenizer to use for encoding/decoding text. tasks (Optional[list]): The names of the evaluation tasks to perform. limit (Optional[int]): The maximum number of samples to evaluate (None for all available). diff --git a/export.py b/export.py index b82f863d03..2b85fbb11e 100644 --- a/export.py +++ b/export.py @@ -56,7 +56,7 @@ def export_for_server( torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) - seq = Dim("seq", min=1, max=model.config.max_seq_length) + seq = Dim("seq", min=1, max=model.config.text_transformer_args.max_seq_length) # Specify that the first dimension of each input is that batch size dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} else: diff --git a/generate.py b/generate.py index cb0048d108..ee69e574f4 100644 --- a/generate.py +++ b/generate.py @@ -25,7 +25,7 @@ BuilderArgs, TokenizerArgs, ) -from build.model import Transformer +from build.model import Model from build.utils import device_sync, set_precision from cli import add_arguments_for_verb, arg_init, check_args from utils.device_info import get_device_info @@ -303,7 +303,7 @@ def sample( def prefill( self, - model: Transformer, + model: Model, x: torch.Tensor, input_pos: torch.Tensor, *, @@ -329,7 +329,7 @@ def prefill( def decode_one_token( self, - model: Transformer, + model: Model, x: torch.Tensor, input_pos: torch.Tensor, need_probs: bool, @@ -349,7 +349,7 @@ def decode_one_token( def decode_n_tokens( self, - model: Transformer, + model: Model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, @@ -418,8 +418,8 @@ def model_forward(self, model, x, input_pos): def speculative_decode( self, - model: Transformer, - draft_model: Transformer, + model: Model, + draft_model: Model, cur_token: torch.Tensor, input_pos: int, speculate_k: int, @@ -483,13 +483,13 @@ def speculative_decode( @torch.no_grad() def generate( self, - model: Transformer, + model: Model, prompt: torch.Tensor, max_new_tokens: int, *, chat_mode: bool, start_pos: int = 0, - draft_model: Transformer, + draft_model: Model, speculate_k: Optional[int] = 8, sequential_prefill=True, callback=lambda x: x, @@ -676,7 +676,7 @@ def chat( self.system_prompt = None # Set up our max_seq_length if generator_args.chat_mode: - max_seq_length = self.model.config.max_seq_length + max_seq_length = self.model.config.text_transformer_args.max_seq_length print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -689,7 +689,7 @@ def chat( else: max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, - self.model.config.block_size, + self.model.config.text_transformer_args.block_size, ) max_seq_length = ( diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 3e059693ac..ab3e15e0bd 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -232,11 +232,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_seq_length = ( - self.model.config.max_seq_length + self.model.config.text_transformer_args.max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None - else self.model.config.max_seq_length + else self.model.config.text_transformer_args.max_seq_length ) # The System fingerprint is a unique identifier for the model and its configuration. self.system_fingerprint = (