Skip to content

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Aug 26, 2025

Purpose

  • Support fully-expressive attention and kv cache quantization
  • Support running kv cache quantization evals with hf transformers
10cf70de-d58b-4e78-9851-bab24e91d228

Prerequisites

Changes

New Classes

  • Add hookable attention and kvcache implementations which are registered to the attention module as submodules
    • QuantizedAttentionImpl injects itself into the model by registering a new attention implementation called ct_hooked_attention overriding model.config._attn_implementation to be the new implementation name
    • QuantizedKVCache injects itself into the model by overriding the past_key_values input kwarg to attention, and wrapping the functionality of the original cache
    • Calibration and transform hooks can be added to these modules via the hook functions
      • register_query_hook,
      • register_key_hook
      • register_value_hook

Quantization Lifecycle Changes

  • Apply
    • The kv_cache_scheme field of the quantization config is now used to call initialize_hooked_kv_cache
    • Attention modules can now be targeted, and are used to call initialize_hooked_attention if attention modules are explicitly targeted (see is_narrow_match)
    • Remove logic for "merging" kv cache schemes (this doesn't really make any sense, I'm not sure why it was ever included)
  • Initialize
    • Hooked kv cache and attention modules have their quantization parameters initialized by initialize_module_for_quantization
    • The presence of attention or kvcache submodules is what determines whether attention or kv cache only quantization is being applied
  • Serialization
    • QuantizationConfig.from_pretrained was cleaned up with additional comments
    • The kv_cache_scheme field is added if there are any attention modules with a quantization_scheme attached

Helpers

  • is_narrow_match is used to check that attention modules are being specifically targeted (rather than targeting all modules in a layer)
  • get_num_attn_heads, get_num_kv_heads, get_head_dim get attention config values from config

Testing

  • Added tests for is_narrow_match
  • Added tests for added attention and kvcache classes
  • Quantized models
    • kylesayrs/Llama-3.2-1B-Instruct-attention-fp8-head
    • kylesayrs/Llama-3.2-1B-Instruct-attention-nvfp4-head

Evaluation

eval.py
import sys
import lm_eval

model_id = sys.argv[1]

print(model_id)
results = lm_eval.simple_evaluate(
    # 3) hf serialized
    model="hf",
    model_args={
        "pretrained": model_id,
        "add_bos_token": False,
        "dtype": "auto",
        "device_map": "cuda",
        #"max_length": 128000,
    },
    device="cuda",
    # 3/)

    #tasks=["gsm8k_platinum", "mmlu_llama", "longbench2_single"],
    tasks=["gsm8k_platinum"],
    batch_size=64,
    apply_chat_template=True,
    fewshot_as_multiturn=True,
)
print(model_id)
print(lm_eval.utils.make_table(results))
compress.py
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs

# Select model and load it.
#model_id = "Qwen/Qwen2.5-14B-Instruct-1M"
model_id = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Select calibration dataset.
DATASET_ID = "ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Configure the quantization algorithm to run.
args = QuantizationArgs(
    num_bits=8,
    type="float",
    strategy="attn_head",
    symmetric=True,
    observer="static_minmax",
)
recipe = QuantizationModifier(
    # config_groups={
    #     "attention": QuantizationScheme(
    #         #targets=["Qwen2Attention"],
    #         targets=["LlamaAttention"],
    #         input_activations=args,
    #     )
    # }
    kv_cache_scheme=args,
)

# Apply algorithms.
oneshot(
    model=model,
    dataset=DATASET_ID,
    splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + f"-KV-FP8-{args.strategy}-{args.observer}"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
Model GSM8K
Llama-3.1-8B-Instruct 0.8337
Llama-3.1-8B-Instruct-KV-FP8-Tensor 0.8271
Llama-3.1-8B-Instruct-KV-FP8-Head 0.8354
Llama-3.1-8B-Instruct-QKV-FP8-Tensor 0.8321
Llama-3.1-8B-Instruct-QKV-FP8-Head 0.8238

Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, though i have a number of questions and minor suggestions

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the goal is to use this generally for kv_cache and attn quantize, can we move the initialize_hooked_attention and initialize_hooked_kv_cache to initialize.py?

I understand we haven't hooked them in yet for those workflows but I think these belong there.

dsikka
dsikka previously approved these changes Sep 2, 2025
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do a pass through on any missing docstring, otherwise lgtm.
nice work

Base automatically changed from kylesayrs/transform-simplify-key to main September 8, 2025 18:46
@dsikka dsikka dismissed stale reviews from brian-dellabetta and themself September 8, 2025 18:46

The base branch was changed.

@kylesayrs kylesayrs force-pushed the kylesayrs/r3-only branch 2 times, most recently from e224a5d to 05ec17e Compare October 8, 2025 19:20
@kylesayrs kylesayrs changed the base branch from main to kylesayrs/add-attn-head-strat October 8, 2025 19:20
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following for the most part. A few clarifications, but this makes sense to me

@kylesayrs kylesayrs marked this pull request as draft October 8, 2025 21:06
@kylesayrs kylesayrs force-pushed the kylesayrs/add-attn-head-strat branch from d084c5e to e3f24d4 Compare October 9, 2025 14:19
@kylesayrs kylesayrs changed the base branch from kylesayrs/add-attn-head-strat to main October 9, 2025 18:14
@kylesayrs kylesayrs dismissed brian-dellabetta’s stale review October 9, 2025 18:14

The base branch was changed.

@kylesayrs kylesayrs changed the base branch from main to kylesayrs/add-attn-head-strat October 9, 2025 18:15
Base automatically changed from kylesayrs/add-attn-head-strat to main October 9, 2025 20:11
@kylesayrs
Copy link
Collaborator Author

@kylesayrs kylesayrs marked this pull request as ready for review October 13, 2025 20:41
@kylesayrs
Copy link
Collaborator Author

Last nightly worked, but e2e failed due to model storage issues
https://github.com/neuralmagic/llm-compressor-testing/actions/runs/18483826999

@kylesayrs kylesayrs force-pushed the kylesayrs/r3-only branch 2 times, most recently from 4cc5ace to 9ead292 Compare October 14, 2025 04:21
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can resolve the global var thread, I have another new comment we might want to consider in a follow-up but marking this as approved. Cool stuff! Excited to see it in action

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some questions. Otherwise, LGTM

if scheme.weights is not None:
raise ValueError(
"Cannot apply weight quantization to attention. "
"Instead, target (q|k|v)_proj"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error doesnt make a lot of sense / took me a while to realize you're saying that if you want to do weight quantization, you should target the linear layers in the attn block, not attention itself.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this clearer?

raise ValueError(
  "Cannot apply weight quantization to attention. "
  "Instead, target the (q|k|v)_proj submodule layers of attention"

"""
if not hasattr(module, KV_CACHE_ATTR):
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm reading this correctly, _kv_cache_attention_hook is called before every forward pass? So we're replacing the kv_cache before every forward pass with the new quantized cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's exactly correct. I've buffed up the docstrings to make this clearer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QuantizedKVCache injects itself into the model by overriding the past_key_values input kwarg to attention, and wrapping the functionality of the original cache

# ----- hooks ----- #


def register_key_hook(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't seem to find where the key / value hooks get registered

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hooks are used to attach observer hooks (and any other hooks we might want to add in the future), see here

# infer format
if format is None:
if quantization_status == QuantizationStatus.COMPRESSED:
if model_status == QuantizationStatus.COMPRESSED:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is unrelated but defaulting to int doesnt make a lot of sense either

Copy link
Collaborator Author

@kylesayrs kylesayrs Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This was the original behavior of this logic.

quantization_status = None
ignore = {}
quantization_type_names = set()
from compressed_tensors.quantization.lifecycle.initialize import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning this up. It doesn't seem like we're adding anything here, apart from how we're fetching the kv_cache scheme?

I still find our ignore logic very confusing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I entirely agree, I've created an issue to track potential removal #494.

This PR does not change behavior, only makes the existing logic easier to read and adds this line to infer kv cache scheme

# attention quantization implies kv cache quantization
if is_attention_module(submodule):
    kv_cache_scheme = submodule.quantization_scheme.input_activations

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of completeness, do you mind adding your kv_cache and attn quantized sample models to this PR description?

)
else:
ret = (key_states, value_states)
self.past_key_values = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we set this to None?

Copy link
Collaborator Author

@kylesayrs kylesayrs Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensures that the cache is only used once. This should theoretically never be a problem, since the self.past_key_values attribute is always written to by the _kv_cache_attention_hook, but this is done just for peace of mind and to avoid dangling references, even if they are weak.

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs
Copy link
Collaborator Author

Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

impressive work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants