Skip to content

[AWQ] Generalize AWQ quantization#1961

Merged
HDCharles merged 32 commits intomainfrom
kylesayrs/awq-generalize-quant
Dec 11, 2025
Merged

[AWQ] Generalize AWQ quantization#1961
HDCharles merged 32 commits intomainfrom
kylesayrs/awq-generalize-quant

Conversation

@kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Oct 22, 2025

Summary

To allow for arbitrary heterogeneous quantization schemes, this PR switches several helpers from AutoAWQ to the observer and QDQ logic. AWQ no longer constrains that the quantization config needs to have the same settings for group_size, symmetric, and num_bits for each config_group.

Resolves #1657

Prerequisites:

Test plan

  • When running llm-compressor/examples/awq/llama_example.py with this (with duo_scaling="both") and logging the best configuration of (ratio, duo_scaling), I see a good mix of Falses and Trues. i.e. a good percentage of best_scales were found with duo_scaling=False and a good percentage were found with duo_scaling=True. Generated model output looks good.
  • When using awq_one_shot.py (pasted below), Wikitext PPL is consistent for w4a16 and w4a16_asym on this branch when compared to main, and better than what was reported in a previous AWQ PR, but those might have been differently configured. For W4A16_ASYM, the results are both 13.41 for main and this branch. This is what we've been historically using to test regressions.
Scheme Wikitext PPL RTN AWQ main AWQ this branch
W4A16 13.784 13.477 13.426
W4A16_ASYM 13.606 13.346 13.377
  • I see a small regression in recovery when running CADENCE=weekly TEST_DATA_FILE=~/projects/llm-compressor/tests/lmeval/configs/w4a16_awq_sym.yaml pytest -s ~/projects/llm-compressor/tests/lmeval/test_lmeval.py on this branch, which causes the test to fail. This persists even when using pseudo_quantize_tensor instead of call_observer/forward_quantize, as shown in this diff. I get the same result in this diff, so at least that means quantization logic in CT is consistent with AutoAWQ
    Output:
<main>
2025-11-17T18:26:04.682699+0000 | _validate_recovery | INFO - ✓ exact_match,strict-match                 | Base: 0.7650 | Compressed: 0.7090 | Recovery: 92.68% ↑ | Threshold: ≥92.00%
2025-11-17T18:26:04.682811+0000 | _validate_recovery | INFO - ✓ exact_match,flexible-extract             | Base: 0.7630 | Compressed: 0.7100 | Recovery: 93.05% ↑ | Threshold: ≥93.00%
<this branch>
2025-11-17T17:55:00.648672+0000 | _validate_recovery | ERROR - ✗ exact_match,strict-match                 | Base: 0.7650 | Compressed: 0.6950 | Recovery: 90.85% ↑ | Threshold: ≥92.00%
2025-11-17T17:55:00.648967+0000 | _validate_recovery | ERROR - ✗ exact_match,flexible-extract             | Base: 0.7630 | Compressed: 0.6960 | Recovery: 91.22% ↑ | Threshold: ≥93.00%

This is already a pretty high drop in recovery, should we revisit this test?

  • Further regression testing against main was done in this commit see run.sh as of that commit which was removed in the final PR. Results look reasonable comparing branch and main, some up some down, within margin of error.

    Test Group Quantization (w4a16_awq_sym)

    Branch Metric Base Compressed Recovery
    On Branch exact_match,strict-match 0.7620 0.7170 94.09% ↑
    On Branch exact_match,flexible-extract 0.7600 0.7130 93.82% ↑
    On Main exact_match,strict-match 0.7620 0.7090 93.04%
    On Main exact_match,flexible-extract 0.7600 0.7060 92.89%

    Test Tensor Quantization (int8_tensor)

    Branch Metric Base Compressed Recovery
    On Branch exact_match,strict-match 0.7620 0.7220 94.75% ↓
    On Branch exact_match,flexible-extract 0.7600 0.7240 95.26% ↓
    On Main exact_match,strict-match 0.7620 0.7280 95.54%
    On Main exact_match,flexible-extract 0.7600 0.7310 96.18%

    Test Channel Quantization (fp8_dynamic)

    Branch Metric Base Compressed Recovery
    On Branch exact_match,strict-match 0.7650 0.7610 99.48%
    On Branch exact_match,flexible-extract 0.7630 0.7580 99.34%

    Test Block Quantization (fp8_block)

    Branch Metric Base Compressed Recovery
    On Branch exact_match,strict-match 0.7650 0.7720 100.92%
    On Branch exact_match,flexible-extract 0.7630 0.7690 100.79%
awq_oneshot.py script ```python import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from llmcompressor import oneshot, active_session
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
)

MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"

Configure the quantization algorithm to run.

recipe = [
AWQModifier(
ignore=[
"lm_head",
"re:.*mlp.gate$",
"re:.mlp.shared_expert_gate$",
"re:visual.
",
],
scheme="W4A16_ASYM",
duo_scaling="both",
targets=["Linear"],
# offload_device=torch.device("cpu"),
),
]

Select calibration dataset.

DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"

Select number of samples. 256 samples is a good place to start.

Increasing the number of samples can improve accuracy.

NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512

def get_calib_dataset(tokenizer):
from datasets import load_dataset

ds = load_dataset(
    DATASET_ID,
    split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
)

def preprocess(example):
    return {"input_ids": tokenizer.encode(example["text"].strip())}

ds = (
    ds.shuffle(seed=42)
    .map(preprocess, remove_columns=ds.column_names)
    .select(range(NUM_CALIBRATION_SAMPLES))
)

return ds

if name == "main":
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

###
### Apply algorithms.
###
oneshot(
    model=model,
    dataset=get_calib_dataset(tokenizer),
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    log_dir=None,
    trust_remote_code_model=True,
)

# Confirm generations of the quantized model look sane.
dispatch_for_generation(model)
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

##
### Apply algorithms.
##

## LM EVAL

active_session().reset()
del model
del tokenizer
torch.cuda.empty_cache()

import lm_eval
from lm_eval.utils import make_table

results = lm_eval.simple_evaluate(
    model="vllm",
    model_args={
        "pretrained": SAVE_DIR,
        "add_bos_token": True,
        "dtype": "bfloat16",
        "gpu_memory_utilization": 0.7,
        "max_model_len": 4096,
        # "max_num_batched_tokens": 128,
        # "max_num_seqs": 128,
    },
    tasks=["wikitext"],
    batch_size=128,
)
print(make_table(results))
</details>

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@brian-dellabetta brian-dellabetta changed the title [WIP] Generalize AWQ quantization [AWQ] Generalize AWQ quantization Nov 13, 2025
Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that so long as you feel confident that _compute_layer_means is going to work as expected for all the supported strategies, then I think this looks good to me!

Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve from my side

HDCharles
HDCharles previously approved these changes Nov 17, 2025
fynnsu
fynnsu previously approved these changes Nov 17, 2025
Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, added a couple comments below!

@HDCharles HDCharles dismissed stale reviews from fynnsu and themself via 4e480ce December 8, 2025 18:47
@HDCharles HDCharles force-pushed the kylesayrs/awq-generalize-quant branch from bdcdca4 to 4e480ce Compare December 8, 2025 18:47
@HDCharles HDCharles marked this pull request as ready for review December 10, 2025 15:01
@HDCharles HDCharles force-pushed the kylesayrs/awq-generalize-quant branch from 9d2d033 to 57ade6b Compare December 10, 2025 15:04
Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, approved from my side

@kylesayrs kylesayrs added the ready When a PR is ready for review label Dec 10, 2025
@HDCharles HDCharles force-pushed the kylesayrs/awq-generalize-quant branch from 57ade6b to 157cf48 Compare December 10, 2025 19:00
kylesayrs and others added 4 commits December 10, 2025 14:07
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
@HDCharles HDCharles force-pushed the kylesayrs/awq-generalize-quant branch from 157cf48 to 7d79953 Compare December 10, 2025 19:08
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

@HDCharles HDCharles enabled auto-merge (squash) December 11, 2025 14:53
@HDCharles HDCharles merged commit 03e694a into main Dec 11, 2025
10 of 12 checks passed
@HDCharles HDCharles deleted the kylesayrs/awq-generalize-quant branch December 11, 2025 15:35
dsikka pushed a commit that referenced this pull request Jan 23, 2026
We identified a number of inefficiencies and fixes after the AWQ
generalization
[PR](#1961), this PR
largely implements them, see details below. Note I previously made this
[speed improvements
PR](#2188) which had
some issues that have been fixed in this one, that PR is going to be
closed.

# BENCHMARKS

to iterate more quickly i ran these tests on models with most of their
[layers
removed](https://github.com/vllm-project/llm-compressor/blob/1f036248f0310b8e95d488fc5d20831bcc9b62b7/examples/awq/llama_example.py#L69),
the actual improvement should be a bit better since the layers which
were removed are where the improvement happens. To replicate these
numbers see the [first
commit](1f03624#diff-208ced55cba2d38b1bc9b03b5f79ae3483c0849cc708a6c1131c231aae5d4b3dR69)

| Runtime (min) | Improvement | PR    | Base  |
|---------------|-------------|-------|-------|
| llama_no_off  |        8.7% |  6.17 |  6.76 |
| llama_off     |        5.6% | 10.46 | 11.08 |
| moe_no_off*    |        2.8% |  6.67 |  6.86 |
| moe_off*       |        1.9% |  7.57 |  7.72 |

| Memory (GB)     |        |       |       |
|-----------------|--------|-------|-------|
| llama_PR_no_off |   7.8% |  9.22 |    10 |
| llama_PR_off    |  17.6% |  3.66 |  4.44 |
| moe_PR_no_off   |  -5.2% | 11.61 | 11.04 |
| moe_PR_off**      | -24.3% |  2.92 |  2.35 |

\*The actual speedup for MoE is going to be higher than this. These
numbers are for a single layer being quantized so the calibration
overhead is going to attenuate the gains.

\*\* This worsening of memory is due to the weights being cached on
device, the layernorm -> up + gate proj mapping has to cache the up +
gate linears for the entire MLP layer which is fairly large.


# SUMMARY:
changes:

- targetted weight cache, no offloading
- previously in compute_best_scale we would record the entire state dict
of the parent and store it on cpu
- now only record the balance layer weights and store those on device
since they are generally small
- reduce load/write/rewrite weights
- during grid search we have to repeatedly updaste the weight to use
scaled and fake quantized versions of the weight. Previously this was
done by writing the original value, calculating the scaled value and
then writing that (2 writes)
- we instead calculate the scaled value directly from the on-device
cached value and write it once
- fake quantize only on device weight
  - previously the on/offloaded balance layer weight was updated
  - we now just update the on-device value
- compute loss
  - we slightly optimize  compute loss to reduce device movement
- note a number of approaches to improve the loss computation were
attempted including
- progressively calculating the loss while running the samples to get
int_w_outputs to avoid storing the whole int_w_output on device (was
slower and seems to saves the same memory as deleting int_w_outputs
after its used)
- using torch.cat to combine the fp16 outputs and int_w_outputs into a
single tensors so we can do only a single mse calculation (torch.cat
briefly doubles memory usage so it created a significant memory
increase)
- avoiding torch.cat by preallocating a flat tensor of the needed size
and progressively storing chunks into it (slow)
- torch compile on compute loss and/or run samples (would likely speed
up the run samples code but the offloading framework doesn't work well
with it)
- del int_w_outputs, simply deleting this intermediate value after its
used saves a significant amount of memory
- change default offload device behavior, normally None, now by default
we check whether we use the default MoE Mapping and if so, offload to
cpu by default

TEST PLAN:
(see first commit tests)

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Etelis pushed a commit to Etelis/llm-compressor that referenced this pull request Jan 24, 2026
We identified a number of inefficiencies and fixes after the AWQ
generalization
[PR](vllm-project#1961), this PR
largely implements them, see details below. Note I previously made this
[speed improvements
PR](vllm-project#2188) which had
some issues that have been fixed in this one, that PR is going to be
closed.

# BENCHMARKS

to iterate more quickly i ran these tests on models with most of their
[layers
removed](https://github.com/vllm-project/llm-compressor/blob/1f036248f0310b8e95d488fc5d20831bcc9b62b7/examples/awq/llama_example.py#L69),
the actual improvement should be a bit better since the layers which
were removed are where the improvement happens. To replicate these
numbers see the [first
commit](vllm-project@1f03624#diff-208ced55cba2d38b1bc9b03b5f79ae3483c0849cc708a6c1131c231aae5d4b3dR69)

| Runtime (min) | Improvement | PR    | Base  |
|---------------|-------------|-------|-------|
| llama_no_off  |        8.7% |  6.17 |  6.76 |
| llama_off     |        5.6% | 10.46 | 11.08 |
| moe_no_off*    |        2.8% |  6.67 |  6.86 |
| moe_off*       |        1.9% |  7.57 |  7.72 |

| Memory (GB)     |        |       |       |
|-----------------|--------|-------|-------|
| llama_PR_no_off |   7.8% |  9.22 |    10 |
| llama_PR_off    |  17.6% |  3.66 |  4.44 |
| moe_PR_no_off   |  -5.2% | 11.61 | 11.04 |
| moe_PR_off**      | -24.3% |  2.92 |  2.35 |

\*The actual speedup for MoE is going to be higher than this. These
numbers are for a single layer being quantized so the calibration
overhead is going to attenuate the gains.

\*\* This worsening of memory is due to the weights being cached on
device, the layernorm -> up + gate proj mapping has to cache the up +
gate linears for the entire MLP layer which is fairly large.


# SUMMARY:
changes:

- targetted weight cache, no offloading
- previously in compute_best_scale we would record the entire state dict
of the parent and store it on cpu
- now only record the balance layer weights and store those on device
since they are generally small
- reduce load/write/rewrite weights
- during grid search we have to repeatedly updaste the weight to use
scaled and fake quantized versions of the weight. Previously this was
done by writing the original value, calculating the scaled value and
then writing that (2 writes)
- we instead calculate the scaled value directly from the on-device
cached value and write it once
- fake quantize only on device weight
  - previously the on/offloaded balance layer weight was updated
  - we now just update the on-device value
- compute loss
  - we slightly optimize  compute loss to reduce device movement
- note a number of approaches to improve the loss computation were
attempted including
- progressively calculating the loss while running the samples to get
int_w_outputs to avoid storing the whole int_w_output on device (was
slower and seems to saves the same memory as deleting int_w_outputs
after its used)
- using torch.cat to combine the fp16 outputs and int_w_outputs into a
single tensors so we can do only a single mse calculation (torch.cat
briefly doubles memory usage so it created a significant memory
increase)
- avoiding torch.cat by preallocating a flat tensor of the needed size
and progressively storing chunks into it (slow)
- torch compile on compute loss and/or run samples (would likely speed
up the run samples code but the offloading framework doesn't work well
with it)
- del int_w_outputs, simply deleting this intermediate value after its
used saves a significant amount of memory
- change default offload device behavior, normally None, now by default
we check whether we use the default MoE Mapping and if so, offload to
cpu by default

TEST PLAN:
(see first commit tests)

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Etelis pushed a commit to Etelis/llm-compressor that referenced this pull request Jan 25, 2026
We identified a number of inefficiencies and fixes after the AWQ
generalization
[PR](vllm-project#1961), this PR
largely implements them, see details below. Note I previously made this
[speed improvements
PR](vllm-project#2188) which had
some issues that have been fixed in this one, that PR is going to be
closed.

# BENCHMARKS

to iterate more quickly i ran these tests on models with most of their
[layers
removed](https://github.com/vllm-project/llm-compressor/blob/1f036248f0310b8e95d488fc5d20831bcc9b62b7/examples/awq/llama_example.py#L69),
the actual improvement should be a bit better since the layers which
were removed are where the improvement happens. To replicate these
numbers see the [first
commit](vllm-project@1f03624#diff-208ced55cba2d38b1bc9b03b5f79ae3483c0849cc708a6c1131c231aae5d4b3dR69)

| Runtime (min) | Improvement | PR    | Base  |
|---------------|-------------|-------|-------|
| llama_no_off  |        8.7% |  6.17 |  6.76 |
| llama_off     |        5.6% | 10.46 | 11.08 |
| moe_no_off*    |        2.8% |  6.67 |  6.86 |
| moe_off*       |        1.9% |  7.57 |  7.72 |

| Memory (GB)     |        |       |       |
|-----------------|--------|-------|-------|
| llama_PR_no_off |   7.8% |  9.22 |    10 |
| llama_PR_off    |  17.6% |  3.66 |  4.44 |
| moe_PR_no_off   |  -5.2% | 11.61 | 11.04 |
| moe_PR_off**      | -24.3% |  2.92 |  2.35 |

\*The actual speedup for MoE is going to be higher than this. These
numbers are for a single layer being quantized so the calibration
overhead is going to attenuate the gains.

\*\* This worsening of memory is due to the weights being cached on
device, the layernorm -> up + gate proj mapping has to cache the up +
gate linears for the entire MLP layer which is fairly large.

# SUMMARY:
changes:

- targetted weight cache, no offloading
- previously in compute_best_scale we would record the entire state dict
of the parent and store it on cpu
- now only record the balance layer weights and store those on device
since they are generally small
- reduce load/write/rewrite weights
- during grid search we have to repeatedly updaste the weight to use
scaled and fake quantized versions of the weight. Previously this was
done by writing the original value, calculating the scaled value and
then writing that (2 writes)
- we instead calculate the scaled value directly from the on-device
cached value and write it once
- fake quantize only on device weight
  - previously the on/offloaded balance layer weight was updated
  - we now just update the on-device value
- compute loss
  - we slightly optimize  compute loss to reduce device movement
- note a number of approaches to improve the loss computation were
attempted including
- progressively calculating the loss while running the samples to get
int_w_outputs to avoid storing the whole int_w_output on device (was
slower and seems to saves the same memory as deleting int_w_outputs
after its used)
- using torch.cat to combine the fp16 outputs and int_w_outputs into a
single tensors so we can do only a single mse calculation (torch.cat
briefly doubles memory usage so it created a significant memory
increase)
- avoiding torch.cat by preallocating a flat tensor of the needed size
and progressively storing chunks into it (slow)
- torch compile on compute loss and/or run samples (would likely speed
up the run samples code but the offloading framework doesn't work well
with it)
- del int_w_outputs, simply deleting this intermediate value after its
used saves a significant amount of memory
- change default offload device behavior, normally None, now by default
we check whether we use the default MoE Mapping and if so, offload to
cpu by default

TEST PLAN:
(see first commit tests)

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
cajeonrh pushed a commit to cajeonrh/llm-compressor that referenced this pull request Feb 10, 2026
We identified a number of inefficiencies and fixes after the AWQ
generalization
[PR](vllm-project#1961), this PR
largely implements them, see details below. Note I previously made this
[speed improvements
PR](vllm-project#2188) which had
some issues that have been fixed in this one, that PR is going to be
closed.

# BENCHMARKS

to iterate more quickly i ran these tests on models with most of their
[layers
removed](https://github.com/vllm-project/llm-compressor/blob/1f036248f0310b8e95d488fc5d20831bcc9b62b7/examples/awq/llama_example.py#L69),
the actual improvement should be a bit better since the layers which
were removed are where the improvement happens. To replicate these
numbers see the [first
commit](vllm-project@1f03624#diff-208ced55cba2d38b1bc9b03b5f79ae3483c0849cc708a6c1131c231aae5d4b3dR69)

| Runtime (min) | Improvement | PR    | Base  |
|---------------|-------------|-------|-------|
| llama_no_off  |        8.7% |  6.17 |  6.76 |
| llama_off     |        5.6% | 10.46 | 11.08 |
| moe_no_off*    |        2.8% |  6.67 |  6.86 |
| moe_off*       |        1.9% |  7.57 |  7.72 |

| Memory (GB)     |        |       |       |
|-----------------|--------|-------|-------|
| llama_PR_no_off |   7.8% |  9.22 |    10 |
| llama_PR_off    |  17.6% |  3.66 |  4.44 |
| moe_PR_no_off   |  -5.2% | 11.61 | 11.04 |
| moe_PR_off**      | -24.3% |  2.92 |  2.35 |

\*The actual speedup for MoE is going to be higher than this. These
numbers are for a single layer being quantized so the calibration
overhead is going to attenuate the gains.

\*\* This worsening of memory is due to the weights being cached on
device, the layernorm -> up + gate proj mapping has to cache the up +
gate linears for the entire MLP layer which is fairly large.


# SUMMARY:
changes:

- targetted weight cache, no offloading
- previously in compute_best_scale we would record the entire state dict
of the parent and store it on cpu
- now only record the balance layer weights and store those on device
since they are generally small
- reduce load/write/rewrite weights
- during grid search we have to repeatedly updaste the weight to use
scaled and fake quantized versions of the weight. Previously this was
done by writing the original value, calculating the scaled value and
then writing that (2 writes)
- we instead calculate the scaled value directly from the on-device
cached value and write it once
- fake quantize only on device weight
  - previously the on/offloaded balance layer weight was updated
  - we now just update the on-device value
- compute loss
  - we slightly optimize  compute loss to reduce device movement
- note a number of approaches to improve the loss computation were
attempted including
- progressively calculating the loss while running the samples to get
int_w_outputs to avoid storing the whole int_w_output on device (was
slower and seems to saves the same memory as deleting int_w_outputs
after its used)
- using torch.cat to combine the fp16 outputs and int_w_outputs into a
single tensors so we can do only a single mse calculation (torch.cat
briefly doubles memory usage so it created a significant memory
increase)
- avoiding torch.cat by preallocating a flat tensor of the needed size
and progressively storing chunks into it (slow)
- torch compile on compute loss and/or run samples (would likely speed
up the run samples code but the offloading framework doesn't work well
with it)
- del int_w_outputs, simply deleting this intermediate value after its
used saves a significant amount of memory
- change default offload device behavior, normally None, now by default
we check whether we use the default MoE Mapping and if so, offload to
cpu by default

TEST PLAN:
(see first commit tests)

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

W4fp8 AWQ

4 participants