Skip to content

Conversation

@kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Jul 16, 2025

Purpose

recipe = QuantizationModifier(
    config_groups={
        "attention": QuantizationScheme(
            targets=["LlamaAttention"],
            input_activations=QuantizationArgs(
                num_bits=8, type="float", strategy="tensor"
            ),
        )
    }
)
{
  "quantization_config": {
    "config_groups": {
      "group_0": {
        "format": null,
        "input_activations": {
          "dynamic": false,
          "num_bits": 8,
          "observer": "minmax",
          "strategy": "tensor",
          "symmetric": true,
          "type": "float"
        },
        "output_activations": null,
        "targets": [
          "LlamaAttention"
        ],
        "weights": null
      }
    },
    "format": "dense",
    "ignore": [],
    "kv_cache_scheme": {
      "dynamic": false,
      "group_size": null,
      "num_bits": 8,
      "observer": "minmax",
      "strategy": "tensor",
      "symmetric": true,
      "type": "float"
    },
    "quant_method": "compressed-tensors",
    "quantization_status": "frozen",
  },
}

Prerequisites

Changes

Testing

  • Kv cache regression tests pass
  • Able to quantize attention with scripts (will add to examples once loadable in vllm)
    • kylesayrs/Llama-3.2-1B-Instruct-attention-fp8-head
    • kylesayrs/Llama-3.2-1B-Instruct-attention-nvfp4-head
  • Nightly passes (in progress)

Evaluation

eval.py
import sys
import lm_eval

model_id = sys.argv[1]

print(model_id)
results = lm_eval.simple_evaluate(
    # 3) hf serialized
    model="hf",
    model_args={
        "pretrained": model_id,
        "add_bos_token": False,
        "dtype": "auto",
        "device_map": "cuda",
        #"max_length": 128000,
    },
    device="cuda",
    # 3/)

    #tasks=["gsm8k_platinum", "mmlu_llama", "longbench2_single"],
    tasks=["gsm8k_platinum"],
    batch_size=64,
    apply_chat_template=True,
    fewshot_as_multiturn=True,
)
print(model_id)
print(lm_eval.utils.make_table(results))
compress.py
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs

# Select model and load it.
#model_id = "Qwen/Qwen2.5-14B-Instruct-1M"
model_id = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Select calibration dataset.
DATASET_ID = "ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Configure the quantization algorithm to run.
args = QuantizationArgs(
    num_bits=8,
    type="float",
    strategy="attn_head",
    symmetric=True,
    observer="static_minmax",
)
recipe = QuantizationModifier(
    # config_groups={
    #     "attention": QuantizationScheme(
    #         #targets=["Qwen2Attention"],
    #         targets=["LlamaAttention"],
    #         input_activations=args,
    #     )
    # }
    kv_cache_scheme=args,
)

# Apply algorithms.
oneshot(
    model=model,
    dataset=DATASET_ID,
    splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + f"-KV-FP8-{args.strategy}-{args.observer}"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
Model GSM8K
nm-testing/Llama-3.1-8B-Instruct 0.8337
nm-testing/Llama-3.1-8B-Instruct-KV-FP8-Tensor 0.8271
nm-testing/Llama-3.1-8B-Instruct-KV-FP8-Head 0.8354
nm-testing/Llama-3.1-8B-Instruct-QKV-FP8-Tensor 0.8321
nm-testing/Llama-3.1-8B-Instruct-QKV-FP8-Head 0.8238

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @kylesayrs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the foundational components for applying 'online rotations' (specifically R1 and R2 from the SpinQuant paper) into the llmcompressor framework. It primarily introduces a new SpinQuantModifier that leverages novel model transformation utilities, such as embedding normalization and norm-linear fusion, to prepare models for more effective quantization. Additionally, it refines the handling of tied word embeddings, ensuring compatibility and robustness across various model configurations.

Highlights

  • New Feature: SpinQuantModifier: Introduced a new SpinQuantModifier to apply 'offline' rotations (R1 and R2) from the SpinQuant paper. These rotations transform model weights and activations to improve quantization accuracy without introducing runtime overhead.
  • Model Transformation Utilities: Added new utilities for normalizing embedding layers and fusing norm layers into subsequent linear layers. These are crucial preprocessing steps for applying SpinQuant rotations, ensuring transform invariance.
  • Improved Tied Word Embedding Handling: Refactored and enhanced the utility for untying word embeddings. The updated implementation is more robust, correctly handling cases where embeddings are tied, especially with offloaded parameters, and centralizes the untying logic.
  • Example Usage and Integration: Provided new example scripts (compress_model.py, spinquant_example.py) demonstrating how to use the SpinQuantModifier for model compression. The modifier is also integrated into the data-free pipeline for seamless application.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for SpinQuant online rotations, a technique for improving quantization performance. It adds a new SpinQuantModifier, along with utilities for model transformation like layer fusion and embedding normalization. The changes also include updates to the data-free pipeline, improvements to handling tied word embeddings, and new example scripts and tests.

My review identified a critical bug in a Pydantic validator within the new SpinQuantModifier that prevents it from being used. I've also pointed out a few medium-severity issues, including a required argument missing in a script, brittle directory name construction, a documentation typo, and a maintainability concern with a hardcoded pipeline selection. Addressing these points will improve the correctness and robustness of the new features.

@kylesayrs kylesayrs changed the base branch from main to bdellabe/transform-modifier July 16, 2025 20:39
Base automatically changed from bdellabe/transform-modifier to main August 13, 2025 15:03
@kylesayrs kylesayrs force-pushed the kylesayrs/transform-online branch from a9b2f51 to 49e1d90 Compare August 20, 2025 01:47
@kylesayrs
Copy link
Collaborator Author

@kylesayrs kylesayrs changed the title [Transform] Online Rotations [Quantization] Attention/ KV Cache Refactor Sep 12, 2025
@kylesayrs kylesayrs force-pushed the kylesayrs/transform-online branch 2 times, most recently from 3a5a04a to b85337f Compare October 9, 2025 16:39
@kylesayrs kylesayrs changed the base branch from main to kylesayrs/observers-refactor October 9, 2025 16:39
@kylesayrs kylesayrs force-pushed the kylesayrs/transform-online branch 2 times, most recently from 9dda155 to 0b6624f Compare October 13, 2025 20:10
@kylesayrs
Copy link
Collaborator Author

@kylesayrs kylesayrs marked this pull request as ready for review October 13, 2025 20:57
@kylesayrs
Copy link
Collaborator Author

Last nightly worked, but e2e failed due to model storage issues
https://github.com/neuralmagic/llm-compressor-testing/actions/runs/18483826999

HDCharles
HDCharles previously approved these changes Oct 14, 2025
Copy link
Collaborator

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks a lot cleaner

rahul-tuli
rahul-tuli previously approved these changes Oct 14, 2025
Copy link
Collaborator

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Base automatically changed from kylesayrs/observers-refactor to main October 14, 2025 19:42
@kylesayrs kylesayrs dismissed stale reviews from rahul-tuli and HDCharles October 14, 2025 19:42

The base branch was changed.

@kylesayrs kylesayrs force-pushed the kylesayrs/transform-online branch from f1b8e5a to 57bee27 Compare October 14, 2025 22:01
HDCharles
HDCharles previously approved these changes Oct 15, 2025
@kylesayrs
Copy link
Collaborator Author

HDCharles
HDCharles previously approved these changes Oct 20, 2025
HDCharles
HDCharles previously approved these changes Oct 21, 2025
@kylesayrs kylesayrs added the ready When a PR is ready for review label Oct 23, 2025
if not hasattr(module, "quantization_scheme"):
continue
if not hasattr(module, "quantization_scheme"):
hooks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a return here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Not sure how that sneaked through

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment, this is addressed now.


# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know how much data is actually needed for get decent results?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

512 is enough for the basic GSM8K evals

update_offload_parameter(module, f"{base_name}_scale", scale)
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually dont run asym quant - why wasn't this a problem before?

Copy link
Collaborator Author

@kylesayrs kylesayrs Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is added because KV cache quantization is weird: it's the only scheme which does not have a compressor. For that reason, it's the only scheme where we cannot force zero points (vllm throws an error if zero points are present)

What should happen

  1. KV cache quant is initialized with forced zero points
  2. Calibration happens and zero points are updated (and stay zero if symmetric)
  3. Model does not have compressed weights and is saved in frozen state
  4. vLLM loads and throws away zero points if symmetric (only symmetric is implemented atm)

(alternatively, in step (3) we write an "attention" compressor which throws away zero points)

What this refactor does to avoid this

  1. KV cache quant is initialized without forced zero points
  2. Calibration happens and zero points are updated, only if they're present (they're not unless asymmetric)
  3. Model does not have compressed weights and is saved in frozen state (without zero points)
  4. vLLM loads scales only (since only symmetric kv cache quantization is supported atm)


quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations)
setattr(module, "kv_cache", quantized_kv_cache)
def calibrate_value_hook(module: Module, value_states: torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so clean

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes from aligning with the patterns and abstractions we've already created, not creating new ones which don't integrate and therefore don't provide as many features.

Signed-off-by: Kyle Sayers <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: KV cache quant - cannot be loaded in vllm [Bug]: k_scale and v_scale is zero after kv cache fp8 quantization

5 participants