Skip to content

Commit

Permalink
Gguf (pytorch#165)
Browse files Browse the repository at this point in the history
* new gguf parsing for Q40 that conforms with pytorch's quantization stack

* updates

* add q6_k and clean up q40

* cleanup

* fix bugs in q6k.  Refactor code

* clean up gguf stuff + add test

* fix gguf CI test

* change to mac

* add pre torch
  • Loading branch information
metascroy authored and malfet committed Jul 17, 2024
1 parent 4f97c9b commit 2d9b722
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 230 deletions.
49 changes: 49 additions & 0 deletions .github/workflows/gguf_load.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: Compile main

on:
push:
branches:
- main
pull_request:
workflow_dispatch:

jobs:
gguf-load-test:
strategy:
matrix:
runner: [macos-14]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout repo
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: 3.11
- name: Print machine info
run: |
uname -a
if [ $(uname -s) == Darwin ]; then
sysctl machdep.cpu.brand_string
sysctl machdep.cpu.core_count
fi
- name: Install requirements
run: |
echo "Intalling pip packages"
pip install gguf
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -r requirements.txt
- name: Download GGUF files
run: |
mkdir gguf_files
wget -O gguf_files/llama-2-7b.Q4_0.gguf "https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true"
- name: Load files
run: |
touch test.py
echo "from gguf_loader import load_llama_from_gguf_file" >> test.py
echo "pt = load_llama_from_gguf_file(\"gguf_files/llama-2-7b.Q4_0.gguf\")" >> test.py
cat test.py
python test.py
echo "Tests complete."
140 changes: 60 additions & 80 deletions gguf_util/loader.py → gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from pathlib import Path
from typing import Any, Mapping, Dict
import logging

from quantize import WeightOnlyInt4Linear, pack_scales_and_zeros, group_dequantize_tensor_from_qparams
from gguf_util import F16, F32, Q4_0, Q6_K
import gguf

import torch
Expand Down Expand Up @@ -118,7 +119,6 @@ def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs:
),
)


def _fqn_lookup(fqn: str, module: torch.nn.Module) -> Any:
if fqn == "":
return module
Expand Down Expand Up @@ -147,80 +147,60 @@ def _fqn_last(fqn: str) -> str:
return atoms[-1]


# def _load_by_state_dict(pt_model: torch.nn.Module, state_dict: Dict[str, Any], fqn: str, gguf_tensor: ReaderTensor) -> bool:
# if gguf_tensor.tensor_type in (gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16):
# reversed_shape = gguf_tensor.shape[::-1]
# new_tensor = gguf_tensor.data.reshape(reversed_shape)
# state_dict[fqn] = torch.from_numpy(new_tensor)
# return True
# elif gguf_tensor.tensor_type == gguf.GGMLQuantizationType.Q4_0 and gguf_tensor.name == "token_embd.weight":
# unpacked = Q4_0.to_float(torch.from_numpy(gguf_tensor.data.reshape(-1, 18)))
# state_dict[fqn] = unpacked.reshape(
# pt_model.config.vocab_size, pt_model.config.dim
# )
# return True
# return False


# def _load_by_parameter(pt_model: torch.nn.Module, fqn: str, gguf_tensor: ReaderTensor) -> bool:
# assert isinstance(_fqn_lookup(fqn, pt_model), torch.nn.Parameter)
# parent: torch.nn.Module = _fqn_lookup(_fqn_up(fqn), pt_model)

# if gguf_tensor.tensor_type == gguf.GGMLQuantizationType.Q4_0 and isinstance(parent, torch.nn.Linear) and _fqn_last(fqn) == "weight":
# packed = torch.from_numpy(gguf_tensor.data).reshape(-1, 18)
# scale = Q4_0._unpack_two_uint8(packed[:, :2]).to(dtype=torch.float16)
# parent.weight = torch.nn.Parameter(
# Q4_0.GGMLInt4LinearWeight(packed, scale, parent.weight.shape)
# )
# pt_model = pt_model.to(dtype=torch.float32)
# return True

# return False


# def _load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor]) -> None:
# loaded_by_state_dict: Set[str] = set()
# loaded_by_parameter: Set[str] = set()

# # state_dict pass
# logger.info("Loading weights by state_dict.")
# state_dict = {}
# for fqn in pt_model.state_dict():
# if fqn not in weight_map:
# continue
# tensor = weight_map[fqn]
# loaded = _load_by_state_dict(pt_model, state_dict, fqn, tensor)
# if loaded:
# loaded_by_state_dict.add(fqn)

# # allow partial loading
# pt_model.load_state_dict(state_dict, strict=False)

# # parameter pass
# logger.info("Loading weights by parameter.")
# for fqn, param in pt_model.named_parameters():
# if fqn not in weight_map:
# continue
# tensor = weight_map[fqn]
# loaded = _load_by_parameter(pt_model, fqn, tensor)
# if loaded:
# loaded_by_parameter.add(fqn)

# # Sanity checks
# for fqn in loaded_by_state_dict:
# if not(fqn not in loaded_by_parameter):
# msg = f"{fqn} was loaded by both state_dict and parameter"
# raise Exception(msg)

# for fqn in weight_map:
# if not (fqn in (loaded_by_state_dict | loaded_by_parameter)):
# msg = f"{fqn} in weight_map was not loaded"
# raise Exception(msg)

# for fqn in pt_model.state_dict():
# if not (fqn in (loaded_by_state_dict | loaded_by_parameter)):
# msg = f"{fqn} in model.state_dict() was not loaded"
# raise Exception(msg)
def load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], inner_k_tiles = 8) -> None:
fqns = []
for fqn in pt_model.state_dict():
assert _fqn_last(fqn) == "weight"
fqns.append(_fqn_up(fqn))

state_dict = {}
for fqn in fqns:
mod = _fqn_lookup(fqn, pt_model)

t = weight_map[f"{fqn}.weight"]

if isinstance(mod, torch.nn.Linear) and t.tensor_type == gguf.GGMLQuantizationType.Q4_0:
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
assert all(t.shape == (in_features, out_features))

q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)

state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')

parent = _fqn_lookup(_fqn_up(fqn), pt_model)
setattr(
parent,
_fqn_last(fqn),
WeightOnlyInt4Linear(
in_features, out_features,
bias=False,
groupsize=Q4_0.group_size,
inner_k_tiles=inner_k_tiles,
use_cuda=False
)
)
else:
# All other weights are dequantized to float
if t.tensor_type == gguf.GGMLQuantizationType.Q4_0:
as_float = group_dequantize_tensor_from_qparams(*Q4_0.unpack(t), Q4_0.n_bit, Q4_0.group_size)
elif t.tensor_type == gguf.GGMLQuantizationType.Q6_K:
as_float = group_dequantize_tensor_from_qparams(*Q6_K.unpack(t), Q6_K.n_bit, Q6_K.group_size)
elif t.tensor_type == gguf.GGMLQuantizationType.F16:
as_float = F16.unpack(t)
elif t.tensor_type == gguf.GGMLQuantizationType.F32:
as_float = F32.unpack(t)
else:
raise ValueError(f"Unsupported tensor type {t.tensor_type}")

state_dict[f"{fqn}.weight"] = as_float.to('cpu')

pt_model.load_state_dict(state_dict)
return pt_model


def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
Expand All @@ -245,7 +225,7 @@ def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:

return metadata

# TODO: finish weight loading

def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module:
"""
Load a LLaMa model from a GGUF file and return a PT nn.Module.
Expand Down Expand Up @@ -274,7 +254,7 @@ def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module:
_convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor
for tensor in gguf_weights.tensors
}
# logger.info("Loading GGUF weights into PT model.")
# _load_weights(pt_model, weight_map)

return pt_model, weight_map
logger.info("Loading weights into state_dict")
pt_model = load_weights(pt_model, weight_map, inner_k_tiles=8)
return pt_model
Loading

0 comments on commit 2d9b722

Please sign in to comment.