Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mistral & Mixtral support #174

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ The detailed support list:
| [Vicuna-v1.1](/scripts/vicuna_example.sh) | 7B/13B | ✅ | |
| [LLaVA-v0](/scripts/llava_example.sh) | 13B | ✅ | |
| [VILA](/scripts/vila_example.sh) | 7B/13B | ✅ | |
| [Mistral](/scripts/mistral_example.sh) | 7B | ✅ | |
| [Mixtral](/scripts/mixtral_example.sh) | 8x7B | ✅ | |


Note: We only list models that we have prepare the [AWQ searching results](https://huggingface.co/datasets/mit-han-lab/awq-model-zoo/tree/main) in the table above. AWQ also supports models such as LLaVA-v1.5 7B, and you may need to run the [AWQ search](#usage) on your own to quantize these models.

Expand Down
5 changes: 5 additions & 0 deletions awq/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def build_model_and_enc(model_path):
"BloomBlock",
"MPTBlock",
"DecoderLayer",
"MistralDecoderLayer",
"MixtralDecoderLayer",
],
**kwargs,
)
Expand Down Expand Up @@ -208,6 +210,7 @@ def build_model_and_enc(model_path):
model, max_memory if len(max_memory) > 0 else None
)
}

device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
Expand All @@ -217,6 +220,8 @@ def build_model_and_enc(model_path):
"BloomBlock",
"MPTBlock",
"DecoderLayer",
"MistralDecoderLayer",
"MixtralDecoderLayer",
],
**kwargs,
)
Expand Down
93 changes: 92 additions & 1 deletion awq/quantize/auto_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
from transformers.models.mixtral.modeling_mixtral import MixtralRMSNorm
from transformers.activations import GELUActivation

from .qmodule import ScaledActivation
Expand Down Expand Up @@ -439,6 +441,95 @@ def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
inp=input_feat["mlp.dense_4h_to_h"],
)
)
elif "mistral" in str(module.__class__).lower():
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
elif "mixtral" in str(module.__class__).lower():
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[
w
for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
inp=input_feat["block_sparse_moe"],
module2inspect=module.block_sparse_moe,
)
)
# fc2
for i, expert in enumerate(module.block_sparse_moe.experts):
scales_list.append(
_auto_get_scale(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f"block_sparse_moe.experts.{i}.w2"],
)
)
else:
raise NotImplementedError(f"{type(module)} not supported yet!")

Expand All @@ -458,7 +549,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm, MistralRMSNorm, MixtralRMSNorm)):
scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
new_module = ScaledActivation(prev_op, scales)
Expand Down
38 changes: 35 additions & 3 deletions awq/quantize/pre_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,17 @@


def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}

named_linears = {}
for name, m in module.named_modules():
if isinstance(m, nn.Linear):
# exclude the gate layer
if "mixtral" in str(module.__class__).lower():
if "gate" in name:
continue

named_linears[name] = m

return named_linears

def get_blocks(model):
if model.__class__.__name__ == "LlamaForCausalLM":
Expand All @@ -39,6 +48,10 @@ def get_blocks(model):
layers = model.transformer.h
elif "neox" in str(model.__class__).lower():
layers = model.gpt_neox.layers
elif "mistral" in str(model.__class__).lower():
layers = model.model.layers
elif "mixtral" in str(model.__class__).lower():
layers = model.model.layers
else:
raise NotImplementedError(type(model))
return layers
Expand Down Expand Up @@ -73,6 +86,10 @@ def move_embed(model, device):
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)
model.gpt_neox.emb_dropout = model.gpt_neox.emb_dropout.to(device)
model.embed_out = model.embed_out.to(device)
elif "mistral" in str(model.__class__).lower():
model.model.embed_tokens = model.model.embed_tokens.to(device)
elif "mixtral" in str(model.__class__).lower():
model.model.embed_tokens = model.model.embed_tokens.to(device)
else:
raise NotImplementedError(type(model))

Expand Down Expand Up @@ -129,10 +146,18 @@ def forward(self, inp, **kwargs):
model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit
pass

# From AutoAWQ
# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs)
# Pop the input_ids as they are not needed at all.
layer_kwargs.pop("input_ids")

del samples
layers[0] = layers[0].module # restore
inps = inps[0]

layers[0] = layers[0].cpu()
move_embed(model, "cpu")

Expand All @@ -158,6 +183,13 @@ def cache_input_hook(m, x, y, name, feat_dict):

input_feat = defaultdict(list)
handles = []

if "mixtral" in str(model.__class__).lower():
named_linears = {
**named_linears,
"block_sparse_moe": layer.block_sparse_moe,
}

for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
Expand Down
25 changes: 25 additions & 0 deletions scripts/mistral_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
MODEL=Mistral-7B-Instruct-v0.2

# run AWQ search (optional; we provided the pre-computed results)
python -m awq.entry --model_path /dataset/mistral/$MODEL \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt

# evaluate the AWQ quantize model (simulated pseudo quantization)
python -m awq.entry --model_path /dataset/mistral/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend fake

# generate real quantized weights (w4)
python -m awq.entry --model_path /dataset/mistral/$MODEL \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt

# load and evaluate the real quantized model (smaller gpu memory usage)
python -m awq.entry --model_path /dataset/mistral/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
25 changes: 25 additions & 0 deletions scripts/mixtral_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
MODEL=Mixtral-8x7B-Instruct-v0.1

# run AWQ search (optional; we provided the pre-computed results)
python -m awq.entry --model_path /dataset/mixtral/$MODEL \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt

# evaluate the AWQ quantize model (simulated pseudo quantization)
python -m awq.entry --model_path /dataset/mixtral/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend fake

# generate real quantized weights (w4)
python -m awq.entry --model_path /dataset/mixtral/$MODEL \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt

# load and evaluate the real quantized model (smaller gpu memory usage)
python -m awq.entry --model_path /dataset/mixtral/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_quant quant_cache/$MODEL-w4-g128-awq.pt