-
Notifications
You must be signed in to change notification settings - Fork 266
[Quantization] Attention/ KV Cache Refactor #1651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e014b64
squash
kylesayrs b98e61a
target attention for hooks, allow non-forced zps
kylesayrs 4a38e0d
fix bug where observers were called twice
kylesayrs 5004894
Merge branch 'main' into kylesayrs/transform-online
kylesayrs 13972d6
add todo
kylesayrs a985fa7
Merge branch 'main' into kylesayrs/transform-online
kylesayrs 3667bda
zoinks!
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
| from llmcompressor.utils import dispatch_for_generation | ||
| from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs | ||
|
|
||
| # Select model and load it. | ||
| model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Select number of samples. 512 samples is a good place to start. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 512 | ||
dsikka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Load dataset and preprocess. | ||
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| # Tokenize inputs. | ||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| recipe = QuantizationModifier( | ||
| config_groups={ | ||
| "attention": QuantizationScheme( | ||
| targets=["LlamaAttention"], | ||
| input_activations=QuantizationArgs( | ||
| num_bits=8, type="float", strategy="attn_head" | ||
| ), | ||
| ) | ||
| } | ||
| ) | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("\n\n") | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_for_generation(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {key: value.to(model.device) for key, value in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================\n\n") | ||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-attention-fp8-head" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| # ruff: noqa | ||
|
|
||
| from .cache import * | ||
| from .gptq import * | ||
| from .quantization import * |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.