Skip to content

Commit

Permalink
enable multi-transformer defination (pytorch#966)
Browse files Browse the repository at this point in the history
* add llama 3.1 8b support

* make Model and ModelArgs as model definition entrance

* make model definition support multiple transformer

* make model definition support multiple transformer

* make model definition support multiple transformer

* make input arg static in Model to support export

* fix bugs for gguf and et in new model definition architecture

* retrieve text transformer arg from modelargs

* add set_cache funtion to Model to work around PTEModel issue

* make torchchat rely on torchtune

* remove export_util

* extra torchtune dependency
  • Loading branch information
Gasoonjia authored Aug 27, 2024
1 parent 0922e65 commit d0d1105
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 84 deletions.
40 changes: 23 additions & 17 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
import torch._dynamo.config
import torch._inductor.config
import torch.nn as nn

from config.model_config import resolve_model_config
from distributed import init_distributed, ParallelDims, parallelize_llama
from distributed import (
init_distributed,
launch_distributed,
ParallelDims,
parallelize_llama,
)
from quantization.quantize import quantize_model
from torch.distributed.device_mesh import DeviceMesh
from utils.measure_time import measure_time

from build.model import Transformer
from build.model import Model
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
from distributed import launch_distributed


@dataclass
Expand Down Expand Up @@ -210,7 +214,7 @@ def __post_init__(self):

def validate_model(
self,
model: Transformer,
model: Model,
model_description: str = "model",
) -> None:
if model is None:
Expand All @@ -221,7 +225,7 @@ def validate_model(

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
use_tiktoken = model.config.use_tiktoken
use_tiktoken = model.config.text_transformer_args.use_tiktoken

if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
raise RuntimeError(
Expand Down Expand Up @@ -298,11 +302,11 @@ def _unset_gguf_kwargs(builder_args):
def _init_model_on_meta_device(builder_args):
with torch.device("meta"):
if builder_args.params_path:
return Transformer.from_params(builder_args.params_path)
return Model.from_params(builder_args.params_path)
elif builder_args.params_table:
return Transformer.from_table(builder_args.params_table)
return Model.from_table(builder_args.params_table)
else:
return Transformer.from_name(builder_args.checkpoint_path.parent.name)
return Model.from_name(builder_args.checkpoint_path.parent.name)


def _load_model_gguf(builder_args, only_config=False):
Expand All @@ -311,7 +315,7 @@ def _load_model_gguf(builder_args, only_config=False):
kwargs = {}
else:
kwargs = builder_args.gguf_kwargs
model = Transformer.from_gguf(builder_args.gguf_path, **kwargs)
model = Model.from_gguf(builder_args.gguf_path, **kwargs)
return model


Expand All @@ -334,7 +338,6 @@ def _load_model_default(builder_args, only_config=False):
mmap=True,
)
)

checkpoint = {}
for key in cps[0].keys():
if not torch.allclose(cps[0][key], cps[1][key]):
Expand All @@ -355,9 +358,10 @@ def _load_model_default(builder_args, only_config=False):

if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]

checkpoint = {"text_transformer." + k: v for k, v in checkpoint.items()}

model.load_state_dict(checkpoint, assign=True, strict=False)

model.load_state_dict(checkpoint, assign=True, strict=True)
return model


Expand All @@ -380,11 +384,13 @@ def _maybe_init_distributed(
"""
if not builder_args.use_distributed:
return None, None
dist_config = 'llama3_8B.toml' # TODO - integrate with chat cmd line
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line

world_mesh, parallel_dims = launch_distributed(dist_config)

assert world_mesh is not None and parallel_dims is not None, f"failed to launch distributed using {dist_config}"
assert (
world_mesh is not None and parallel_dims is not None
), f"failed to launch distributed using {dist_config}"

return world_mesh, parallel_dims

Expand Down Expand Up @@ -523,7 +529,7 @@ def _initialize_model(
if builder_args.setup_caches:
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length or model.config.max_seq_length
max_batch_size=1, max_seq_length=max_seq_length or model.config.text_transformer_args.max_seq_length
)

model.to(dtype=builder_args.precision)
Expand Down
26 changes: 15 additions & 11 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

import torch

from build.gguf_util import Q4_0, to_float
from build.model import Model, ModelArgs, TransformerArgs

from gguf import GGUFValueType
from quantization.qops import LinearInt4 as WeightOnlyInt4Linear
from quantization.quantize import pack_scales_and_zeros
from build.gguf_util import Q4_0, to_float
from build.model import TransformerArgs, Transformer

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -41,6 +42,7 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
result = copy.deepcopy(gguf_name)
for gguf_string, replacement in _name_replacements:
result = result.replace(gguf_string, replacement)
result = "text_transformer." + result
return result


Expand Down Expand Up @@ -107,22 +109,24 @@ def load_model(gguf_file: str) -> torch.nn.Module:
arch = metadata["general.architecture"]
assert arch == "llama", "Only LLaMa models are supported by this converter."

model_args = TransformerArgs(
dim=metadata[f"{arch}.embedding_length"],
n_layers=metadata[f"{arch}.block_count"],
n_heads=metadata[f"{arch}.attention.head_count"],
n_local_heads=metadata[f"{arch}.attention.head_count_kv"],
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
hidden_dim=metadata[f"{arch}.feed_forward_length"],
model_args = ModelArgs(
TransformerArgs(
dim=metadata[f"{arch}.embedding_length"],
n_layers=metadata[f"{arch}.block_count"],
n_heads=metadata[f"{arch}.attention.head_count"],
n_local_heads=metadata[f"{arch}.attention.head_count_kv"],
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
hidden_dim=metadata[f"{arch}.feed_forward_length"],
)
)

# TODO: what to do with rope args like
# metadata.get(f"{arch}.rope.freq_base", None)
# metadata.get(f"{arch}.rope.dimension_count", None)

with torch.device("meta"):
model = Transformer(model_args)
model = Model(model_args)
return model


Expand Down
103 changes: 71 additions & 32 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,46 @@ def __post_init__(self):
self.use_tiktoken = self.use_tiktoken == "True"

@classmethod
def from_params(cls, params_path):
def from_params(cls, params):
replace = [("rope_theta", "rope_base"), ("n_kv_heads", "n_local_heads")]
with open(params_path, "r") as f:
params = json.loads(f.read())
# Patch for llama3
for _from, _to in replace:
if _from in params:
params[_to] = params.pop(_from)
for _from, _to in replace:
if _from in params:
params[_to] = params.pop(_from)
return cls(**params)

@dataclass
class ModelArgs:
text_transformer_args: TransformerArgs

def __post_init__(self):
assert self.text_transformer_args is not None
assert type(self.text_transformer_args) == TransformerArgs

@classmethod
def from_params(cls, params_path):
with open(params_path, "r") as f:
loaded_params = json.loads(f.read())

try:
# try to interpret as a single transformer config
text_transformer_args = TransformerArgs.from_params(
loaded_params
)
except TypeError:
# try to interpret as a dict of transformer configs
for name, params in loaded_params.items():
if name == "text":
text_transformer_args = TransformerArgs.from_params(params)
else:
raise ValueError(f"Unknown transformer name {name}")

return cls(text_transformer_args)

@classmethod
def from_table(cls, name: str):
json_path = config_path / f"{name}.json"
if json_path.is_file():
return TransformerArgs.from_params(json_path)
return ModelArgs.from_params(json_path)
else:
known_model_params = [
config.replace(".json", "") for config in os.listdir(config_path)
Expand All @@ -86,7 +111,7 @@ def from_table(cls, name: str):
def from_name(cls, name: str):
json_path = config_path / f"{name}.json"
if Path(json_path).is_file():
return TransformerArgs.from_params(json_path)
return ModelArgs.from_params(json_path)

known_model_params = [
config.replace(".json", "") for config in os.listdir(config_path)
Expand All @@ -113,7 +138,7 @@ def from_name(cls, name: str):
f"Unknown model directory name {name}. Must be one of {known_model_params}."
)

return TransformerArgs.from_params(config_path / f"{config[0]}.json")
return ModelArgs.from_params(config_path / f"{config[0]}.json")


class KVCache(nn.Module):
Expand Down Expand Up @@ -144,6 +169,40 @@ def update(self, input_pos, k_val, v_val):
return k_out, v_out


class Model(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.text_transformer = Transformer(config.text_transformer_args)

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return self.text_transformer(idx, input_pos)

def setup_caches(self, max_batch_size, max_seq_length):
self.text_transformer.setup_caches(max_batch_size, max_seq_length)

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))

@classmethod
def from_table(cls, name: str):
return cls(ModelArgs.from_table(name))

@classmethod
def from_params(cls, params_path: str):
return cls(ModelArgs.from_params(params_path))

@classmethod
def from_gguf(cls, gguf_path: str, **kwargs):
from build.gguf_loader import load_model_and_state_dict

model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
if state_dict != {}:
model.load_state_dict(state_dict, assign=True)
return model


class Transformer(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
Expand Down Expand Up @@ -180,7 +239,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
self.config.dim // self.config.n_heads,
self.config.block_size * 2,
self.config.rope_base,
use_scaled = self.config.use_scaled_rope,
use_scaled=self.config.use_scaled_rope,
)
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
causal_mask = torch.tril(
Expand All @@ -201,27 +260,6 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
# print(f"logits shape: {logits.shape}")
return logits

@classmethod
def from_name(cls, name: str):
return cls(TransformerArgs.from_name(name))

@classmethod
def from_table(cls, name: str):
return cls(TransformerArgs.from_table(name))

@classmethod
def from_params(cls, params_path: str):
return cls(TransformerArgs.from_params(params_path))

@classmethod
def from_gguf(cls, gguf_path: str, **kwargs):
from build.gguf_loader import load_model_and_state_dict

model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
if state_dict != {}:
model.load_state_dict(state_dict, assign=True)
return model


class TransformerBlock(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
Expand Down Expand Up @@ -388,6 +426,7 @@ def apply_scaling(freqs: torch.Tensor):
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False
) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def apply_tp(
# after we apply TP to the model. Because we don't want to change model code
# when applying TP. We need to have change to ensure KVCache has the correct
# size as k and v.
model.config.n_local_heads = model.config.n_local_heads // tp_mesh.size()
model.config.text_transformer_args.n_local_heads = model.config.text_transformer_args.n_local_heads // tp_mesh.size()

# Apply tensor parallelism to every transformer block
for transformer_block in model.layers:
Expand Down
8 changes: 4 additions & 4 deletions docs/ADVANCED-USERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,21 @@ architecture, provided you have the model weights in llama format, the
model parameters and the tokenizer model used by your language model.

Some common models are recognized by torchchat based on their filename
through `Transformer.from_name()` to perform a fuzzy match against a
through `Model.from_name()` to perform a fuzzy match against a
table of known model architectures. Alternatively, you can specify the
index into that table with the option `--params-table ${INDEX}` where
the index is the lookup key key in the [the list of known
pconfigurations](https://github.com/pytorch/torchchat/tree/main/build/known_model_params)
For example, for the stories15M model, this would be expressed as
`--params-table stories15M`. (We use the model constructor
`Transformer.from_table()`)
`Model.from_table()`)

For models using a configuration not in the list of known
configurations, you can construct the model by initializing the
`TransformerArgs` dataclass that controls model construction from a
parameter json using the `params-path ${PARAMS_PATH}` containing the
appropriate model parameters to initialize the `TransformerArgs` for the
model. (We use the model constructor `Transformer.from_params()`).
appropriate model parameters to initialize the `ModelArgs` for the
model. (We use the model constructor `Model.from_params()`).

The parameter file should be in JSON format specifying these
parameters. You can find the `TransformerArgs` data class in
Expand Down
Loading

0 comments on commit d0d1105

Please sign in to comment.