Skip to content

Commit

Permalink
Add support for save quantized checkpoint in llama code (#553)
Browse files Browse the repository at this point in the history
Summary:
The goal is to upload a torchao quantized model to huggingface so that we can run the model in huggingface

Test Plan:
python generate.py -q int4wo-32 --save

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored and jainapurva committed Aug 7, 2024
1 parent 665f574 commit beba1d5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
17 changes: 14 additions & 3 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def format_value(value):

print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))

def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, save, batch_size, max_length):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)

if compile:
if quantization == "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

if quantization == "int8dq":
Expand All @@ -57,6 +57,10 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
quantize_(model.to(device=device), int4_weight_only())
elif quantization == "autoquant":
model = autoquant(model.to(device=device))

if quantization != "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

with torch.no_grad():
result = evaluate(
HFLM(
Expand All @@ -70,6 +74,12 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi

pretty_print_nested_results(result)

if save:
# This doesn't work yet: https://github.com/huggingface/transformers/issues/32364
# model.save_pretrained("quantized_model_test", safe_serialization=False)
file_name = repo_id.split("/")[-1] + "-" + quantization + ".pt"
torch.save(model.state_dict(), file_name)


if __name__ == '__main__':
import argparse
Expand All @@ -81,8 +91,9 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')

args = parser.parse_args()
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.save, args.batch_size, args.max_length)
10 changes: 9 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import time
from pathlib import Path
Expand Down Expand Up @@ -165,6 +166,7 @@ def main(
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
kv_cache_quantization: bool = False,
save: bool = False,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
Expand Down Expand Up @@ -238,6 +240,11 @@ def main(

model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

if save:
output_dir = str(checkpoint_path.cwd())
filename = str(checkpoint_path.name).split(".")[0]
torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt"))

if compile:
print("Compiling Model")
global decode_one_token, prefill
Expand Down Expand Up @@ -362,6 +369,7 @@ def callback(x):
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
Expand All @@ -372,5 +380,5 @@ def callback(x):
args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
)

0 comments on commit beba1d5

Please sign in to comment.