Skip to content

Commit e417818

Browse files
committed
new gemlite integration using pip install
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent a8fd19c commit e417818

File tree

9 files changed

+8
-688
lines changed

9 files changed

+8
-688
lines changed

torchao/_models/llama/benchmark_results.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,5 @@ bs4
6969
20241008155928, tok/s= 49.45, mem/s= 214.18 GB/s, peak_mem= 7.81 GB, model_size= 4.33 GB quant: gemlite-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8
7070
20241008160515, tok/s= 51.74, mem/s= 224.09 GB/s, peak_mem= 7.79 GB, model_size= 4.33 GB quant: gemlite-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8
7171

72-
20241029013738, tok/s= 12.81, mem/s= 1.40 GB/s, peak_mem=14.55 GB, model_size= 0.11 GB quant: gemlite-4-128, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-128 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
72+
20241029013738, tok/s= 12.81, mem/s= 1.40 GB/s, peak_mem=14.55 GB, model_size= 0.11 GB quant: gemlite-4-128, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-128 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
73+
20241029015254, tok/s= 12.73, mem/s= 1.39 GB/s, peak_mem=14.55 GB, model_size= 0.11 GB quant: gemlite-4-128, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-128 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8

torchao/_models/llama/benchmarks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
22

33
# README BENCHMARKS
44
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
5-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-128 --num_samples 1 --write_result benchmark_results.txt
5+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-128 --num_samples 1 --write_result benchmark_results.txt --batch_size 16
66

77

88
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt

torchao/_models/llama/generate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def main(
240240
assert group_size in [64, 128, 256], f"group_size needs to be in [64, 128, 256], got {group_size} for gemlite-<W_nbits>-<group_size>"
241241
assert precision == torch.float16, f"gemlite only supports float16 precision, got {precision}"
242242

243+
quant_config = BaseQuantizeConfig(nbits=W_nbits, group_size=group_size, quant_zero=False, quant_scale=False, axis=1)
244+
quant_config['weight_quant_params']['optimize'] = False
245+
243246
def replace_fn(mod):
244247
if not isinstance(mod, torch.nn.Linear):
245248
return mod
@@ -250,8 +253,7 @@ def replace_fn(mod):
250253
compute_dtype = mod.weight.dtype
251254
input_dtype, output_dtype = DType.FP16, DType.FP16
252255

253-
quant_config = BaseQuantizeConfig(nbits=W_nbits, group_size=group_size, quant_zero=False, quant_scale=False, axis=1)
254-
quant_config['weight_quant_params']['optimize'] = False
256+
255257
hqq_layer = HQQLinear(mod, quant_config=quant_config, compute_dtype=compute_dtype, device=device, del_orig=False)
256258
orig_shape = (out_features, in_features)
257259
gemlite_linear = GemLiteLinearTriton(W_nbits=W_nbits,

torchao/_models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def prepare_inputs_for_model(inps, max_new_tokens=1):
1919
if inps.dim() > 2:
2020
raise ValueError(f"Expected input to be of dim 1 or 2, but got {inps.dim()}")
2121

22-
input_pos = torch.arange(0, inps.size(-1), device=inps.device)
22+
input_pos = torch.arange(0, inps.numel(), device=inps.device)
2323
return (inps.view(1, -1), input_pos)
2424

2525
@dataclass

torchao/quantization/prototype/gemlite/__init__.py

Whitespace-only changes.

torchao/quantization/prototype/gemlite/core.py

Lines changed: 0 additions & 238 deletions
This file was deleted.

torchao/quantization/prototype/gemlite/triton_kernels/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

0 commit comments

Comments
 (0)